Handle expiration of nest auth credentials (#44202)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
5bdf022bf2
commit
81341bbf91
4 changed files with 262 additions and 55 deletions
|
@ -6,14 +6,14 @@ import logging
|
|||
import threading
|
||||
|
||||
from google_nest_sdm.event import AsyncEventCallback, EventMessage
|
||||
from google_nest_sdm.exceptions import GoogleNestException
|
||||
from google_nest_sdm.exceptions import AuthException, GoogleNestException
|
||||
from google_nest_sdm.google_nest_subscriber import GoogleNestSubscriber
|
||||
from nest import Nest
|
||||
from nest.nest import APIError, AuthorizationError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
|
||||
from homeassistant.const import (
|
||||
CONF_BINARY_SENSORS,
|
||||
CONF_CLIENT_ID,
|
||||
|
@ -231,6 +231,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
|||
|
||||
try:
|
||||
await subscriber.start_async()
|
||||
except AuthException as err:
|
||||
_LOGGER.debug("Subscriber authentication error: %s", err)
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": SOURCE_REAUTH},
|
||||
data=entry.data,
|
||||
)
|
||||
)
|
||||
return False
|
||||
except GoogleNestException as err:
|
||||
_LOGGER.error("Subscriber error: %s", err)
|
||||
subscriber.stop_async()
|
||||
|
|
|
@ -75,6 +75,12 @@ class NestFlowHandler(
|
|||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_PUSH
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize NestFlowHandler."""
|
||||
super().__init__()
|
||||
# When invoked for reauth, allows updating an existing config entry
|
||||
self._reauth = False
|
||||
|
||||
@classmethod
|
||||
def register_sdm_api(cls, hass):
|
||||
"""Configure the flow handler to use the SDM API."""
|
||||
|
@ -103,19 +109,56 @@ class NestFlowHandler(
|
|||
|
||||
async def async_oauth_create_entry(self, data: dict) -> dict:
|
||||
"""Create an entry for the SDM flow."""
|
||||
assert self.is_sdm_api(), "Step only supported for SDM API"
|
||||
data[DATA_SDM] = {}
|
||||
await self.async_set_unique_id(DOMAIN)
|
||||
# Update existing config entry when in the reauth flow. This
|
||||
# integration only supports one config entry so remove any prior entries
|
||||
# added before the "single_instance_allowed" check was added
|
||||
existing_entries = self.hass.config_entries.async_entries(DOMAIN)
|
||||
if existing_entries:
|
||||
updated = False
|
||||
for entry in existing_entries:
|
||||
if updated:
|
||||
await self.hass.config_entries.async_remove(entry.entry_id)
|
||||
continue
|
||||
updated = True
|
||||
self.hass.config_entries.async_update_entry(
|
||||
entry, data=data, unique_id=DOMAIN
|
||||
)
|
||||
await self.hass.config_entries.async_reload(entry.entry_id)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
|
||||
return await super().async_oauth_create_entry(data)
|
||||
|
||||
async def async_step_reauth(self, user_input=None):
|
||||
"""Perform reauth upon an API authentication error."""
|
||||
assert self.is_sdm_api(), "Step only supported for SDM API"
|
||||
self._reauth = True # Forces update of existing config entry
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(self, user_input=None):
|
||||
"""Confirm reauth dialog."""
|
||||
assert self.is_sdm_api(), "Step only supported for SDM API"
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
data_schema=vol.Schema({}),
|
||||
)
|
||||
return await self.async_step_user()
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Handle a flow initialized by the user."""
|
||||
if self.is_sdm_api():
|
||||
# Reauth will update an existing entry
|
||||
if self.hass.config_entries.async_entries(DOMAIN) and not self._reauth:
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
return await super().async_step_user(user_input)
|
||||
return await self.async_step_init(user_input)
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Handle a flow start."""
|
||||
if self.is_sdm_api():
|
||||
raise UnexpectedStateError("Step only supported for legacy API")
|
||||
assert not self.is_sdm_api(), "Step only supported for legacy API"
|
||||
|
||||
flows = self.hass.data.get(DATA_FLOW_IMPL, {})
|
||||
|
||||
|
@ -145,8 +188,7 @@ class NestFlowHandler(
|
|||
implementation type we expect a pin or an external component to
|
||||
deliver the authentication code.
|
||||
"""
|
||||
if self.is_sdm_api():
|
||||
raise UnexpectedStateError("Step only supported for legacy API")
|
||||
assert not self.is_sdm_api(), "Step only supported for legacy API"
|
||||
|
||||
flow = self.hass.data[DATA_FLOW_IMPL][self.flow_impl]
|
||||
|
||||
|
@ -188,8 +230,7 @@ class NestFlowHandler(
|
|||
|
||||
async def async_step_import(self, info):
|
||||
"""Import existing auth from Nest."""
|
||||
if self.is_sdm_api():
|
||||
raise UnexpectedStateError("Step only supported for legacy API")
|
||||
assert not self.is_sdm_api(), "Step only supported for legacy API"
|
||||
|
||||
if self.hass.config_entries.async_entries(DOMAIN):
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
"pick_implementation": {
|
||||
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]"
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"title": "[%key:common::config_flow::title::reauth%]",
|
||||
"description": "The Nest integration needs to re-authenticate your account"
|
||||
},
|
||||
"init": {
|
||||
"title": "Authentication Provider",
|
||||
"description": "[%key:common::config_flow::title::oauth2_pick_implementation%]",
|
||||
|
@ -30,7 +34,8 @@
|
|||
"missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]",
|
||||
"authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]",
|
||||
"unknown_authorize_url_generation": "[%key:common::config_flow::abort::unknown_authorize_url_generation%]",
|
||||
"no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]"
|
||||
"no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
|
||||
},
|
||||
"create_entry": {
|
||||
"default": "[%key:common::config_flow::create_entry::authenticated%]"
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
"""Test the Google Nest Device Access config flow."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, setup
|
||||
from homeassistant.components.nest.const import DOMAIN, OAUTH2_AUTHORIZE, OAUTH2_TOKEN
|
||||
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from .common import MockConfigEntry
|
||||
|
||||
from tests.async_mock import patch
|
||||
|
||||
CLIENT_ID = "1234"
|
||||
|
@ -11,64 +16,210 @@ CLIENT_SECRET = "5678"
|
|||
PROJECT_ID = "project-id-4321"
|
||||
SUBSCRIBER_ID = "subscriber-id-9876"
|
||||
|
||||
CONFIG = {
|
||||
DOMAIN: {
|
||||
"project_id": PROJECT_ID,
|
||||
"subscriber_id": SUBSCRIBER_ID,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
"http": {"base_url": "https://example.com"},
|
||||
}
|
||||
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
DOMAIN,
|
||||
{
|
||||
DOMAIN: {
|
||||
"project_id": PROJECT_ID,
|
||||
"subscriber_id": SUBSCRIBER_ID,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
|
||||
def get_config_entry(hass):
|
||||
"""Return a single config entry."""
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
return entries[0]
|
||||
|
||||
|
||||
class OAuthFixture:
|
||||
"""Simulate the oauth flow used by the config flow."""
|
||||
|
||||
def __init__(self, hass, aiohttp_client, aioclient_mock):
|
||||
"""Initialize OAuthFixture."""
|
||||
self.hass = hass
|
||||
self.aiohttp_client = aiohttp_client
|
||||
self.aioclient_mock = aioclient_mock
|
||||
|
||||
async def async_oauth_flow(self, result):
|
||||
"""Invoke the oauth flow with fake responses."""
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
self.hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
"http": {"base_url": "https://example.com"},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID)
|
||||
assert result["type"] == "external"
|
||||
assert result["url"] == (
|
||||
f"{oauth_authorize}?response_type=code&client_id={CLIENT_ID}"
|
||||
"&redirect_uri=https://example.com/auth/external/callback"
|
||||
f"&state={state}&scope=https://www.googleapis.com/auth/sdm.service"
|
||||
"+https://www.googleapis.com/auth/pubsub"
|
||||
"&access_type=offline&prompt=consent"
|
||||
)
|
||||
|
||||
client = await self.aiohttp_client(self.hass.http.app)
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
self.aioclient_mock.post(
|
||||
OAUTH2_TOKEN,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.nest.async_setup_entry", return_value=True
|
||||
) as mock_setup:
|
||||
await self.hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def oauth(hass, aiohttp_client, aioclient_mock, current_request_with_host):
|
||||
"""Create the simulated oauth flow."""
|
||||
return OAuthFixture(hass, aiohttp_client, aioclient_mock)
|
||||
|
||||
|
||||
async def test_full_flow(hass, oauth):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(hass, DOMAIN, CONFIG)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
await oauth.async_oauth_flow(result)
|
||||
|
||||
entry = get_config_entry(hass)
|
||||
assert entry.title == "Configuration.yaml"
|
||||
assert "token" in entry.data
|
||||
entry.data["token"].pop("expires_at")
|
||||
assert entry.unique_id == DOMAIN
|
||||
assert entry.data["token"] == {
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
}
|
||||
|
||||
|
||||
async def test_reauth(hass, oauth):
|
||||
"""Test Nest reauthentication."""
|
||||
|
||||
assert await setup.async_setup_component(hass, DOMAIN, CONFIG)
|
||||
|
||||
old_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
"auth_implementation": DOMAIN,
|
||||
"token": {
|
||||
# Verify this is replaced at end of the test
|
||||
"access_token": "some-revoked-token",
|
||||
},
|
||||
"sdm": {},
|
||||
},
|
||||
unique_id=DOMAIN,
|
||||
)
|
||||
old_entry.add_to_hass(hass)
|
||||
|
||||
entry = get_config_entry(hass)
|
||||
assert entry.data["token"] == {
|
||||
"access_token": "some-revoked-token",
|
||||
}
|
||||
|
||||
await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_entry.data
|
||||
)
|
||||
|
||||
oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID)
|
||||
assert result["url"] == (
|
||||
f"{oauth_authorize}?response_type=code&client_id={CLIENT_ID}"
|
||||
"&redirect_uri=https://example.com/auth/external/callback"
|
||||
f"&state={state}&scope=https://www.googleapis.com/auth/sdm.service"
|
||||
"+https://www.googleapis.com/auth/pubsub"
|
||||
"&access_type=offline&prompt=consent"
|
||||
# Advance through the reauth flow
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
assert flows[0]["step_id"] == "reauth_confirm"
|
||||
|
||||
# Run the oauth flow
|
||||
result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {})
|
||||
await oauth.async_oauth_flow(result)
|
||||
|
||||
# Verify existing tokens are replaced
|
||||
entry = get_config_entry(hass)
|
||||
entry.data["token"].pop("expires_at")
|
||||
assert entry.unique_id == DOMAIN
|
||||
assert entry.data["token"] == {
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
}
|
||||
|
||||
|
||||
async def test_single_config_entry(hass):
|
||||
"""Test that only a single config entry is allowed."""
|
||||
old_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data={"auth_implementation": DOMAIN, "sdm": {}}
|
||||
)
|
||||
old_entry.add_to_hass(hass)
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
assert await setup.async_setup_component(hass, DOMAIN, CONFIG)
|
||||
|
||||
aioclient_mock.post(
|
||||
OAUTH2_TOKEN,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == "abort"
|
||||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.nest.async_setup_entry", return_value=True
|
||||
) as mock_setup:
|
||||
await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
async def test_unexpected_existing_config_entries(hass, oauth):
|
||||
"""Test Nest reauthentication with multiple existing config entries."""
|
||||
# Note that this case will not happen in the future since only a single
|
||||
# instance is now allowed, but this may have been allowed in the past.
|
||||
# On reauth, only one entry is kept and the others are deleted.
|
||||
|
||||
assert await setup.async_setup_component(hass, DOMAIN, CONFIG)
|
||||
|
||||
old_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data={"auth_implementation": DOMAIN, "sdm": {}}
|
||||
)
|
||||
old_entry.add_to_hass(hass)
|
||||
|
||||
old_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data={"auth_implementation": DOMAIN, "sdm": {}}
|
||||
)
|
||||
old_entry.add_to_hass(hass)
|
||||
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 2
|
||||
|
||||
# Invoke the reauth flow
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_entry.data
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {})
|
||||
await oauth.async_oauth_flow(result)
|
||||
|
||||
# Only a single entry now exists, and the other was cleaned up
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.unique_id == DOMAIN
|
||||
entry.data["token"].pop("expires_at")
|
||||
assert entry.data["token"] == {
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue