diff --git a/homeassistant/components/notion/config_flow.py b/homeassistant/components/notion/config_flow.py index 917c0f8ebb9..1e4adab2910 100644 --- a/homeassistant/components/notion/config_flow.py +++ b/homeassistant/components/notion/config_flow.py @@ -2,14 +2,16 @@ from __future__ import annotations from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import Any from aionotion import async_get_client from aionotion.errors import InvalidCredentialsError, NotionError import voluptuous as vol from homeassistant import config_entries +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_PASSWORD, CONF_USERNAME +from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import aiohttp_client @@ -21,78 +23,84 @@ AUTH_SCHEMA = vol.Schema( vol.Required(CONF_PASSWORD): str, } ) -RE_AUTH_SCHEMA = vol.Schema( +REAUTH_SCHEMA = vol.Schema( { vol.Required(CONF_PASSWORD): str, } ) +async def async_validate_credentials( + hass: HomeAssistant, username: str, password: str +) -> dict[str, Any]: + """Validate a Notion username and password (returning any errors).""" + session = aiohttp_client.async_get_clientsession(hass) + errors = {} + + try: + await async_get_client(username, password, session=session) + except InvalidCredentialsError: + errors["base"] = "invalid_auth" + except NotionError as err: + LOGGER.error("Unknown Notion error while validation credentials: %s", err) + errors["base"] = "unknown" + except Exception as err: # pylint: disable=broad-except + LOGGER.exception("Unknown error while validation credentials: %s", err) + errors["base"] = "unknown" + + return errors + + class NotionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle a Notion config flow.""" VERSION = 1 def __init__(self) -> None: - """Initialize the config flow.""" - self._password: str | None = None - self._username: str | None = None - - async def _async_verify(self, step_id: str, schema: vol.Schema) -> FlowResult: - """Attempt to authenticate the provided credentials.""" - if TYPE_CHECKING: - assert self._username - assert self._password - - errors = {} - session = aiohttp_client.async_get_clientsession(self.hass) - - try: - await async_get_client(self._username, self._password, session=session) - except InvalidCredentialsError: - errors["base"] = "invalid_auth" - except NotionError as err: - LOGGER.error("Unknown Notion error: %s", err) - errors["base"] = "unknown" - - if errors: - return self.async_show_form( - step_id=step_id, - data_schema=schema, - errors=errors, - description_placeholders={CONF_USERNAME: self._username}, - ) - - data = {CONF_USERNAME: self._username, CONF_PASSWORD: self._password} - - if existing_entry := await self.async_set_unique_id(self._username): - self.hass.config_entries.async_update_entry(existing_entry, data=data) - self.hass.async_create_task( - self.hass.config_entries.async_reload(existing_entry.entry_id) - ) - return self.async_abort(reason="reauth_successful") - - return self.async_create_entry(title=self._username, data=data) + """Initialize.""" + self._reauth_entry: ConfigEntry | None = None async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: """Handle configuration by re-auth.""" - self._username = entry_data[CONF_USERNAME] + self._reauth_entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) return await self.async_step_reauth_confirm() async def async_step_reauth_confirm( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle re-auth completion.""" + assert self._reauth_entry + if not user_input: return self.async_show_form( step_id="reauth_confirm", - data_schema=RE_AUTH_SCHEMA, - description_placeholders={CONF_USERNAME: self._username}, + data_schema=REAUTH_SCHEMA, + description_placeholders={ + CONF_USERNAME: self._reauth_entry.data[CONF_USERNAME] + }, ) - self._password = user_input[CONF_PASSWORD] + if errors := await async_validate_credentials( + self.hass, self._reauth_entry.data[CONF_USERNAME], user_input[CONF_PASSWORD] + ): + return self.async_show_form( + step_id="reauth_confirm", + data_schema=REAUTH_SCHEMA, + errors=errors, + description_placeholders={ + CONF_USERNAME: self._reauth_entry.data[CONF_USERNAME] + }, + ) - return await self._async_verify("reauth_confirm", RE_AUTH_SCHEMA) + self.hass.config_entries.async_update_entry( + self._reauth_entry, data=self._reauth_entry.data | user_input + ) + self.hass.async_create_task( + self.hass.config_entries.async_reload(self._reauth_entry.entry_id) + ) + return self.async_abort(reason="reauth_successful") async def async_step_user( self, user_input: dict[str, str] | None = None @@ -104,7 +112,13 @@ class NotionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): await self.async_set_unique_id(user_input[CONF_USERNAME]) self._abort_if_unique_id_configured() - self._username = user_input[CONF_USERNAME] - self._password = user_input[CONF_PASSWORD] + if errors := await async_validate_credentials( + self.hass, user_input[CONF_USERNAME], user_input[CONF_PASSWORD] + ): + return self.async_show_form( + step_id="user", + data_schema=AUTH_SCHEMA, + errors=errors, + ) - return await self._async_verify("user", AUTH_SCHEMA) + return self.async_create_entry(title=user_input[CONF_USERNAME], data=user_input) diff --git a/tests/components/notion/conftest.py b/tests/components/notion/conftest.py index 7d87b9adc64..e29ea83ef2a 100644 --- a/tests/components/notion/conftest.py +++ b/tests/components/notion/conftest.py @@ -56,12 +56,20 @@ def data_task_fixture(): return json.loads(load_fixture("task_data.json", "notion")) +@pytest.fixture(name="get_client") +def get_client_fixture(client): + """Define a fixture to mock the async_get_client method.""" + return AsyncMock(return_value=client) + + @pytest.fixture(name="setup_notion") -async def setup_notion_fixture(hass, client, config): +async def setup_notion_fixture(hass, config, get_client): """Define a fixture to set up Notion.""" - with patch("homeassistant.components.notion.config_flow.async_get_client"), patch( + with patch( + "homeassistant.components.notion.config_flow.async_get_client", get_client + ), patch("homeassistant.components.notion.async_get_client", get_client), patch( "homeassistant.components.notion.PLATFORMS", [] - ), patch("homeassistant.components.notion.async_get_client", return_value=client): + ): assert await async_setup_component(hass, DOMAIN, config) await hass.async_block_till_done() yield diff --git a/tests/components/notion/test_config_flow.py b/tests/components/notion/test_config_flow.py index 92e285ba899..0eff3890274 100644 --- a/tests/components/notion/test_config_flow.py +++ b/tests/components/notion/test_config_flow.py @@ -1,5 +1,5 @@ """Define tests for the Notion config flow.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from aionotion.errors import InvalidCredentialsError, NotionError import pytest @@ -10,6 +10,46 @@ from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER from homeassistant.const import CONF_PASSWORD, CONF_USERNAME +@pytest.mark.parametrize( + "get_client_with_exception,errors", + [ + (AsyncMock(side_effect=Exception), {"base": "unknown"}), + (AsyncMock(side_effect=InvalidCredentialsError), {"base": "invalid_auth"}), + (AsyncMock(side_effect=NotionError), {"base": "unknown"}), + ], +) +async def test_create_entry( + hass, client, config, errors, get_client_with_exception, setup_notion +): + """Test creating an etry (including recovery from errors).""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == "user" + + # Test errors that can arise when getting a Notion API client: + with patch( + "homeassistant.components.notion.config_flow.async_get_client", + get_client_with_exception, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER}, data=config + ) + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["errors"] == errors + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=config + ) + assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY + assert result["title"] == "user@host.com" + assert result["data"] == { + CONF_USERNAME: "user@host.com", + CONF_PASSWORD: "password123", + } + + async def test_duplicate_error(hass, config, config_entry): """Test that errors are shown when duplicates are added.""" result = await hass.config_entries.flow.async_init( @@ -20,62 +60,42 @@ async def test_duplicate_error(hass, config, config_entry): @pytest.mark.parametrize( - "exc,error", + "get_client_with_exception,errors", [ - (NotionError, "unknown"), - (InvalidCredentialsError, "invalid_auth"), + (AsyncMock(side_effect=Exception), {"base": "unknown"}), + (AsyncMock(side_effect=InvalidCredentialsError), {"base": "invalid_auth"}), + (AsyncMock(side_effect=NotionError), {"base": "unknown"}), ], ) -async def test_erros(hass, config, error, exc): - """Test that exceptions show the correct error.""" - with patch( - "homeassistant.components.notion.config_flow.async_get_client", side_effect=exc - ): - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=config - ) - assert result["errors"] == {"base": error} - - -async def test_step_reauth(hass, config, config_entry, setup_notion): - """Test that the reauth step works.""" +async def test_reauth( + hass, config, config_entry, errors, get_client_with_exception, setup_notion +): + """Test that re-auth works.""" result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_REAUTH}, data=config + DOMAIN, + context={ + "source": SOURCE_REAUTH, + "entry_id": config_entry.entry_id, + "unique_id": config_entry.unique_id, + }, + data=config, ) assert result["step_id"] == "reauth_confirm" - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] == data_entry_flow.FlowResultType.FORM - assert result["step_id"] == "reauth_confirm" - - with patch("homeassistant.components.notion.async_setup_entry", return_value=True): + # Test errors that can arise when getting a Notion API client: + with patch( + "homeassistant.components.notion.config_flow.async_get_client", + get_client_with_exception, + ): result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={CONF_PASSWORD: "password"} ) - await hass.async_block_till_done() + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["errors"] == errors + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_PASSWORD: "password"} + ) assert result["type"] == data_entry_flow.FlowResultType.ABORT assert result["reason"] == "reauth_successful" assert len(hass.config_entries.async_entries()) == 1 - - -async def test_show_form(hass): - """Test that the form is served with no input.""" - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER} - ) - assert result["type"] == data_entry_flow.FlowResultType.FORM - assert result["step_id"] == "user" - - -async def test_step_user(hass, config, setup_notion): - """Test that the user step works.""" - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=config - ) - assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY - assert result["title"] == "user@host.com" - assert result["data"] == { - CONF_USERNAME: "user@host.com", - CONF_PASSWORD: "password123", - }