diff --git a/homeassistant/components/nice_go/config_flow.py b/homeassistant/components/nice_go/config_flow.py index 9d2c1c05518..94594bbd11f 100644 --- a/homeassistant/components/nice_go/config_flow.py +++ b/homeassistant/components/nice_go/config_flow.py @@ -2,17 +2,19 @@ from __future__ import annotations +from collections.abc import Mapping from datetime import datetime import logging -from typing import Any +from typing import TYPE_CHECKING, Any from nice_go import AuthFailedError, NiceGOApi import voluptuous as vol from homeassistant.config_entries import ConfigFlow, ConfigFlowResult -from homeassistant.const import CONF_EMAIL, CONF_PASSWORD +from homeassistant.const import CONF_EMAIL, CONF_NAME, CONF_PASSWORD from homeassistant.helpers.aiohttp_client import async_get_clientsession +from . import NiceGOConfigEntry from .const import CONF_REFRESH_TOKEN, CONF_REFRESH_TOKEN_CREATION_TIME, DOMAIN _LOGGER = logging.getLogger(__name__) @@ -29,6 +31,7 @@ class NiceGOConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for Nice G.O.""" VERSION = 1 + reauth_entry: NiceGOConfigEntry | None async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -66,3 +69,57 @@ class NiceGOConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors ) + + async def async_step_reauth( + self, entry_data: Mapping[str, Any] + ) -> ConfigFlowResult: + """Handle re-authentication.""" + self.reauth_entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Confirm re-authentication.""" + errors = {} + + if TYPE_CHECKING: + assert self.reauth_entry is not None + + if user_input is not None: + hub = NiceGOApi() + + try: + refresh_token = await hub.authenticate( + user_input[CONF_EMAIL], + user_input[CONF_PASSWORD], + async_get_clientsession(self.hass), + ) + except AuthFailedError: + errors["base"] = "invalid_auth" + except Exception: # noqa: BLE001 + _LOGGER.exception("Unexpected exception") + errors["base"] = "unknown" + else: + return self.async_update_reload_and_abort( + self.reauth_entry, + data={ + **user_input, + CONF_REFRESH_TOKEN: refresh_token, + CONF_REFRESH_TOKEN_CREATION_TIME: datetime.now().timestamp(), + }, + unique_id=user_input[CONF_EMAIL], + ) + + return self.async_show_form( + step_id="reauth_confirm", + data_schema=self.add_suggested_values_to_schema( + STEP_USER_DATA_SCHEMA, + user_input or {CONF_EMAIL: self.reauth_entry.data[CONF_EMAIL]}, + ), + description_placeholders={CONF_NAME: self.reauth_entry.title}, + errors=errors, + ) diff --git a/homeassistant/components/nice_go/strings.json b/homeassistant/components/nice_go/strings.json index 30a2bbf58b6..f83207ad977 100644 --- a/homeassistant/components/nice_go/strings.json +++ b/homeassistant/components/nice_go/strings.json @@ -1,6 +1,13 @@ { "config": { "step": { + "reauth_confirm": { + "title": "[%key:common::config_flow::title::reauth%]", + "data": { + "email": "[%key:common::config_flow::data::email%]", + "password": "[%key:common::config_flow::data::password%]" + } + }, "user": { "data": { "email": "[%key:common::config_flow::data::email%]", diff --git a/tests/components/nice_go/test_config_flow.py b/tests/components/nice_go/test_config_flow.py index 67930b9f752..9c25a640c75 100644 --- a/tests/components/nice_go/test_config_flow.py +++ b/tests/components/nice_go/test_config_flow.py @@ -16,6 +16,8 @@ from homeassistant.const import CONF_EMAIL, CONF_PASSWORD from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType +from . import setup_integration + from tests.common import MockConfigEntry @@ -109,3 +111,71 @@ async def test_duplicate_device( ) assert result["type"] is FlowResultType.ABORT assert result["reason"] == "already_configured" + + +async def test_reauth( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_nice_go: AsyncMock, +) -> None: + """Test reauth flow.""" + + await setup_integration(hass, mock_config_entry, []) + + result = await mock_config_entry.start_reauth_flow(hass) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_EMAIL: "test-email", + CONF_PASSWORD: "other-fake-password", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert len(hass.config_entries.async_entries()) == 1 + + +@pytest.mark.parametrize( + ("side_effect", "expected_error"), + [(AuthFailedError, "invalid_auth"), (Exception, "unknown")], +) +async def test_reauth_exceptions( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_nice_go: AsyncMock, + side_effect: Exception, + expected_error: str, +) -> None: + """Test we handle invalid auth.""" + mock_nice_go.authenticate.side_effect = side_effect + await setup_integration(hass, mock_config_entry, []) + + result = await mock_config_entry.start_reauth_flow(hass) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_EMAIL: "test-email", + CONF_PASSWORD: "test-password", + }, + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {"base": expected_error} + mock_nice_go.authenticate.side_effect = None + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_EMAIL: "test-email", + CONF_PASSWORD: "test-password", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert len(hass.config_entries.async_entries()) == 1 diff --git a/tests/components/nice_go/test_init.py b/tests/components/nice_go/test_init.py index 9c9bf28ca7a..23d496df238 100644 --- a/tests/components/nice_go/test_init.py +++ b/tests/components/nice_go/test_init.py @@ -10,7 +10,7 @@ import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components.nice_go.const import DOMAIN -from homeassistant.config_entries import ConfigEntryState +from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers import issue_registry as ir @@ -33,29 +33,32 @@ async def test_unload_entry( assert mock_config_entry.state is ConfigEntryState.NOT_LOADED -@pytest.mark.parametrize( - ("side_effect", "entry_state"), - [ - ( - AuthFailedError(), - ConfigEntryState.SETUP_ERROR, - ), - (ApiError(), ConfigEntryState.SETUP_RETRY), - ], -) -async def test_setup_failure( +async def test_setup_failure_api_error( hass: HomeAssistant, mock_nice_go: AsyncMock, mock_config_entry: MockConfigEntry, - side_effect: Exception, - entry_state: ConfigEntryState, ) -> None: """Test reauth trigger setup.""" - mock_nice_go.authenticate_refresh.side_effect = side_effect + mock_nice_go.authenticate_refresh.side_effect = ApiError() await setup_integration(hass, mock_config_entry, []) - assert mock_config_entry.state is entry_state + assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY + + +async def test_setup_failure_auth_failed( + hass: HomeAssistant, + mock_nice_go: AsyncMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test reauth trigger setup.""" + + mock_nice_go.authenticate_refresh.side_effect = AuthFailedError() + + await setup_integration(hass, mock_config_entry, []) + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + + assert any(mock_config_entry.async_get_active_flows(hass, {SOURCE_REAUTH})) async def test_firmware_update_required( @@ -176,6 +179,8 @@ async def test_update_refresh_token_auth_failed( assert mock_nice_go.get_all_barriers.call_count == 1 assert mock_config_entry.data["refresh_token"] == "test-refresh-token" assert "Authentication failed" in caplog.text + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + assert any(mock_config_entry.async_get_active_flows(hass, {SOURCE_REAUTH})) async def test_client_listen_api_error(