Clean up Notion config flow (and tests) (#84007)

* Clean up Notion config flow (and tests)

* Code review
This commit is contained in:
Aaron Bach 2022-12-19 15:03:58 -07:00 committed by GitHub
parent 0d8cd2d067
commit ace20782f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 98 deletions

View file

@ -2,14 +2,16 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Any from typing import Any
from aionotion import async_get_client from aionotion import async_get_client
from aionotion.errors import InvalidCredentialsError, NotionError from aionotion.errors import InvalidCredentialsError, NotionError
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import aiohttp_client from homeassistant.helpers import aiohttp_client
@ -21,78 +23,84 @@ AUTH_SCHEMA = vol.Schema(
vol.Required(CONF_PASSWORD): str, vol.Required(CONF_PASSWORD): str,
} }
) )
RE_AUTH_SCHEMA = vol.Schema( REAUTH_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_PASSWORD): str, vol.Required(CONF_PASSWORD): str,
} }
) )
async def async_validate_credentials(
hass: HomeAssistant, username: str, password: str
) -> dict[str, Any]:
"""Validate a Notion username and password (returning any errors)."""
session = aiohttp_client.async_get_clientsession(hass)
errors = {}
try:
await async_get_client(username, password, session=session)
except InvalidCredentialsError:
errors["base"] = "invalid_auth"
except NotionError as err:
LOGGER.error("Unknown Notion error while validation credentials: %s", err)
errors["base"] = "unknown"
except Exception as err: # pylint: disable=broad-except
LOGGER.exception("Unknown error while validation credentials: %s", err)
errors["base"] = "unknown"
return errors
class NotionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): class NotionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a Notion config flow.""" """Handle a Notion config flow."""
VERSION = 1 VERSION = 1
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the config flow.""" """Initialize."""
self._password: str | None = None self._reauth_entry: ConfigEntry | None = None
self._username: str | None = None
async def _async_verify(self, step_id: str, schema: vol.Schema) -> FlowResult:
"""Attempt to authenticate the provided credentials."""
if TYPE_CHECKING:
assert self._username
assert self._password
errors = {}
session = aiohttp_client.async_get_clientsession(self.hass)
try:
await async_get_client(self._username, self._password, session=session)
except InvalidCredentialsError:
errors["base"] = "invalid_auth"
except NotionError as err:
LOGGER.error("Unknown Notion error: %s", err)
errors["base"] = "unknown"
if errors:
return self.async_show_form(
step_id=step_id,
data_schema=schema,
errors=errors,
description_placeholders={CONF_USERNAME: self._username},
)
data = {CONF_USERNAME: self._username, CONF_PASSWORD: self._password}
if existing_entry := await self.async_set_unique_id(self._username):
self.hass.config_entries.async_update_entry(existing_entry, data=data)
self.hass.async_create_task(
self.hass.config_entries.async_reload(existing_entry.entry_id)
)
return self.async_abort(reason="reauth_successful")
return self.async_create_entry(title=self._username, data=data)
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle configuration by re-auth.""" """Handle configuration by re-auth."""
self._username = entry_data[CONF_USERNAME] self._reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm() return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm( async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Handle re-auth completion.""" """Handle re-auth completion."""
assert self._reauth_entry
if not user_input: if not user_input:
return self.async_show_form( return self.async_show_form(
step_id="reauth_confirm", step_id="reauth_confirm",
data_schema=RE_AUTH_SCHEMA, data_schema=REAUTH_SCHEMA,
description_placeholders={CONF_USERNAME: self._username}, description_placeholders={
CONF_USERNAME: self._reauth_entry.data[CONF_USERNAME]
},
) )
self._password = user_input[CONF_PASSWORD] if errors := await async_validate_credentials(
self.hass, self._reauth_entry.data[CONF_USERNAME], user_input[CONF_PASSWORD]
):
return self.async_show_form(
step_id="reauth_confirm",
data_schema=REAUTH_SCHEMA,
errors=errors,
description_placeholders={
CONF_USERNAME: self._reauth_entry.data[CONF_USERNAME]
},
)
return await self._async_verify("reauth_confirm", RE_AUTH_SCHEMA) self.hass.config_entries.async_update_entry(
self._reauth_entry, data=self._reauth_entry.data | user_input
)
self.hass.async_create_task(
self.hass.config_entries.async_reload(self._reauth_entry.entry_id)
)
return self.async_abort(reason="reauth_successful")
async def async_step_user( async def async_step_user(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
@ -104,7 +112,13 @@ class NotionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
await self.async_set_unique_id(user_input[CONF_USERNAME]) await self.async_set_unique_id(user_input[CONF_USERNAME])
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
self._username = user_input[CONF_USERNAME] if errors := await async_validate_credentials(
self._password = user_input[CONF_PASSWORD] self.hass, user_input[CONF_USERNAME], user_input[CONF_PASSWORD]
):
return self.async_show_form(
step_id="user",
data_schema=AUTH_SCHEMA,
errors=errors,
)
return await self._async_verify("user", AUTH_SCHEMA) return self.async_create_entry(title=user_input[CONF_USERNAME], data=user_input)

View file

@ -56,12 +56,20 @@ def data_task_fixture():
return json.loads(load_fixture("task_data.json", "notion")) return json.loads(load_fixture("task_data.json", "notion"))
@pytest.fixture(name="get_client")
def get_client_fixture(client):
"""Define a fixture to mock the async_get_client method."""
return AsyncMock(return_value=client)
@pytest.fixture(name="setup_notion") @pytest.fixture(name="setup_notion")
async def setup_notion_fixture(hass, client, config): async def setup_notion_fixture(hass, config, get_client):
"""Define a fixture to set up Notion.""" """Define a fixture to set up Notion."""
with patch("homeassistant.components.notion.config_flow.async_get_client"), patch( with patch(
"homeassistant.components.notion.config_flow.async_get_client", get_client
), patch("homeassistant.components.notion.async_get_client", get_client), patch(
"homeassistant.components.notion.PLATFORMS", [] "homeassistant.components.notion.PLATFORMS", []
), patch("homeassistant.components.notion.async_get_client", return_value=client): ):
assert await async_setup_component(hass, DOMAIN, config) assert await async_setup_component(hass, DOMAIN, config)
await hass.async_block_till_done() await hass.async_block_till_done()
yield yield

View file

@ -1,5 +1,5 @@
"""Define tests for the Notion config flow.""" """Define tests for the Notion config flow."""
from unittest.mock import patch from unittest.mock import AsyncMock, patch
from aionotion.errors import InvalidCredentialsError, NotionError from aionotion.errors import InvalidCredentialsError, NotionError
import pytest import pytest
@ -10,6 +10,46 @@ from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
@pytest.mark.parametrize(
"get_client_with_exception,errors",
[
(AsyncMock(side_effect=Exception), {"base": "unknown"}),
(AsyncMock(side_effect=InvalidCredentialsError), {"base": "invalid_auth"}),
(AsyncMock(side_effect=NotionError), {"base": "unknown"}),
],
)
async def test_create_entry(
hass, client, config, errors, get_client_with_exception, setup_notion
):
"""Test creating an etry (including recovery from errors)."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "user"
# Test errors that can arise when getting a Notion API client:
with patch(
"homeassistant.components.notion.config_flow.async_get_client",
get_client_with_exception,
):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}, data=config
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["errors"] == errors
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=config
)
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert result["title"] == "user@host.com"
assert result["data"] == {
CONF_USERNAME: "user@host.com",
CONF_PASSWORD: "password123",
}
async def test_duplicate_error(hass, config, config_entry): async def test_duplicate_error(hass, config, config_entry):
"""Test that errors are shown when duplicates are added.""" """Test that errors are shown when duplicates are added."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -20,62 +60,42 @@ async def test_duplicate_error(hass, config, config_entry):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"exc,error", "get_client_with_exception,errors",
[ [
(NotionError, "unknown"), (AsyncMock(side_effect=Exception), {"base": "unknown"}),
(InvalidCredentialsError, "invalid_auth"), (AsyncMock(side_effect=InvalidCredentialsError), {"base": "invalid_auth"}),
(AsyncMock(side_effect=NotionError), {"base": "unknown"}),
], ],
) )
async def test_erros(hass, config, error, exc): async def test_reauth(
"""Test that exceptions show the correct error.""" hass, config, config_entry, errors, get_client_with_exception, setup_notion
with patch( ):
"homeassistant.components.notion.config_flow.async_get_client", side_effect=exc """Test that re-auth works."""
):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}, data=config
)
assert result["errors"] == {"base": error}
async def test_step_reauth(hass, config, config_entry, setup_notion):
"""Test that the reauth step works."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_REAUTH}, data=config DOMAIN,
context={
"source": SOURCE_REAUTH,
"entry_id": config_entry.entry_id,
"unique_id": config_entry.unique_id,
},
data=config,
) )
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"]) # Test errors that can arise when getting a Notion API client:
assert result["type"] == data_entry_flow.FlowResultType.FORM with patch(
assert result["step_id"] == "reauth_confirm" "homeassistant.components.notion.config_flow.async_get_client",
get_client_with_exception,
with patch("homeassistant.components.notion.async_setup_entry", return_value=True): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={CONF_PASSWORD: "password"} result["flow_id"], user_input={CONF_PASSWORD: "password"}
) )
await hass.async_block_till_done() assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["errors"] == errors
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={CONF_PASSWORD: "password"}
)
assert result["type"] == data_entry_flow.FlowResultType.ABORT assert result["type"] == data_entry_flow.FlowResultType.ABORT
assert result["reason"] == "reauth_successful" assert result["reason"] == "reauth_successful"
assert len(hass.config_entries.async_entries()) == 1 assert len(hass.config_entries.async_entries()) == 1
async def test_show_form(hass):
"""Test that the form is served with no input."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "user"
async def test_step_user(hass, config, setup_notion):
"""Test that the user step works."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}, data=config
)
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert result["title"] == "user@host.com"
assert result["data"] == {
CONF_USERNAME: "user@host.com",
CONF_PASSWORD: "password123",
}