From e96cea997eedea32a9b01b0de274abebde48a9ce Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Tue, 24 Jan 2023 14:13:49 +0100 Subject: [PATCH] Add reboot button to SFRBox (#86514) --- homeassistant/components/sfr_box/__init__.py | 7 +- homeassistant/components/sfr_box/button.py | 108 +++++++++++++++++++ homeassistant/components/sfr_box/const.py | 1 + homeassistant/components/sfr_box/models.py | 2 + tests/components/sfr_box/const.py | 10 ++ tests/components/sfr_box/test_button.py | 71 ++++++++++++ 6 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 homeassistant/components/sfr_box/button.py create mode 100644 tests/components/sfr_box/test_button.py diff --git a/homeassistant/components/sfr_box/__init__.py b/homeassistant/components/sfr_box/__init__.py index 8c7bca7a913..07f122fa4b2 100644 --- a/homeassistant/components/sfr_box/__init__.py +++ b/homeassistant/components/sfr_box/__init__.py @@ -13,7 +13,7 @@ from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers import device_registry as dr from homeassistant.helpers.httpx_client import get_async_client -from .const import DOMAIN, PLATFORMS +from .const import DOMAIN, PLATFORMS, PLATFORMS_WITH_AUTH from .coordinator import SFRDataUpdateCoordinator from .models import DomainData @@ -21,6 +21,7 @@ from .models import DomainData async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up SFR box as config entry.""" box = SFRBox(ip=entry.data[CONF_HOST], client=get_async_client(hass)) + platforms = PLATFORMS if (username := entry.data.get(CONF_USERNAME)) and ( password := entry.data.get(CONF_PASSWORD) ): @@ -30,8 +31,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: raise ConfigEntryAuthFailed() from err except SFRBoxError as err: raise ConfigEntryNotReady() from err + platforms = PLATFORMS_WITH_AUTH data = DomainData( + box=box, dsl=SFRDataUpdateCoordinator(hass, box, "dsl", lambda b: b.dsl_get_info()), system=SFRDataUpdateCoordinator( hass, box, "system", lambda b: b.system_get_info() @@ -56,7 +59,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: configuration_url=f"http://{entry.data[CONF_HOST]}", ) - await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + await hass.config_entries.async_forward_entry_setups(entry, platforms) return True diff --git a/homeassistant/components/sfr_box/button.py b/homeassistant/components/sfr_box/button.py new file mode 100644 index 00000000000..a6fa9af5385 --- /dev/null +++ b/homeassistant/components/sfr_box/button.py @@ -0,0 +1,108 @@ +"""SFR Box button platform.""" +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Coroutine +from dataclasses import dataclass +from functools import wraps +from typing import Any, Concatenate, ParamSpec, TypeVar + +from sfrbox_api.bridge import SFRBox +from sfrbox_api.exceptions import SFRBoxError +from sfrbox_api.models import SystemInfo + +from homeassistant.components.button import ( + ButtonDeviceClass, + ButtonEntity, + ButtonEntityDescription, +) +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity import EntityCategory +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .models import DomainData + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def with_error_wrapping( + func: Callable[Concatenate[SFRBoxButton, _P], Awaitable[_T]] +) -> Callable[Concatenate[SFRBoxButton, _P], Coroutine[Any, Any, _T]]: + """Catch SFR errors.""" + + @wraps(func) + async def wrapper( + self: SFRBoxButton, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _T: + """Catch SFRBoxError errors and raise HomeAssistantError.""" + try: + return await func(self, *args, **kwargs) + except SFRBoxError as err: + raise HomeAssistantError(err) from err + + return wrapper + + +@dataclass +class SFRBoxButtonMixin: + """Mixin for SFR Box buttons.""" + + async_press: Callable[[SFRBox], Coroutine[None, None, None]] + + +@dataclass +class SFRBoxButtonEntityDescription(ButtonEntityDescription, SFRBoxButtonMixin): + """Description for SFR Box buttons.""" + + +BUTTON_TYPES: tuple[SFRBoxButtonEntityDescription, ...] = ( + SFRBoxButtonEntityDescription( + async_press=lambda x: x.system_reboot(), + device_class=ButtonDeviceClass.RESTART, + entity_category=EntityCategory.CONFIG, + key="system_reboot", + name="Reboot", + ), +) + + +async def async_setup_entry( + hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback +) -> None: + """Set up the buttons.""" + data: DomainData = hass.data[DOMAIN][entry.entry_id] + + entities = [ + SFRBoxButton(data.box, description, data.system.data) + for description in BUTTON_TYPES + ] + async_add_entities(entities) + + +class SFRBoxButton(ButtonEntity): + """Mixin for button specific attributes.""" + + entity_description: SFRBoxButtonEntityDescription + _attr_has_entity_name = True + + def __init__( + self, + box: SFRBox, + description: SFRBoxButtonEntityDescription, + system_info: SystemInfo, + ) -> None: + """Initialize the sensor.""" + self.entity_description = description + self._box = box + self._attr_unique_id = f"{system_info.mac_addr}_{description.key}" + self._attr_device_info = {"identifiers": {(DOMAIN, system_info.mac_addr)}} + + @with_error_wrapping + async def async_press(self) -> None: + """Process the button press.""" + await self.entity_description.async_press(self._box) diff --git a/homeassistant/components/sfr_box/const.py b/homeassistant/components/sfr_box/const.py index 7a64994ce42..3700890b957 100644 --- a/homeassistant/components/sfr_box/const.py +++ b/homeassistant/components/sfr_box/const.py @@ -7,3 +7,4 @@ DEFAULT_USERNAME = "admin" DOMAIN = "sfr_box" PLATFORMS = [Platform.BINARY_SENSOR, Platform.SENSOR] +PLATFORMS_WITH_AUTH = [*PLATFORMS, Platform.BUTTON] diff --git a/homeassistant/components/sfr_box/models.py b/homeassistant/components/sfr_box/models.py index 242a248309c..e2f86aeb924 100644 --- a/homeassistant/components/sfr_box/models.py +++ b/homeassistant/components/sfr_box/models.py @@ -1,6 +1,7 @@ """SFR Box models.""" from dataclasses import dataclass +from sfrbox_api.bridge import SFRBox from sfrbox_api.models import DslInfo, SystemInfo from .coordinator import SFRDataUpdateCoordinator @@ -10,5 +11,6 @@ from .coordinator import SFRDataUpdateCoordinator class DomainData: """Domain data for SFR Box.""" + box: SFRBox dsl: SFRDataUpdateCoordinator[DslInfo] system: SFRDataUpdateCoordinator[SystemInfo] diff --git a/tests/components/sfr_box/const.py b/tests/components/sfr_box/const.py index 6bd5a1b8a52..8b7513aaf8c 100644 --- a/tests/components/sfr_box/const.py +++ b/tests/components/sfr_box/const.py @@ -1,5 +1,6 @@ """Constants for SFR Box tests.""" from homeassistant.components.binary_sensor import BinarySensorDeviceClass +from homeassistant.components.button import ButtonDeviceClass from homeassistant.components.sensor import ( ATTR_OPTIONS, ATTR_STATE_CLASS, @@ -18,6 +19,7 @@ from homeassistant.const import ( ATTR_UNIT_OF_MEASUREMENT, SIGNAL_STRENGTH_DECIBELS, STATE_ON, + STATE_UNKNOWN, Platform, UnitOfDataRate, UnitOfElectricPotential, @@ -48,6 +50,14 @@ EXPECTED_ENTITIES = { ATTR_UNIQUE_ID: "e4:5d:51:00:11:22_dsl_status", }, ], + Platform.BUTTON: [ + { + ATTR_DEVICE_CLASS: ButtonDeviceClass.RESTART, + ATTR_ENTITY_ID: "button.sfr_box_reboot", + ATTR_STATE: STATE_UNKNOWN, + ATTR_UNIQUE_ID: "e4:5d:51:00:11:22_system_reboot", + }, + ], Platform.SENSOR: [ { ATTR_DEFAULT_DISABLED: True, diff --git a/tests/components/sfr_box/test_button.py b/tests/components/sfr_box/test_button.py new file mode 100644 index 00000000000..9872e39d3c4 --- /dev/null +++ b/tests/components/sfr_box/test_button.py @@ -0,0 +1,71 @@ +"""Test the SFR Box buttons.""" +from collections.abc import Generator +from unittest.mock import patch + +import pytest +from sfrbox_api.exceptions import SFRBoxError + +from homeassistant.components.button import DOMAIN as BUTTON_DOMAIN, SERVICE_PRESS +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import ATTR_ENTITY_ID, Platform +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError + +from . import check_device_registry, check_entities +from .const import EXPECTED_ENTITIES + +from tests.common import mock_device_registry, mock_registry + +pytestmark = pytest.mark.usefixtures("system_get_info", "dsl_get_info") + + +@pytest.fixture(autouse=True) +def override_platforms() -> Generator[None, None, None]: + """Override PLATFORMS_WITH_AUTH.""" + with patch( + "homeassistant.components.sfr_box.PLATFORMS_WITH_AUTH", [Platform.BUTTON] + ), patch("homeassistant.components.sfr_box.coordinator.SFRBox.authenticate"): + yield + + +async def test_buttons( + hass: HomeAssistant, config_entry_with_auth: ConfigEntry +) -> None: + """Test for SFR Box buttons.""" + entity_registry = mock_registry(hass) + device_registry = mock_device_registry(hass) + + await hass.config_entries.async_setup(config_entry_with_auth.entry_id) + await hass.async_block_till_done() + + check_device_registry(device_registry, EXPECTED_ENTITIES["expected_device"]) + + expected_entities = EXPECTED_ENTITIES[Platform.BUTTON] + assert len(entity_registry.entities) == len(expected_entities) + + check_entities(hass, entity_registry, expected_entities) + + # Reboot success + service_data = {ATTR_ENTITY_ID: "button.sfr_box_reboot"} + with patch( + "homeassistant.components.sfr_box.button.SFRBox.system_reboot" + ) as mock_action: + await hass.services.async_call( + BUTTON_DOMAIN, SERVICE_PRESS, service_data=service_data, blocking=True + ) + + assert len(mock_action.mock_calls) == 1 + assert mock_action.mock_calls[0][1] == () + + # Reboot failed + service_data = {ATTR_ENTITY_ID: "button.sfr_box_reboot"} + with patch( + "homeassistant.components.sfr_box.button.SFRBox.system_reboot", + side_effect=SFRBoxError, + ) as mock_action, pytest.raises(HomeAssistantError): + await hass.services.async_call( + BUTTON_DOMAIN, SERVICE_PRESS, service_data=service_data, blocking=True + ) + + assert len(mock_action.mock_calls) == 1 + assert mock_action.mock_calls[0][1] == ()