Add cloud account linking support (#28210)

* Add cloud account linking support

* Update account_link.py
This commit is contained in:
Paulus Schoutsen 2019-10-25 16:04:24 -07:00 committed by GitHub
parent 475b43500a
commit 08cc9fd375
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 407 additions and 21 deletions

View file

@ -33,6 +33,8 @@ STAGE_1_INTEGRATIONS = {
"recorder",
# To make sure we forward data to other instances
"mqtt_eventstream",
# To provide account link implementations
"cloud",
}

View file

@ -20,7 +20,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter
from homeassistant.loader import bind_hass
from homeassistant.util.aiohttp import MockRequest
from . import http_api
from . import account_link, http_api
from .client import CloudClient
from .const import (
CONF_ACME_DIRECTORY_SERVER,
@ -38,6 +38,7 @@ from .const import (
CONF_REMOTE_API_URL,
CONF_SUBSCRIPTION_INFO_URL,
CONF_USER_POOL_ID,
CONF_ACCOUNT_LINK_URL,
DOMAIN,
MODE_DEV,
MODE_PROD,
@ -101,6 +102,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(CONF_GOOGLE_ACTIONS): GACTIONS_SCHEMA,
vol.Optional(CONF_ALEXA_ACCESS_TOKEN_URL): vol.Url(),
vol.Optional(CONF_GOOGLE_ACTIONS_REPORT_STATE_URL): vol.Url(),
vol.Optional(CONF_ACCOUNT_LINK_URL): vol.Url(),
}
)
},
@ -168,7 +170,6 @@ def is_cloudhook_request(request):
async def async_setup(hass, config):
"""Initialize the Home Assistant cloud."""
# Process configs
if DOMAIN in config:
kwargs = dict(config[DOMAIN])
@ -248,4 +249,7 @@ async def async_setup(hass, config):
cloud.iot.register_on_connect(_on_connect)
await http_api.async_setup(hass)
account_link.async_setup(hass)
return True

View file

@ -0,0 +1,132 @@
"""Account linking via the cloud."""
import asyncio
import logging
from typing import Any
from hass_nabucasa import account_link
from homeassistant.const import MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import event, config_entry_oauth2_flow
from .const import DOMAIN
DATA_SERVICES = "cloud_account_link_services"
CACHE_TIMEOUT = 3600
PATCH_VERSION = int(PATCH_VERSION.split(".")[0])
_LOGGER = logging.getLogger(__name__)
@callback
def async_setup(hass: HomeAssistant):
"""Set up cloud account link."""
config_entry_oauth2_flow.async_add_implementation_provider(
hass, DOMAIN, async_provide_implementation
)
async def async_provide_implementation(hass: HomeAssistant, domain: str):
"""Provide an implementation for a domain."""
services = await _get_services(hass)
for service in services:
if service["service"] == domain and _is_older(service["min_version"]):
return CloudOAuth2Implementation(hass, domain)
return
@callback
def _is_older(version: str) -> bool:
"""Test if a version is older than the current HA version."""
version_parts = version.split(".")
if len(version_parts) != 3:
return False
try:
version_parts = [int(val) for val in version_parts]
except ValueError:
return False
cur_version_parts = [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
return version_parts <= cur_version_parts
async def _get_services(hass):
"""Get the available services."""
services = hass.data.get(DATA_SERVICES)
if services is not None:
return services
services = await account_link.async_fetch_available_services(hass.data[DOMAIN])
hass.data[DATA_SERVICES] = services
@callback
def clear_services(_now):
"""Clear services cache."""
hass.data.pop(DATA_SERVICES, None)
event.async_call_later(hass, CACHE_TIMEOUT, clear_services)
return services
class CloudOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation):
"""Cloud implementation of the OAuth2 flow."""
def __init__(self, hass: HomeAssistant, service: str):
"""Initialize cloud OAuth2 implementation."""
self.hass = hass
self.service = service
@property
def name(self) -> str:
"""Name of the implementation."""
return "Home Assistant Cloud"
@property
def domain(self) -> str:
"""Domain that is providing the implementation."""
return DOMAIN
async def async_generate_authorize_url(self, flow_id: str) -> str:
"""Generate a url for the user to authorize."""
helper = account_link.AuthorizeAccountHelper(
self.hass.data[DOMAIN], self.service
)
authorize_url = await helper.async_get_authorize_url()
async def await_tokens():
"""Wait for tokens and pass them on when received."""
try:
tokens = await helper.async_get_tokens()
except asyncio.TimeoutError:
_LOGGER.info("Timeout fetching tokens for flow %s", flow_id)
except account_link.AccountLinkException as err:
_LOGGER.info(
"Failed to fetch tokens for flow %s: %s", flow_id, err.code
)
else:
await self.hass.config_entries.flow.async_configure(
flow_id=flow_id, user_input=tokens
)
self.hass.async_create_task(await_tokens())
return authorize_url
async def async_resolve_external_data(self, external_data: Any) -> dict:
"""Resolve external data to tokens."""
# We already passed in tokens
return external_data
async def _async_refresh_token(self, token: dict) -> dict:
"""Refresh a token."""
return await account_link.async_fetch_access_token(
self.hass.data[DOMAIN], self.service, token["refresh_token"]
)

View file

@ -37,6 +37,7 @@ CONF_REMOTE_API_URL = "remote_api_url"
CONF_ACME_DIRECTORY_SERVER = "acme_directory_server"
CONF_ALEXA_ACCESS_TOKEN_URL = "alexa_access_token_url"
CONF_GOOGLE_ACTIONS_REPORT_STATE_URL = "google_actions_report_state_url"
CONF_ACCOUNT_LINK_URL = "account_link_url"
MODE_DEV = "development"
MODE_PROD = "production"

View file

@ -2,7 +2,7 @@
"domain": "cloud",
"name": "Cloud",
"documentation": "https://www.home-assistant.io/integrations/cloud",
"requirements": ["hass-nabucasa==0.22"],
"requirements": ["hass-nabucasa==0.23"],
"dependencies": ["http", "webhook"],
"codeowners": ["@home-assistant/cloud"]
}

View file

@ -8,6 +8,11 @@
"create_entry": {
"default": "Successfully authenticated with Somfy."
},
"step": {
"pick_implementation": {
"title": "Pick Authentication Method"
}
},
"title": "Somfy"
}
}

View file

@ -1,13 +1,18 @@
{
"config": {
"abort": {
"already_setup": "You can only configure one Somfy account.",
"authorize_url_timeout": "Timeout generating authorize url.",
"missing_configuration": "The Somfy component is not configured. Please follow the documentation."
},
"create_entry": {
"default": "Successfully authenticated with Somfy."
},
"title": "Somfy"
}
}
"config": {
"step": {
"pick_implementation": {
"title": "Pick Authentication Method"
}
},
"abort": {
"already_setup": "You can only configure one Somfy account.",
"authorize_url_timeout": "Timeout generating authorize url.",
"missing_configuration": "The Somfy component is not configured. Please follow the documentation."
},
"create_entry": {
"default": "Successfully authenticated with Somfy."
},
"title": "Somfy"
}
}

View file

@ -8,7 +8,7 @@ This module exists of the following parts:
import asyncio
from abc import ABCMeta, ABC, abstractmethod
import logging
from typing import Optional, Any, Dict, cast
from typing import Optional, Any, Dict, cast, Awaitable, Callable
import time
import async_timeout
@ -28,6 +28,7 @@ from .aiohttp_client import async_get_clientsession
DATA_JWT_SECRET = "oauth2_jwt_secret"
DATA_VIEW_REGISTERED = "oauth2_view_reg"
DATA_IMPLEMENTATIONS = "oauth2_impl"
DATA_PROVIDERS = "oauth2_providers"
AUTH_CALLBACK_PATH = "/auth/external/callback"
@ -291,11 +292,23 @@ async def async_get_implementations(
hass: HomeAssistant, domain: str
) -> Dict[str, AbstractOAuth2Implementation]:
"""Return OAuth2 implementations for specified domain."""
return cast(
registered = cast(
Dict[str, AbstractOAuth2Implementation],
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
)
if DATA_PROVIDERS not in hass.data:
return registered
registered = dict(registered)
for provider_domain, get_impl in hass.data[DATA_PROVIDERS].items():
implementation = await get_impl(hass, domain)
if implementation is not None:
registered[provider_domain] = implementation
return registered
async def async_get_config_entry_implementation(
hass: HomeAssistant, config_entry: config_entries.ConfigEntry
@ -310,6 +323,23 @@ async def async_get_config_entry_implementation(
return implementation
@callback
def async_add_implementation_provider(
hass: HomeAssistant,
provider_domain: str,
async_provide_implementation: Callable[
[HomeAssistant, str], Awaitable[Optional[AbstractOAuth2Implementation]]
],
) -> None:
"""Add an implementation provider.
If no implementation found, return None.
"""
hass.data.setdefault(DATA_PROVIDERS, {})[
provider_domain
] = async_provide_implementation
class OAuth2AuthorizeCallbackView(HomeAssistantView):
"""OAuth2 Authorization Callback View."""
@ -355,9 +385,14 @@ class OAuth2Session:
self.config_entry = config_entry
self.implementation = implementation
@property
def token(self) -> dict:
"""Return the current token."""
return cast(dict, self.config_entry.data["token"])
async def async_ensure_token_valid(self) -> None:
"""Ensure that the current token is valid."""
token = self.config_entry.data["token"]
token = self.token
if token["expires_at"] > time.time():
return

View file

@ -10,7 +10,7 @@ certifi>=2019.9.11
contextvars==2.4;python_version<"3.7"
cryptography==2.8
distro==1.4.0
hass-nabucasa==0.22
hass-nabucasa==0.23
home-assistant-frontend==20191025.0
importlib-metadata==0.23
jinja2>=2.10.1

View file

@ -616,7 +616,7 @@ habitipy==0.2.0
hangups==0.4.9
# homeassistant.components.cloud
hass-nabucasa==0.22
hass-nabucasa==0.23
# homeassistant.components.mqtt
hbmqtt==0.9.5

View file

@ -225,7 +225,7 @@ ha-ffmpeg==2.0
hangups==0.4.9
# homeassistant.components.cloud
hass-nabucasa==0.22
hass-nabucasa==0.23
# homeassistant.components.mqtt
hbmqtt==0.9.5

View file

@ -0,0 +1,160 @@
"""Test account link services."""
import asyncio
import logging
from time import time
from unittest.mock import Mock, patch
import pytest
from homeassistant import data_entry_flow, config_entries
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.components.cloud import account_link
from homeassistant.util.dt import utcnow
from tests.common import mock_coro, async_fire_time_changed, mock_platform
TEST_DOMAIN = "oauth2_test"
@pytest.fixture
def flow_handler(hass):
"""Return a registered config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
"""Test flow handler."""
DOMAIN = TEST_DOMAIN
@property
def logger(self) -> logging.Logger:
"""Return logger."""
return logging.getLogger(__name__)
with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}):
yield TestFlowHandler
async def test_setup_provide_implementation(hass):
"""Test that we provide implementations."""
account_link.async_setup(hass)
with patch(
"homeassistant.components.cloud.account_link._get_services",
side_effect=lambda _: mock_coro(
[
{"service": "test", "min_version": "0.1.0"},
{"service": "too_new", "min_version": "100.0.0"},
]
),
):
assert (
await config_entry_oauth2_flow.async_get_implementations(
hass, "non_existing"
)
== {}
)
assert (
await config_entry_oauth2_flow.async_get_implementations(hass, "too_new")
== {}
)
implementations = await config_entry_oauth2_flow.async_get_implementations(
hass, "test"
)
assert "cloud" in implementations
assert implementations["cloud"].domain == "cloud"
assert implementations["cloud"].service == "test"
assert implementations["cloud"].hass is hass
async def test_get_services_cached(hass):
"""Test that we cache services."""
hass.data["cloud"] = None
services = 1
with patch.object(account_link, "CACHE_TIMEOUT", 0), patch(
"hass_nabucasa.account_link.async_fetch_available_services",
side_effect=lambda _: mock_coro(services),
) as mock_fetch:
assert await account_link._get_services(hass) == 1
services = 2
assert len(mock_fetch.mock_calls) == 1
assert await account_link._get_services(hass) == 1
services = 3
hass.data.pop(account_link.DATA_SERVICES)
assert await account_link._get_services(hass) == 3
services = 4
async_fire_time_changed(hass, utcnow())
await hass.async_block_till_done()
# Check cache purged
assert await account_link._get_services(hass) == 4
async def test_implementation(hass, flow_handler):
"""Test Cloud OAuth2 implementation."""
hass.data["cloud"] = None
impl = account_link.CloudOAuth2Implementation(hass, "test")
assert impl.name == "Home Assistant Cloud"
assert impl.domain == "cloud"
flow_handler.async_register_implementation(hass, impl)
flow_finished = asyncio.Future()
helper = Mock(
async_get_authorize_url=Mock(return_value=mock_coro("http://example.com/auth")),
async_get_tokens=Mock(return_value=flow_finished),
)
with patch(
"hass_nabucasa.account_link.AuthorizeAccountHelper", return_value=helper
):
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == "http://example.com/auth"
flow_finished.set_result(
{
"refresh_token": "mock-refresh",
"access_token": "mock-access",
"expires_in": 10,
"token_type": "bearer",
}
)
await hass.async_block_till_done()
# Flow finished!
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["data"]["auth_implementation"] == "cloud"
expires_at = result["data"]["token"].pop("expires_at")
assert round(expires_at - time()) == 10
assert result["data"]["token"] == {
"refresh_token": "mock-refresh",
"access_token": "mock-access",
"token_type": "bearer",
"expires_in": 10,
}
entry = hass.config_entries.async_entries(TEST_DOMAIN)[0]
assert (
await config_entry_oauth2_flow.async_get_config_entry_implementation(
hass, entry
)
is impl
)

View file

@ -264,3 +264,45 @@ async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock):
assert config_entry.data["token"]["expires_in"] == 100
assert config_entry.data["token"]["random_other_data"] == "should_stay"
assert round(config_entry.data["token"]["expires_at"] - now) == 100
async def test_implementation_provider(hass, local_impl):
"""Test providing an implementation provider."""
assert (
await config_entry_oauth2_flow.async_get_implementations(hass, TEST_DOMAIN)
== {}
)
mock_domain_with_impl = "some_domain"
config_entry_oauth2_flow.async_register_implementation(
hass, mock_domain_with_impl, local_impl
)
assert await config_entry_oauth2_flow.async_get_implementations(
hass, mock_domain_with_impl
) == {TEST_DOMAIN: local_impl}
provider_source = {}
async def async_provide_implementation(hass, domain):
"""Mock implementation provider."""
return provider_source.get(domain)
config_entry_oauth2_flow.async_add_implementation_provider(
hass, "cloud", async_provide_implementation
)
assert await config_entry_oauth2_flow.async_get_implementations(
hass, mock_domain_with_impl
) == {TEST_DOMAIN: local_impl}
provider_source[
mock_domain_with_impl
] = config_entry_oauth2_flow.LocalOAuth2Implementation(
hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
)
assert await config_entry_oauth2_flow.async_get_implementations(
hass, mock_domain_with_impl
) == {TEST_DOMAIN: local_impl, "cloud": provider_source[mock_domain_with_impl]}