diff --git a/homeassistant/components/switcher_kis/__init__.py b/homeassistant/components/switcher_kis/__init__.py index 60b3b18b0b0..555ba951041 100644 --- a/homeassistant/components/switcher_kis/__init__.py +++ b/homeassistant/components/switcher_kis/__init__.py @@ -10,7 +10,9 @@ from aioswitcher.device import SwitcherBase from homeassistant.config_entries import ConfigEntry from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.helpers import device_registry as dr +from .const import DOMAIN from .coordinator import SwitcherDataUpdateCoordinator PLATFORMS = [ @@ -77,3 +79,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> async def async_unload_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool: """Unload a config entry.""" return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + + +async def async_remove_config_entry_device( + hass: HomeAssistant, config_entry: SwitcherConfigEntry, device_entry: dr.DeviceEntry +) -> bool: + """Remove a config entry from a device.""" + return not device_entry.identifiers.intersection( + (DOMAIN, device_id) for device_id in config_entry.runtime_data + ) diff --git a/tests/components/switcher_kis/test_init.py b/tests/components/switcher_kis/test_init.py index 14217a7e044..a652348463e 100644 --- a/tests/components/switcher_kis/test_init.py +++ b/tests/components/switcher_kis/test_init.py @@ -4,16 +4,19 @@ from datetime import timedelta import pytest -from homeassistant.components.switcher_kis.const import MAX_UPDATE_INTERVAL_SEC +from homeassistant.components.switcher_kis.const import DOMAIN, MAX_UPDATE_INTERVAL_SEC from homeassistant.config_entries import ConfigEntryState from homeassistant.const import STATE_UNAVAILABLE from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr +from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util, slugify from . import init_integration -from .consts import DUMMY_SWITCHER_DEVICES +from .consts import DUMMY_DEVICE_ID1, DUMMY_DEVICE_ID4, DUMMY_SWITCHER_DEVICES from tests.common import async_fire_time_changed +from tests.typing import WebSocketGenerator async def test_update_fail( @@ -78,3 +81,56 @@ async def test_entry_unload(hass: HomeAssistant, mock_bridge) -> None: assert entry.state is ConfigEntryState.NOT_LOADED assert mock_bridge.is_running is False + + +async def test_remove_device( + hass: HomeAssistant, mock_bridge, hass_ws_client: WebSocketGenerator +) -> None: + """Test being able to remove a disconnected device.""" + assert await async_setup_component(hass, "config", {}) + entry = await init_integration(hass) + entry_id = entry.entry_id + assert mock_bridge + + mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES) + await hass.async_block_till_done() + + assert mock_bridge.is_running is True + assert len(entry.runtime_data) == 2 + + device_registry = dr.async_get(hass) + live_device_id = DUMMY_DEVICE_ID1 + dead_device_id = DUMMY_DEVICE_ID4 + + assert len(dr.async_entries_for_config_entry(device_registry, entry_id)) == 2 + + # Create a dead device + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + identifiers={(DOMAIN, dead_device_id)}, + manufacturer="Switcher", + model="Switcher Model", + name="Switcher Device", + ) + await hass.async_block_till_done() + assert len(dr.async_entries_for_config_entry(device_registry, entry_id)) == 3 + + # Try to remove a live device - fails + device = device_registry.async_get_device(identifiers={(DOMAIN, live_device_id)}) + client = await hass_ws_client(hass) + response = await client.remove_device(device.id, entry_id) + assert not response["success"] + assert len(dr.async_entries_for_config_entry(device_registry, entry_id)) == 3 + assert ( + device_registry.async_get_device(identifiers={(DOMAIN, live_device_id)}) + is not None + ) + + # Try to remove a dead device - succeeds + device = device_registry.async_get_device(identifiers={(DOMAIN, dead_device_id)}) + response = await client.remove_device(device.id, entry_id) + assert response["success"] + assert len(dr.async_entries_for_config_entry(device_registry, entry_id)) == 2 + assert ( + device_registry.async_get_device(identifiers={(DOMAIN, dead_device_id)}) is None + )