diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 32934d4e970..f67bfb98641 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -7,7 +7,7 @@ from homeassistant import config_entries, data_entry_flow from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT from homeassistant.components import websocket_api from homeassistant.components.http import HomeAssistantView -from homeassistant.const import HTTP_NOT_FOUND +from homeassistant.const import HTTP_FORBIDDEN, HTTP_NOT_FOUND from homeassistant.core import callback from homeassistant.exceptions import Unauthorized import homeassistant.helpers.config_validation as cv @@ -22,6 +22,7 @@ async def async_setup(hass): """Enable the Home Assistant views.""" hass.http.register_view(ConfigManagerEntryIndexView) hass.http.register_view(ConfigManagerEntryResourceView) + hass.http.register_view(ConfigManagerEntryResourceReloadView) hass.http.register_view(ConfigManagerFlowIndexView(hass.config_entries.flow)) hass.http.register_view(ConfigManagerFlowResourceView(hass.config_entries.flow)) hass.http.register_view(ConfigManagerAvailableFlowView) @@ -92,6 +93,29 @@ class ConfigManagerEntryResourceView(HomeAssistantView): return self.json(result) +class ConfigManagerEntryResourceReloadView(HomeAssistantView): + """View to reload a config entry.""" + + url = "/api/config/config_entries/entry/{entry_id}/reload" + name = "api:config:config_entries:entry:resource:reload" + + async def post(self, request, entry_id): + """Reload a config entry.""" + if not request["hass_user"].is_admin: + raise Unauthorized(config_entry_id=entry_id, permission="remove") + + hass = request.app["hass"] + + try: + result = await hass.config_entries.async_reload(entry_id) + except config_entries.OperationNotAllowed: + return self.json_message("Entry cannot be reloaded", HTTP_FORBIDDEN) + except config_entries.UnknownEntry: + return self.json_message("Invalid entry specified", HTTP_NOT_FOUND) + + return self.json({"require_restart": not result}) + + class ConfigManagerFlowIndexView(FlowManagerIndexView): """View to create config flows.""" @@ -345,4 +369,5 @@ def entry_json(entry: config_entries.ConfigEntry) -> dict: "state": entry.state, "connection_class": entry.connection_class, "supports_options": supports_options, + "supports_unload": entry.supports_unload, } diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 04bdbf236a5..01eb63fe05f 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -110,6 +110,7 @@ class ConfigEntry: "data", "options", "unique_id", + "supports_unload", "system_options", "source", "connection_class", @@ -167,6 +168,9 @@ class ConfigEntry: # Unique ID of this entry. self.unique_id = unique_id + # Supports unload + self.supports_unload = False + # Listeners to call on update self.update_listeners: List[weakref.ReferenceType[UpdateListenerType]] = [] @@ -187,6 +191,8 @@ class ConfigEntry: if integration is None: integration = await loader.async_get_integration(hass, self.domain) + self.supports_unload = await support_entry_unload(hass, self.domain) + try: component = integration.get_component() except ImportError as err: @@ -1116,9 +1122,7 @@ class EntityRegistryDisabledHandler: ) assert config_entry is not None - if config_entry.entry_id not in self.changed and await support_entry_unload( - self.hass, config_entry.domain - ): + if config_entry.entry_id not in self.changed and config_entry.supports_unload: self.changed.add(config_entry.entry_id) if not self.changed: diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index e5da27818fc..9bd8875add0 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -53,12 +53,14 @@ async def test_get_entries(hass, client): "comp2", "Comp 2", lambda: None, core_ce.CONN_CLASS_ASSUMED ) - MockConfigEntry( + entry = MockConfigEntry( domain="comp1", title="Test 1", source="bla", connection_class=core_ce.CONN_CLASS_LOCAL_POLL, - ).add_to_hass(hass) + ) + entry.supports_unload = True + entry.add_to_hass(hass) MockConfigEntry( domain="comp2", title="Test 2", @@ -80,6 +82,7 @@ async def test_get_entries(hass, client): "state": "not_loaded", "connection_class": "local_poll", "supports_options": True, + "supports_unload": True, }, { "domain": "comp2", @@ -88,6 +91,7 @@ async def test_get_entries(hass, client): "state": "loaded", "connection_class": "assumed", "supports_options": False, + "supports_unload": False, }, ] @@ -103,6 +107,25 @@ async def test_remove_entry(hass, client): assert len(hass.config_entries.async_entries()) == 0 +async def test_reload_entry(hass, client): + """Test reloading an entry via the API.""" + entry = MockConfigEntry(domain="demo", state=core_ce.ENTRY_STATE_LOADED) + entry.add_to_hass(hass) + resp = await client.post( + f"/api/config/config_entries/entry/{entry.entry_id}/reload" + ) + assert resp.status == 200 + data = await resp.json() + assert data == {"require_restart": True} + assert len(hass.config_entries.async_entries()) == 1 + + +async def test_reload_invalid_entry(hass, client): + """Test reloading an invalid entry via the API.""" + resp = await client.post("/api/config/config_entries/entry/invalid/reload") + assert resp.status == 404 + + async def test_remove_entry_unauth(hass, client, hass_admin_user): """Test removing an entry via the API.""" hass_admin_user.groups = [] @@ -113,6 +136,29 @@ async def test_remove_entry_unauth(hass, client, hass_admin_user): assert len(hass.config_entries.async_entries()) == 1 +async def test_reload_entry_unauth(hass, client, hass_admin_user): + """Test reloading an entry via the API.""" + hass_admin_user.groups = [] + entry = MockConfigEntry(domain="demo", state=core_ce.ENTRY_STATE_LOADED) + entry.add_to_hass(hass) + resp = await client.post( + f"/api/config/config_entries/entry/{entry.entry_id}/reload" + ) + assert resp.status == 401 + assert len(hass.config_entries.async_entries()) == 1 + + +async def test_reload_entry_in_failed_state(hass, client, hass_admin_user): + """Test reloading an entry via the API that has already failed to unload.""" + entry = MockConfigEntry(domain="demo", state=core_ce.ENTRY_STATE_FAILED_UNLOAD) + entry.add_to_hass(hass) + resp = await client.post( + f"/api/config/config_entries/entry/{entry.entry_id}/reload" + ) + assert resp.status == 403 + assert len(hass.config_entries.async_entries()) == 1 + + async def test_available_flows(hass, client): """Test querying the available flows.""" with patch.object(config_flows, "FLOWS", ["hello", "world"]): diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 988e67718ef..816530befea 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -53,6 +53,7 @@ async def test_call_setup_entry(hass): """Test we call .setup_entry.""" entry = MockConfigEntry(domain="comp") entry.add_to_hass(hass) + assert not entry.supports_unload mock_setup_entry = AsyncMock(return_value=True) mock_migrate_entry = AsyncMock(return_value=True) @@ -67,16 +68,49 @@ async def test_call_setup_entry(hass): ) mock_entity_platform(hass, "config_flow.comp", None) - result = await async_setup_component(hass, "comp", {}) + with patch("homeassistant.config_entries.support_entry_unload", return_value=True): + result = await async_setup_component(hass, "comp", {}) + await hass.async_block_till_done() assert result assert len(mock_migrate_entry.mock_calls) == 0 assert len(mock_setup_entry.mock_calls) == 1 assert entry.state == config_entries.ENTRY_STATE_LOADED + assert entry.supports_unload + + +async def test_call_setup_entry_without_reload_support(hass): + """Test we call .setup_entry and the does not support unloading.""" + entry = MockConfigEntry(domain="comp") + entry.add_to_hass(hass) + assert not entry.supports_unload + + mock_setup_entry = AsyncMock(return_value=True) + mock_migrate_entry = AsyncMock(return_value=True) + + mock_integration( + hass, + MockModule( + "comp", + async_setup_entry=mock_setup_entry, + async_migrate_entry=mock_migrate_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + with patch("homeassistant.config_entries.support_entry_unload", return_value=False): + result = await async_setup_component(hass, "comp", {}) + await hass.async_block_till_done() + assert result + assert len(mock_migrate_entry.mock_calls) == 0 + assert len(mock_setup_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + assert not entry.supports_unload async def test_call_async_migrate_entry(hass): """Test we call .async_migrate_entry when version mismatch.""" entry = MockConfigEntry(domain="comp") + assert not entry.supports_unload entry.version = 2 entry.add_to_hass(hass) @@ -93,11 +127,14 @@ async def test_call_async_migrate_entry(hass): ) mock_entity_platform(hass, "config_flow.comp", None) - result = await async_setup_component(hass, "comp", {}) + with patch("homeassistant.config_entries.support_entry_unload", return_value=True): + result = await async_setup_component(hass, "comp", {}) + await hass.async_block_till_done() assert result assert len(mock_migrate_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 assert entry.state == config_entries.ENTRY_STATE_LOADED + assert entry.supports_unload async def test_call_async_migrate_entry_failure_false(hass): @@ -105,6 +142,7 @@ async def test_call_async_migrate_entry_failure_false(hass): entry = MockConfigEntry(domain="comp") entry.version = 2 entry.add_to_hass(hass) + assert not entry.supports_unload mock_migrate_entry = AsyncMock(return_value=False) mock_setup_entry = AsyncMock(return_value=True) @@ -124,6 +162,7 @@ async def test_call_async_migrate_entry_failure_false(hass): assert len(mock_migrate_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 0 assert entry.state == config_entries.ENTRY_STATE_MIGRATION_ERROR + assert not entry.supports_unload async def test_call_async_migrate_entry_failure_exception(hass): @@ -131,6 +170,7 @@ async def test_call_async_migrate_entry_failure_exception(hass): entry = MockConfigEntry(domain="comp") entry.version = 2 entry.add_to_hass(hass) + assert not entry.supports_unload mock_migrate_entry = AsyncMock(side_effect=Exception) mock_setup_entry = AsyncMock(return_value=True) @@ -150,6 +190,7 @@ async def test_call_async_migrate_entry_failure_exception(hass): assert len(mock_migrate_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 0 assert entry.state == config_entries.ENTRY_STATE_MIGRATION_ERROR + assert not entry.supports_unload async def test_call_async_migrate_entry_failure_not_bool(hass): @@ -157,6 +198,7 @@ async def test_call_async_migrate_entry_failure_not_bool(hass): entry = MockConfigEntry(domain="comp") entry.version = 2 entry.add_to_hass(hass) + assert not entry.supports_unload mock_migrate_entry = AsyncMock(return_value=None) mock_setup_entry = AsyncMock(return_value=True) @@ -176,6 +218,7 @@ async def test_call_async_migrate_entry_failure_not_bool(hass): assert len(mock_migrate_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 0 assert entry.state == config_entries.ENTRY_STATE_MIGRATION_ERROR + assert not entry.supports_unload async def test_call_async_migrate_entry_failure_not_supported(hass): @@ -183,6 +226,7 @@ async def test_call_async_migrate_entry_failure_not_supported(hass): entry = MockConfigEntry(domain="comp") entry.version = 2 entry.add_to_hass(hass) + assert not entry.supports_unload mock_setup_entry = AsyncMock(return_value=True) @@ -193,6 +237,7 @@ async def test_call_async_migrate_entry_failure_not_supported(hass): assert result assert len(mock_setup_entry.mock_calls) == 0 assert entry.state == config_entries.ENTRY_STATE_MIGRATION_ERROR + assert not entry.supports_unload async def test_remove_entry(hass, manager): @@ -991,6 +1036,7 @@ async def test_reload_entry_entity_registry_works(hass): config_entry = MockConfigEntry( domain="comp", state=config_entries.ENTRY_STATE_LOADED ) + config_entry.supports_unload = True config_entry.add_to_hass(hass) mock_setup_entry = AsyncMock(return_value=True) mock_unload_entry = AsyncMock(return_value=True)