Add reauth flow to Google Tasks (#109517)

* Add reauth flow to Google Tasks

* Update homeassistant/components/google_tasks/config_flow.py

Co-authored-by: Jan-Philipp Benecke <github@bnck.me>

* Add tests

* Reauth

* Remove insta reauth

* Fix

---------

Co-authored-by: Jan-Philipp Benecke <github@bnck.me>
This commit is contained in:
Joost Lekkerkerker 2024-04-19 17:38:39 +02:00 committed by GitHub
parent ff83d9acff
commit c108c7df38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 204 additions and 24 deletions

View file

@ -2,12 +2,12 @@
from __future__ import annotations from __future__ import annotations
from aiohttp import ClientError from aiohttp import ClientError, ClientResponseError
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from . import api from . import api
@ -18,8 +18,6 @@ PLATFORMS: list[Platform] = [Platform.TODO]
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Google Tasks from a config entry.""" """Set up Google Tasks from a config entry."""
hass.data.setdefault(DOMAIN, {})
implementation = ( implementation = (
await config_entry_oauth2_flow.async_get_config_entry_implementation( await config_entry_oauth2_flow.async_get_config_entry_implementation(
hass, entry hass, entry
@ -29,10 +27,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
auth = api.AsyncConfigEntryAuth(hass, session) auth = api.AsyncConfigEntryAuth(hass, session)
try: try:
await auth.async_get_access_token() await auth.async_get_access_token()
except ClientResponseError as err:
if 400 <= err.status < 500:
raise ConfigEntryAuthFailed(
"OAuth session is not valid, reauth required"
) from err
raise ConfigEntryNotReady from err
except ClientError as err: except ClientError as err:
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
hass.data[DOMAIN][entry.entry_id] = auth hass.data.setdefault(DOMAIN, {})[entry.entry_id] = auth
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

View file

@ -1,5 +1,6 @@
"""Config flow for Google Tasks.""" """Config flow for Google Tasks."""
from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any
@ -8,7 +9,7 @@ from googleapiclient.discovery import build
from googleapiclient.errors import HttpError from googleapiclient.errors import HttpError
from googleapiclient.http import HttpRequest from googleapiclient.http import HttpRequest
from homeassistant.config_entries import ConfigFlowResult from homeassistant.config_entries import ConfigEntry, ConfigFlowResult
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
@ -22,6 +23,8 @@ class OAuth2FlowHandler(
DOMAIN = DOMAIN DOMAIN = DOMAIN
reauth_entry: ConfigEntry | None = None
@property @property
def logger(self) -> logging.Logger: def logger(self) -> logging.Logger:
"""Return logger.""" """Return logger."""
@ -39,11 +42,21 @@ class OAuth2FlowHandler(
async def async_oauth_create_entry(self, data: dict[str, Any]) -> ConfigFlowResult: async def async_oauth_create_entry(self, data: dict[str, Any]) -> ConfigFlowResult:
"""Create an entry for the flow.""" """Create an entry for the flow."""
credentials = Credentials(token=data[CONF_TOKEN][CONF_ACCESS_TOKEN])
try: try:
user_resource = build(
"oauth2",
"v2",
credentials=credentials,
)
user_resource_cmd: HttpRequest = user_resource.userinfo().get()
user_resource_info = await self.hass.async_add_executor_job(
user_resource_cmd.execute
)
resource = build( resource = build(
"tasks", "tasks",
"v1", "v1",
credentials=Credentials(token=data[CONF_TOKEN][CONF_ACCESS_TOKEN]), credentials=credentials,
) )
cmd: HttpRequest = resource.tasklists().list() cmd: HttpRequest = resource.tasklists().list()
await self.hass.async_add_executor_job(cmd.execute) await self.hass.async_add_executor_job(cmd.execute)
@ -56,4 +69,32 @@ class OAuth2FlowHandler(
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
self.logger.exception("Unknown error occurred") self.logger.exception("Unknown error occurred")
return self.async_abort(reason="unknown") return self.async_abort(reason="unknown")
return self.async_create_entry(title=self.flow_impl.name, data=data) user_id = user_resource_info["id"]
if not self.reauth_entry:
await self.async_set_unique_id(user_id)
self._abort_if_unique_id_configured()
return self.async_create_entry(title=user_resource_info["name"], data=data)
if self.reauth_entry.unique_id == user_id or not self.reauth_entry.unique_id:
return self.async_update_reload_and_abort(
self.reauth_entry, unique_id=user_id, data=data
)
return self.async_abort(reason="wrong_account")
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Perform reauth upon an API authentication error."""
self.reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Confirm reauth dialog."""
if user_input is None:
return self.async_show_form(step_id="reauth_confirm")
return await self.async_step_user()

View file

@ -6,7 +6,10 @@ DOMAIN = "google_tasks"
OAUTH2_AUTHORIZE = "https://accounts.google.com/o/oauth2/v2/auth" OAUTH2_AUTHORIZE = "https://accounts.google.com/o/oauth2/v2/auth"
OAUTH2_TOKEN = "https://oauth2.googleapis.com/token" OAUTH2_TOKEN = "https://oauth2.googleapis.com/token"
OAUTH2_SCOPES = ["https://www.googleapis.com/auth/tasks"] OAUTH2_SCOPES = [
"https://www.googleapis.com/auth/tasks",
"https://www.googleapis.com/auth/userinfo.profile",
]
class TaskStatus(StrEnum): class TaskStatus(StrEnum):

View file

@ -18,6 +18,7 @@
"user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]", "user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]",
"access_not_configured": "Unable to access the Google API:\n\n{message}", "access_not_configured": "Unable to access the Google API:\n\n{message}",
"unknown": "[%key:common::config_flow::error::unknown%]", "unknown": "[%key:common::config_flow::error::unknown%]",
"wrong_account": "Wrong account: Please authenticate with the right account.",
"oauth_timeout": "[%key:common::config_flow::abort::oauth2_timeout%]", "oauth_timeout": "[%key:common::config_flow::abort::oauth2_timeout%]",
"oauth_unauthorized": "[%key:common::config_flow::abort::oauth2_unauthorized%]", "oauth_unauthorized": "[%key:common::config_flow::abort::oauth2_unauthorized%]",
"oauth_failed": "[%key:common::config_flow::abort::oauth2_failed%]" "oauth_failed": "[%key:common::config_flow::abort::oauth2_failed%]"

View file

@ -54,6 +54,7 @@ def mock_config_entry(token_entry: dict[str, Any]) -> MockConfigEntry:
"""Fixture for a config entry.""" """Fixture for a config entry."""
return MockConfigEntry( return MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
unique_id="123",
data={ data={
"auth_implementation": DOMAIN, "auth_implementation": DOMAIN,
"token": token_entry, "token": token_entry,

View file

@ -1,9 +1,11 @@
"""Test the Google Tasks config flow.""" """Test the Google Tasks config flow."""
from unittest.mock import patch from collections.abc import Generator
from unittest.mock import Mock, patch
from googleapiclient.errors import HttpError from googleapiclient.errors import HttpError
from httplib2 import Response from httplib2 import Response
import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.google_tasks.const import ( from homeassistant.components.google_tasks.const import (
@ -15,18 +17,37 @@ from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from tests.common import load_fixture from tests.common import MockConfigEntry, load_fixture
from tests.test_util.aiohttp import AiohttpClientMocker
CLIENT_ID = "1234" CLIENT_ID = "1234"
CLIENT_SECRET = "5678" CLIENT_SECRET = "5678"
@pytest.fixture
def user_identifier() -> str:
"""Return a unique user ID."""
return "123"
@pytest.fixture
def setup_userinfo(user_identifier: str) -> Generator[Mock, None, None]:
"""Set up userinfo."""
with patch("homeassistant.components.google_tasks.config_flow.build") as mock:
mock.return_value.userinfo.return_value.get.return_value.execute.return_value = {
"id": user_identifier,
"name": "Test Name",
}
yield mock
async def test_full_flow( async def test_full_flow(
hass: HomeAssistant, hass: HomeAssistant,
hass_client_no_auth, hass_client_no_auth,
aioclient_mock, aioclient_mock: AiohttpClientMocker,
current_request_with_host, current_request_with_host,
setup_credentials, setup_credentials,
setup_userinfo,
) -> None: ) -> None:
"""Check full flow.""" """Check full flow."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -44,7 +65,8 @@ async def test_full_flow(
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}" f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
"&redirect_uri=https://example.com/auth/external/callback" "&redirect_uri=https://example.com/auth/external/callback"
f"&state={state}" f"&state={state}"
"&scope=https://www.googleapis.com/auth/tasks" "&scope=https://www.googleapis.com/auth/tasks+"
"https://www.googleapis.com/auth/userinfo.profile"
"&access_type=offline&prompt=consent" "&access_type=offline&prompt=consent"
) )
@ -63,14 +85,13 @@ async def test_full_flow(
}, },
) )
with ( with patch(
patch(
"homeassistant.components.google_tasks.async_setup_entry", return_value=True "homeassistant.components.google_tasks.async_setup_entry", return_value=True
) as mock_setup, ) as mock_setup:
patch("homeassistant.components.google_tasks.config_flow.build"),
):
result = await hass.config_entries.flow.async_configure(result["flow_id"]) result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.CREATE_ENTRY assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["result"].unique_id == "123"
assert result["result"].title == "Test Name"
assert len(hass.config_entries.async_entries(DOMAIN)) == 1 assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
@ -78,9 +99,10 @@ async def test_full_flow(
async def test_api_not_enabled( async def test_api_not_enabled(
hass: HomeAssistant, hass: HomeAssistant,
hass_client_no_auth, hass_client_no_auth,
aioclient_mock, aioclient_mock: AiohttpClientMocker,
current_request_with_host, current_request_with_host,
setup_credentials, setup_credentials,
setup_userinfo,
) -> None: ) -> None:
"""Check flow aborts if api is not enabled.""" """Check flow aborts if api is not enabled."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -98,7 +120,8 @@ async def test_api_not_enabled(
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}" f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
"&redirect_uri=https://example.com/auth/external/callback" "&redirect_uri=https://example.com/auth/external/callback"
f"&state={state}" f"&state={state}"
"&scope=https://www.googleapis.com/auth/tasks" "&scope=https://www.googleapis.com/auth/tasks+"
"https://www.googleapis.com/auth/userinfo.profile"
"&access_type=offline&prompt=consent" "&access_type=offline&prompt=consent"
) )
@ -137,9 +160,10 @@ async def test_api_not_enabled(
async def test_general_exception( async def test_general_exception(
hass: HomeAssistant, hass: HomeAssistant,
hass_client_no_auth, hass_client_no_auth,
aioclient_mock, aioclient_mock: AiohttpClientMocker,
current_request_with_host, current_request_with_host,
setup_credentials, setup_credentials,
setup_userinfo,
) -> None: ) -> None:
"""Check flow aborts if exception happens.""" """Check flow aborts if exception happens."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -157,7 +181,8 @@ async def test_general_exception(
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}" f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
"&redirect_uri=https://example.com/auth/external/callback" "&redirect_uri=https://example.com/auth/external/callback"
f"&state={state}" f"&state={state}"
"&scope=https://www.googleapis.com/auth/tasks" "&scope=https://www.googleapis.com/auth/tasks+"
"https://www.googleapis.com/auth/userinfo.profile"
"&access_type=offline&prompt=consent" "&access_type=offline&prompt=consent"
) )
@ -184,3 +209,108 @@ async def test_general_exception(
assert result["type"] is FlowResultType.ABORT assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "unknown" assert result["reason"] == "unknown"
@pytest.mark.parametrize(
("user_identifier", "abort_reason", "resulting_access_token", "starting_unique_id"),
[
(
"123",
"reauth_successful",
"updated-access-token",
"123",
),
(
"123",
"reauth_successful",
"updated-access-token",
None,
),
(
"345",
"wrong_account",
"mock-access",
"123",
),
],
)
async def test_reauth(
hass: HomeAssistant,
hass_client_no_auth,
aioclient_mock: AiohttpClientMocker,
current_request_with_host,
setup_credentials,
setup_userinfo,
user_identifier: str,
abort_reason: str,
resulting_access_token: str,
starting_unique_id: str | None,
) -> None:
"""Test the re-authentication case updates the correct config entry."""
config_entry = MockConfigEntry(
domain=DOMAIN,
unique_id=starting_unique_id,
data={
"token": {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access",
}
},
)
config_entry.add_to_hass(hass)
config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
result = flows[0]
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["url"] == (
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
"&redirect_uri=https://example.com/auth/external/callback"
f"&state={state}"
"&scope=https://www.googleapis.com/auth/tasks+"
"https://www.googleapis.com/auth/userinfo.profile"
"&access_type=offline&prompt=consent"
)
client = await hass_client_no_auth()
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
aioclient_mock.clear_requests()
aioclient_mock.post(
OAUTH2_TOKEN,
json={
"refresh_token": "mock-refresh-token",
"access_token": "updated-access-token",
"type": "Bearer",
"expires_in": 60,
},
)
with patch(
"homeassistant.components.google_tasks.async_setup_entry", return_value=True
):
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert result["type"] == "abort"
assert result["reason"] == abort_reason
assert config_entry.unique_id == "123"
assert "token" in config_entry.data
# Verify access token is refreshed
assert config_entry.data["token"]["access_token"] == resulting_access_token
assert config_entry.data["token"]["refresh_token"] == "mock-refresh-token"

View file

@ -68,7 +68,7 @@ async def test_expired_token_refresh_success(
( (
time.time() - 3600, time.time() - 3600,
http.HTTPStatus.UNAUTHORIZED, http.HTTPStatus.UNAUTHORIZED,
ConfigEntryState.SETUP_RETRY, # Will trigger reauth in the future ConfigEntryState.SETUP_ERROR,
), ),
( (
time.time() - 3600, time.time() - 3600,