From 00b5d30e24dccebcc61839be7cf6ca9d87b2a3de Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 30 Apr 2022 08:06:43 -0700 Subject: [PATCH] Add application credentials platform (#69148) * Initial developer credentials scaffolding - Support websocket list/add/delete - Add developer credentials protocol from yaml config - Handle OAuth credential registration and de-registration - Tests for websocket and integration based registration * Fix pydoc text * Remove translations and update owners * Update homeassistant/components/developer_credentials/__init__.py Co-authored-by: Paulus Schoutsen * Update homeassistant/components/developer_credentials/__init__.py Co-authored-by: Paulus Schoutsen * Remove _async_get_developer_credential * Rename to application credentials platform * Fix race condition and add import support * Increase code coverage (92%) * Increase test coverage 93% * Increase test coverage (94%) * Increase test coverage (97%) * Increase test covearge (98%) * Increase test coverage (99%) * Increase test coverage (100%) * Remove http router frozen comment * Remove auth domain override on import * Remove debug statement * Don't import the same client id multiple times * Add auth dependency for local oauth implementation * Revert older oauth2 changes from merge * Update homeassistant/components/application_credentials/__init__.py Co-authored-by: Martin Hjelmare * Move config credential import to its own fixture * Override the mock_application_credentials_integration fixture instead per test * Update application credentials * Add dictionary typing * Use f-strings as per feedback * Add additional structure needed for an MVP application credential Add additional structure needed for an MVP, including a target component Xbox * Add websocket to list supported integrations for frontend selector * Application credentials config * Import xbox credentials * Remove unnecessary async calls * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare * Import credentials with a fixed auth domain Resolve an issue with compatibility of exisiting config entries when importing client credentials Co-authored-by: Paulus Schoutsen Co-authored-by: Martin Hjelmare --- CODEOWNERS | 2 + .../application_credentials/__init__.py | 242 +++++++ .../application_credentials/manifest.json | 9 + .../application_credentials/strings.json | 3 + .../components/cloud/account_link.py | 4 +- .../components/default_config/manifest.json | 1 + homeassistant/components/xbox/__init__.py | 17 +- .../xbox/application_credentials.py | 14 + homeassistant/components/xbox/manifest.json | 2 +- .../generated/application_credentials.py | 10 + .../helpers/config_entry_oauth2_flow.py | 9 +- script/hassfest/__main__.py | 2 + script/hassfest/application_credentials.py | 63 ++ script/hassfest/manifest.py | 1 + .../application_credentials/__init__.py | 1 + .../application_credentials/test_init.py | 623 ++++++++++++++++++ .../helpers/test_config_entry_oauth2_flow.py | 28 +- 17 files changed, 1006 insertions(+), 25 deletions(-) create mode 100644 homeassistant/components/application_credentials/__init__.py create mode 100644 homeassistant/components/application_credentials/manifest.json create mode 100644 homeassistant/components/application_credentials/strings.json create mode 100644 homeassistant/components/xbox/application_credentials.py create mode 100644 homeassistant/generated/application_credentials.py create mode 100644 script/hassfest/application_credentials.py create mode 100644 tests/components/application_credentials/__init__.py create mode 100644 tests/components/application_credentials/test_init.py diff --git a/CODEOWNERS b/CODEOWNERS index c3405001f23..b8d737cdfb7 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -75,6 +75,8 @@ build.json @home-assistant/supervisor /tests/components/api/ @home-assistant/core /homeassistant/components/apple_tv/ @postlund /tests/components/apple_tv/ @postlund +/homeassistant/components/application_credentials/ @home-assistant/core +/tests/components/application_credentials/ @home-assistant/core /homeassistant/components/apprise/ @caronc /tests/components/apprise/ @caronc /homeassistant/components/aprs/ @PhilRW diff --git a/homeassistant/components/application_credentials/__init__.py b/homeassistant/components/application_credentials/__init__.py new file mode 100644 index 00000000000..cc5ed5e44bb --- /dev/null +++ b/homeassistant/components/application_credentials/__init__.py @@ -0,0 +1,242 @@ +"""The Application Credentials integration. + +This integration provides APIs for managing local OAuth credentials on behalf +of other integrations. Integrations register an authorization server, and then +the APIs are used to add one or more client credentials. Integrations may also +provide credentials from yaml for backwards compatibility. +""" +from __future__ import annotations + +from dataclasses import dataclass +import logging +from typing import Any, Protocol + +import voluptuous as vol + +from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.connection import ActiveConnection +from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN, CONF_ID +from homeassistant.core import HomeAssistant, callback +from homeassistant.generated.application_credentials import APPLICATION_CREDENTIALS +from homeassistant.helpers import collection, config_entry_oauth2_flow +import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.storage import Store +from homeassistant.helpers.typing import ConfigType +from homeassistant.loader import IntegrationNotFound, async_get_integration +from homeassistant.util import slugify + +__all__ = ["ClientCredential", "AuthorizationServer", "async_import_client_credential"] + +_LOGGER = logging.getLogger(__name__) + +DOMAIN = "application_credentials" + +STORAGE_KEY = DOMAIN +STORAGE_VERSION = 1 +DATA_STORAGE = "storage" +CONF_AUTH_DOMAIN = "auth_domain" + +CREATE_FIELDS = { + vol.Required(CONF_DOMAIN): cv.string, + vol.Required(CONF_CLIENT_ID): cv.string, + vol.Required(CONF_CLIENT_SECRET): cv.string, + vol.Optional(CONF_AUTH_DOMAIN): cv.string, +} +UPDATE_FIELDS: dict = {} # Not supported + + +@dataclass +class ClientCredential: + """Represent an OAuth client credential.""" + + client_id: str + client_secret: str + + +@dataclass +class AuthorizationServer: + """Represent an OAuth2 Authorization Server.""" + + authorize_url: str + token_url: str + + +class ApplicationCredentialsStorageCollection(collection.StorageCollection): + """Application credential collection stored in storage.""" + + CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) + + async def _process_create_data(self, data: dict[str, str]) -> dict[str, str]: + """Validate the config is valid.""" + result = self.CREATE_SCHEMA(data) + domain = result[CONF_DOMAIN] + if not await _get_platform(self.hass, domain): + raise ValueError(f"No application_credentials platform for {domain}") + return result + + @callback + def _get_suggested_id(self, info: dict[str, str]) -> str: + """Suggest an ID based on the config.""" + return f"{info[CONF_DOMAIN]}.{info[CONF_CLIENT_ID]}" + + async def _update_data( + self, data: dict[str, str], update_data: dict[str, str] + ) -> dict[str, str]: + """Return a new updated data object.""" + raise ValueError("Updates not supported") + + async def async_delete_item(self, item_id: str) -> None: + """Delete item, verifying credential is not in use.""" + if item_id not in self.data: + raise collection.ItemNotFound(item_id) + + # Cannot delete a credential currently in use by a ConfigEntry + current = self.data[item_id] + entries = self.hass.config_entries.async_entries(current[CONF_DOMAIN]) + for entry in entries: + if entry.data.get("auth_implementation") == item_id: + raise ValueError("Cannot delete credential in use by an integration") + + await super().async_delete_item(item_id) + + async def async_import_item(self, info: dict[str, str]) -> None: + """Import an yaml credential if it does not already exist.""" + suggested_id = self._get_suggested_id(info) + if self.id_manager.has_id(slugify(suggested_id)): + return + await self.async_create_item(info) + + def async_client_credentials(self, domain: str) -> dict[str, ClientCredential]: + """Return ClientCredentials in storage for the specified domain.""" + credentials = {} + for item in self.async_items(): + if item[CONF_DOMAIN] != domain: + continue + auth_domain = ( + item[CONF_AUTH_DOMAIN] if CONF_AUTH_DOMAIN in item else item[CONF_ID] + ) + credentials[auth_domain] = ClientCredential( + item[CONF_CLIENT_ID], item[CONF_CLIENT_SECRET] + ) + return credentials + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up Application Credentials.""" + hass.data[DOMAIN] = {} + + id_manager = collection.IDManager() + storage_collection = ApplicationCredentialsStorageCollection( + Store(hass, STORAGE_VERSION, STORAGE_KEY), + logging.getLogger(f"{__name__}.storage_collection"), + id_manager, + ) + await storage_collection.async_load() + hass.data[DOMAIN][DATA_STORAGE] = storage_collection + + collection.StorageCollectionWebsocket( + storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS + ).async_setup(hass) + + websocket_api.async_register_command(hass, handle_integration_list) + + config_entry_oauth2_flow.async_add_implementation_provider( + hass, DOMAIN, _async_provide_implementation + ) + + return True + + +async def async_import_client_credential( + hass: HomeAssistant, domain: str, credential: ClientCredential +) -> None: + """Import an existing credential from configuration.yaml.""" + if DOMAIN not in hass.data: + raise ValueError("Integration 'application_credentials' not setup") + storage_collection = hass.data[DOMAIN][DATA_STORAGE] + item = { + CONF_DOMAIN: domain, + CONF_CLIENT_ID: credential.client_id, + CONF_CLIENT_SECRET: credential.client_secret, + CONF_AUTH_DOMAIN: domain, + } + await storage_collection.async_import_item(item) + + +class AuthImplementation(config_entry_oauth2_flow.LocalOAuth2Implementation): + """Application Credentials local oauth2 implementation.""" + + @property + def name(self) -> str: + """Name of the implementation.""" + return self.client_id + + +async def _async_provide_implementation( + hass: HomeAssistant, domain: str +) -> list[config_entry_oauth2_flow.AbstractOAuth2Implementation]: + """Return registered OAuth implementations.""" + + platform = await _get_platform(hass, domain) + if not platform: + return [] + + authorization_server = await platform.async_get_authorization_server(hass) + storage_collection = hass.data[DOMAIN][DATA_STORAGE] + credentials = storage_collection.async_client_credentials(domain) + return [ + AuthImplementation( + hass, + auth_domain, + credential.client_id, + credential.client_secret, + authorization_server.authorize_url, + authorization_server.token_url, + ) + for auth_domain, credential in credentials.items() + ] + + +class ApplicationCredentialsProtocol(Protocol): + """Define the format that application_credentials platforms can have.""" + + async def async_get_authorization_server( + self, hass: HomeAssistant + ) -> AuthorizationServer: + """Return authorization server.""" + + +async def _get_platform( + hass: HomeAssistant, integration_domain: str +) -> ApplicationCredentialsProtocol | None: + """Register an application_credentials platform.""" + try: + integration = await async_get_integration(hass, integration_domain) + except IntegrationNotFound as err: + _LOGGER.debug("Integration '%s' does not exist: %s", integration_domain, err) + return None + try: + platform = integration.get_platform("application_credentials") + except ImportError as err: + _LOGGER.debug( + "Integration '%s' does not provide application_credentials: %s", + integration_domain, + err, + ) + return None + if not hasattr(platform, "async_get_authorization_server"): + raise ValueError( + f"Integration '{integration_domain}' platform application_credentials did not implement 'async_get_authorization_server'" + ) + return platform + + +@websocket_api.websocket_command( + {vol.Required("type"): "application_credentials/config"} +) +@callback +def handle_integration_list( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle integrations command.""" + connection.send_result(msg["id"], {"domains": APPLICATION_CREDENTIALS}) diff --git a/homeassistant/components/application_credentials/manifest.json b/homeassistant/components/application_credentials/manifest.json new file mode 100644 index 00000000000..9a8abc16c36 --- /dev/null +++ b/homeassistant/components/application_credentials/manifest.json @@ -0,0 +1,9 @@ +{ + "domain": "application_credentials", + "name": "Application Credentials", + "config_flow": false, + "documentation": "https://www.home-assistant.io/integrations/application_credentials", + "dependencies": ["auth", "websocket_api"], + "codeowners": ["@home-assistant/core"], + "quality_scale": "internal" +} diff --git a/homeassistant/components/application_credentials/strings.json b/homeassistant/components/application_credentials/strings.json new file mode 100644 index 00000000000..48d74bc75e4 --- /dev/null +++ b/homeassistant/components/application_credentials/strings.json @@ -0,0 +1,3 @@ +{ + "title": "Application Credentials" +} diff --git a/homeassistant/components/cloud/account_link.py b/homeassistant/components/cloud/account_link.py index 6dc0da82512..5df16cb1724 100644 --- a/homeassistant/components/cloud/account_link.py +++ b/homeassistant/components/cloud/account_link.py @@ -34,9 +34,9 @@ async def async_provide_implementation(hass: HomeAssistant, domain: str): for service in services: if service["service"] == domain and CURRENT_VERSION >= service["min_version"]: - return CloudOAuth2Implementation(hass, domain) + return [CloudOAuth2Implementation(hass, domain)] - return + return [] async def _get_services(hass): diff --git a/homeassistant/components/default_config/manifest.json b/homeassistant/components/default_config/manifest.json index 1ab827529c6..1742092cc70 100644 --- a/homeassistant/components/default_config/manifest.json +++ b/homeassistant/components/default_config/manifest.json @@ -3,6 +3,7 @@ "name": "Default Config", "documentation": "https://www.home-assistant.io/integrations/default_config", "dependencies": [ + "application_credentials", "automation", "cloud", "counter", diff --git a/homeassistant/components/xbox/__init__.py b/homeassistant/components/xbox/__init__.py index 0466d0191cf..2b5772dd0ba 100644 --- a/homeassistant/components/xbox/__init__.py +++ b/homeassistant/components/xbox/__init__.py @@ -20,6 +20,7 @@ from xbox.webapi.api.provider.smartglass.models import ( SmartglassConsoleStatus, ) +from homeassistant.components import application_credentials from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, Platform from homeassistant.core import HomeAssistant @@ -31,8 +32,8 @@ from homeassistant.helpers import ( from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.update_coordinator import DataUpdateCoordinator -from . import api, config_flow -from .const import DOMAIN, OAUTH2_AUTHORIZE, OAUTH2_TOKEN +from . import api +from .const import DOMAIN _LOGGER = logging.getLogger(__name__) @@ -63,15 +64,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: if DOMAIN not in config: return True - config_flow.OAuth2FlowHandler.async_register_implementation( + await application_credentials.async_import_client_credential( hass, - config_entry_oauth2_flow.LocalOAuth2Implementation( - hass, - DOMAIN, - config[DOMAIN][CONF_CLIENT_ID], - config[DOMAIN][CONF_CLIENT_SECRET], - OAUTH2_AUTHORIZE, - OAUTH2_TOKEN, + DOMAIN, + application_credentials.ClientCredential( + config[DOMAIN][CONF_CLIENT_ID], config[DOMAIN][CONF_CLIENT_SECRET] ), ) diff --git a/homeassistant/components/xbox/application_credentials.py b/homeassistant/components/xbox/application_credentials.py new file mode 100644 index 00000000000..2e3d7f8a6a0 --- /dev/null +++ b/homeassistant/components/xbox/application_credentials.py @@ -0,0 +1,14 @@ +"""Application credentials platform for xbox.""" + +from homeassistant.components.application_credentials import AuthorizationServer +from homeassistant.core import HomeAssistant + +from .const import OAUTH2_AUTHORIZE, OAUTH2_TOKEN + + +async def async_get_authorization_server(hass: HomeAssistant) -> AuthorizationServer: + """Return authorization server.""" + return AuthorizationServer( + authorize_url=OAUTH2_AUTHORIZE, + token_url=OAUTH2_TOKEN, + ) diff --git a/homeassistant/components/xbox/manifest.json b/homeassistant/components/xbox/manifest.json index 432b3e84100..5adfa54a901 100644 --- a/homeassistant/components/xbox/manifest.json +++ b/homeassistant/components/xbox/manifest.json @@ -4,7 +4,7 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/xbox", "requirements": ["xbox-webapi==2.0.11"], - "dependencies": ["auth"], + "dependencies": ["auth", "application_credentials"], "codeowners": ["@hunterjm"], "iot_class": "cloud_polling" } diff --git a/homeassistant/generated/application_credentials.py b/homeassistant/generated/application_credentials.py new file mode 100644 index 00000000000..ec6c1886e0a --- /dev/null +++ b/homeassistant/generated/application_credentials.py @@ -0,0 +1,10 @@ +"""Automatically generated by hassfest. + +To update, run python3 -m script.hassfest +""" + +# fmt: off + +APPLICATION_CREDENTIALS = [ + "xbox" +] diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index e2b21522d42..d0aaca71304 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -347,10 +347,9 @@ async def async_get_implementations( return registered registered = dict(registered) - - for provider_domain, get_impl in hass.data[DATA_PROVIDERS].items(): - if (implementation := await get_impl(hass, domain)) is not None: - registered[provider_domain] = implementation + for get_impl in list(hass.data[DATA_PROVIDERS].values()): + for impl in await get_impl(hass, domain): + registered[impl.domain] = impl return registered @@ -373,7 +372,7 @@ def async_add_implementation_provider( hass: HomeAssistant, provider_domain: str, async_provide_implementation: Callable[ - [HomeAssistant, str], Awaitable[AbstractOAuth2Implementation | None] + [HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]] ], ) -> None: """Add an implementation provider. diff --git a/script/hassfest/__main__.py b/script/hassfest/__main__.py index c6a9799a502..889cad2a497 100644 --- a/script/hassfest/__main__.py +++ b/script/hassfest/__main__.py @@ -5,6 +5,7 @@ import sys from time import monotonic from . import ( + application_credentials, codeowners, config_flow, coverage, @@ -25,6 +26,7 @@ from . import ( from .model import Config, Integration INTEGRATION_PLUGINS = [ + application_credentials, codeowners, config_flow, dependencies, diff --git a/script/hassfest/application_credentials.py b/script/hassfest/application_credentials.py new file mode 100644 index 00000000000..87a277bb2b8 --- /dev/null +++ b/script/hassfest/application_credentials.py @@ -0,0 +1,63 @@ +"""Generate application_credentials data.""" +from __future__ import annotations + +import json + +from .model import Config, Integration + +BASE = """ +\"\"\"Automatically generated by hassfest. + +To update, run python3 -m script.hassfest +\"\"\" + +# fmt: off + +APPLICATION_CREDENTIALS = {} +""".strip() + + +def generate_and_validate(integrations: dict[str, Integration], config: Config) -> str: + """Validate and generate config flow data.""" + + match_list = [] + + for domain in sorted(integrations): + integration = integrations[domain] + application_credentials_file = integration.path / "application_credentials.py" + if not application_credentials_file.is_file(): + continue + + match_list.append(domain) + + return BASE.format(json.dumps(match_list, indent=4)) + + +def validate(integrations: dict[str, Integration], config: Config) -> None: + """Validate application_credentials data.""" + application_credentials_path = ( + config.root / "homeassistant/generated/application_credentials.py" + ) + config.cache["application_credentials"] = content = generate_and_validate( + integrations, config + ) + + if config.specific_integrations: + return + + if application_credentials_path.read_text(encoding="utf-8").strip() != content: + config.add_error( + "application_credentials", + "File application_credentials.py is not up to date. Run python3 -m script.hassfest", + fixable=True, + ) + + +def generate(integrations: dict[str, Integration], config: Config): + """Generate application_credentials data.""" + application_credentials_path = ( + config.root / "homeassistant/generated/application_credentials.py" + ) + application_credentials_path.write_text( + f"{config.cache['application_credentials']}\n", encoding="utf-8" + ) diff --git a/script/hassfest/manifest.py b/script/hassfest/manifest.py index ca9acedd515..b66b33486cb 100644 --- a/script/hassfest/manifest.py +++ b/script/hassfest/manifest.py @@ -36,6 +36,7 @@ SUPPORTED_IOT_CLASSES = [ NO_IOT_CLASS = [ *{platform.value for platform in Platform}, "api", + "application_credentials", "auth", "automation", "blueprint", diff --git a/tests/components/application_credentials/__init__.py b/tests/components/application_credentials/__init__.py new file mode 100644 index 00000000000..36933b9ccfb --- /dev/null +++ b/tests/components/application_credentials/__init__.py @@ -0,0 +1 @@ +"""Tests for the Application Credentials integration.""" diff --git a/tests/components/application_credentials/test_init.py b/tests/components/application_credentials/test_init.py new file mode 100644 index 00000000000..31cf45f2b54 --- /dev/null +++ b/tests/components/application_credentials/test_init.py @@ -0,0 +1,623 @@ +"""Test the Developer Credentials integration.""" + +from __future__ import annotations + +from collections.abc import Callable, Generator +import logging +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +from aiohttp import ClientWebSocketResponse +import pytest + +from homeassistant import config_entries, data_entry_flow +from homeassistant.components.application_credentials import ( + CONF_AUTH_DOMAIN, + DOMAIN, + AuthorizationServer, + ClientCredential, + async_import_client_credential, +) +from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_entry_oauth2_flow +from homeassistant.setup import async_setup_component + +from tests.common import mock_platform + +CLIENT_ID = "some-client-id" +CLIENT_SECRET = "some-client-secret" +DEVELOPER_CREDENTIAL = ClientCredential(CLIENT_ID, CLIENT_SECRET) +ID = "fake_integration_some_client_id" +AUTHORIZE_URL = "https://example.com/auth" +TOKEN_URL = "https://example.com/oauth2/v4/token" +REFRESH_TOKEN = "mock-refresh-token" +ACCESS_TOKEN = "mock-access-token" + +TEST_DOMAIN = "fake_integration" + + +@pytest.fixture +async def authorization_server() -> AuthorizationServer: + """Fixture AuthorizationServer for mock application_credentials integration.""" + return AuthorizationServer(AUTHORIZE_URL, TOKEN_URL) + + +@pytest.fixture +async def config_credential() -> ClientCredential | None: + """Fixture ClientCredential for mock application_credentials integration.""" + return None + + +@pytest.fixture +async def import_config_credential( + hass: HomeAssistant, config_credential: ClientCredential +) -> None: + """Fixture to import the yaml based credential.""" + await async_import_client_credential(hass, TEST_DOMAIN, config_credential) + + +async def setup_application_credentials_integration( + hass: HomeAssistant, + domain: str, + authorization_server: AuthorizationServer, +) -> None: + """Set up a fake application_credentials integration.""" + hass.config.components.add(domain) + mock_platform( + hass, + f"{domain}.application_credentials", + Mock( + async_get_authorization_server=AsyncMock(return_value=authorization_server), + ), + ) + + +@pytest.fixture(autouse=True) +async def mock_application_credentials_integration( + hass: HomeAssistant, + authorization_server: AuthorizationServer, +): + """Mock a application_credentials integration.""" + assert await async_setup_component(hass, "application_credentials", {}) + await setup_application_credentials_integration( + hass, TEST_DOMAIN, authorization_server + ) + + +class FakeConfigFlow(config_entry_oauth2_flow.AbstractOAuth2FlowHandler, domain=DOMAIN): + """Config flow used during tests.""" + + DOMAIN = TEST_DOMAIN + + @property + def logger(self) -> logging.Logger: + """Return logger.""" + return logging.getLogger(__name__) + + +@pytest.fixture(autouse=True) +def config_flow_handler( + hass: HomeAssistant, current_request_with_host: Any +) -> Generator[FakeConfigFlow, None, None]: + """Fixture for a test config flow.""" + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: FakeConfigFlow}): + yield FakeConfigFlow + + +class OAuthFixture: + """Fixture to facilitate testing an OAuth flow.""" + + def __init__(self, hass, hass_client, aioclient_mock): + """Initialize OAuthFixture.""" + self.hass = hass + self.hass_client = hass_client + self.aioclient_mock = aioclient_mock + self.client_id = CLIENT_ID + + async def complete_external_step( + self, result: data_entry_flow.FlowResult + ) -> data_entry_flow.FlowResult: + """Fixture method to complete the OAuth flow and return the completed result.""" + client = await self.hass_client() + state = config_entry_oauth2_flow._encode_jwt( + self.hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) + assert result["url"] == ( + f"{AUTHORIZE_URL}?response_type=code&client_id={self.client_id}" + "&redirect_uri=https://example.com/auth/external/callback" + f"&state={state}" + ) + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + + self.aioclient_mock.post( + TOKEN_URL, + json={ + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN, + "type": "bearer", + "expires_in": 60, + }, + ) + + result = await self.hass.config_entries.flow.async_configure(result["flow_id"]) + assert result.get("type") == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result.get("title") == self.client_id + assert "data" in result + assert "token" in result["data"] + return result + + +@pytest.fixture +async def oauth_fixture( + hass: HomeAssistant, hass_client_no_auth: Any, aioclient_mock: Any +) -> OAuthFixture: + """Fixture for testing the OAuth flow.""" + return OAuthFixture(hass, hass_client_no_auth, aioclient_mock) + + +class Client: + """Test client with helper methods for application credentials websocket.""" + + def __init__(self, client): + """Initialize Client.""" + self.client = client + self.id = 0 + + async def cmd(self, cmd: str, payload: dict[str, Any] = None) -> dict[str, Any]: + """Send a command and receive the json result.""" + self.id += 1 + await self.client.send_json( + { + "id": self.id, + "type": f"{DOMAIN}/{cmd}", + **(payload if payload is not None else {}), + } + ) + resp = await self.client.receive_json() + assert resp.get("id") == self.id + return resp + + async def cmd_result(self, cmd: str, payload: dict[str, Any] = None) -> Any: + """Send a command and parse the result.""" + resp = await self.cmd(cmd, payload) + assert resp.get("success") + assert resp.get("type") == "result" + return resp.get("result") + + +ClientFixture = Callable[[], Client] + + +@pytest.fixture +async def ws_client( + hass_ws_client: Callable[[...], ClientWebSocketResponse] +) -> ClientFixture: + """Fixture for creating the test websocket client.""" + + async def create_client() -> Client: + ws_client = await hass_ws_client() + return Client(ws_client) + + return create_client + + +async def test_websocket_list_empty(ws_client: ClientFixture): + """Test websocket list command.""" + client = await ws_client() + assert await client.cmd_result("list") == [] + + +async def test_websocket_create(ws_client: ClientFixture): + """Test websocket create command.""" + client = await ws_client() + result = await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + assert result == { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + "id": ID, + } + + result = await client.cmd_result("list") + assert result == [ + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + "id": ID, + } + ] + + +async def test_websocket_create_invalid_domain(ws_client: ClientFixture): + """Test websocket create command.""" + client = await ws_client() + resp = await client.cmd( + "create", + { + CONF_DOMAIN: "other-domain", + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + assert not resp.get("success") + assert "error" in resp + assert resp["error"].get("code") == "invalid_format" + assert ( + resp["error"].get("message") + == "No application_credentials platform for other-domain" + ) + + +async def test_websocket_update_not_supported(ws_client: ClientFixture): + """Test websocket update command in unsupported.""" + client = await ws_client() + result = await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + assert result == { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + "id": ID, + } + + resp = await client.cmd("update", {"application_credentials_id": ID}) + assert not resp.get("success") + assert "error" in resp + assert resp["error"].get("code") == "invalid_format" + assert resp["error"].get("message") == "Updates not supported" + + +async def test_websocket_delete(ws_client: ClientFixture): + """Test websocket delete command.""" + client = await ws_client() + + await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + assert await client.cmd_result("list") == [ + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + "id": ID, + } + ] + + await client.cmd_result("delete", {"application_credentials_id": ID}) + assert await client.cmd_result("list") == [] + + +async def test_websocket_delete_item_not_found(ws_client: ClientFixture): + """Test websocket delete command.""" + client = await ws_client() + + resp = await client.cmd("delete", {"application_credentials_id": ID}) + assert not resp.get("success") + assert "error" in resp + assert resp["error"].get("code") == "not_found" + assert ( + resp["error"].get("message") + == f"Unable to find application_credentials_id {ID}" + ) + + +@pytest.mark.parametrize("config_credential", [DEVELOPER_CREDENTIAL]) +async def test_websocket_import_config( + ws_client: ClientFixture, + config_credential: ClientCredential, + import_config_credential: Any, +): + """Test websocket list command for an imported credential.""" + client = await ws_client() + + # Imported creds returned from websocket + assert await client.cmd_result("list") == [ + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + "id": ID, + CONF_AUTH_DOMAIN: TEST_DOMAIN, + } + ] + + # Imported credential can be deleted + await client.cmd_result("delete", {"application_credentials_id": ID}) + assert await client.cmd_result("list") == [] + + +@pytest.mark.parametrize("config_credential", [DEVELOPER_CREDENTIAL]) +async def test_import_duplicate_credentials( + hass: HomeAssistant, + ws_client: ClientFixture, + config_credential: ClientCredential, + import_config_credential: Any, +): + """Exercise duplicate credentials are ignored.""" + + # Import the test credential again and verify it is not imported twice + await async_import_client_credential(hass, TEST_DOMAIN, DEVELOPER_CREDENTIAL) + client = await ws_client() + assert await client.cmd_result("list") == [ + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + "id": ID, + CONF_AUTH_DOMAIN: TEST_DOMAIN, + } + ] + + +async def test_config_flow_no_credentials(hass): + """Test config flow base case with no credentials registered.""" + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT + assert result.get("reason") == "missing_configuration" + + +async def test_config_flow_other_domain( + hass: HomeAssistant, + ws_client: ClientFixture, + authorization_server: AuthorizationServer, +): + """Test config flow ignores credentials for another domain.""" + await setup_application_credentials_integration( + hass, + "other_domain", + authorization_server, + ) + client = await ws_client() + await client.cmd_result( + "create", + { + CONF_DOMAIN: "other_domain", + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT + assert result.get("reason") == "missing_configuration" + + +async def test_config_flow( + hass: HomeAssistant, + ws_client: ClientFixture, + oauth_fixture: OAuthFixture, +): + """Test config flow with application credential registered.""" + client = await ws_client() + + await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + result = await oauth_fixture.complete_external_step(result) + assert ( + result["data"].get("auth_implementation") == "fake_integration_some_client_id" + ) + + # Verify it is not possible to delete an in-use config entry + resp = await client.cmd("delete", {"application_credentials_id": ID}) + assert not resp.get("success") + assert "error" in resp + assert resp["error"].get("code") == "unknown_error" + + +async def test_config_flow_multiple_entries( + hass: HomeAssistant, + ws_client: ClientFixture, + oauth_fixture: OAuthFixture, +): + """Test config flow with multiple application credentials registered.""" + client = await ws_client() + + await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID + "2", + CONF_CLIENT_SECRET: CLIENT_SECRET + "2", + }, + ) + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_FORM + assert result.get("step_id") == "pick_implementation" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={"implementation": "fake_integration_some_client_id2"}, + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + oauth_fixture.client_id = CLIENT_ID + "2" + result = await oauth_fixture.complete_external_step(result) + assert ( + result["data"].get("auth_implementation") == "fake_integration_some_client_id2" + ) + + +async def test_config_flow_create_delete_credential( + hass: HomeAssistant, + ws_client: ClientFixture, + oauth_fixture: OAuthFixture, +): + """Test adding and deleting a credential unregisters from the config flow.""" + client = await ws_client() + + await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + await client.cmd("delete", {"application_credentials_id": ID}) + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT + assert result.get("reason") == "missing_configuration" + + +@pytest.mark.parametrize("config_credential", [DEVELOPER_CREDENTIAL]) +async def test_config_flow_with_config_credential( + hass, + hass_client_no_auth, + aioclient_mock, + oauth_fixture, + config_credential, + import_config_credential, +): + """Test config flow with application credential registered.""" + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + result = await oauth_fixture.complete_external_step(result) + # Uses the imported auth domain for compatibility + assert result["data"].get("auth_implementation") == TEST_DOMAIN + + +@pytest.mark.parametrize("mock_application_credentials_integration", [None]) +async def test_import_without_setup(hass, config_credential): + """Test import of credentials without setting up the integration.""" + + with pytest.raises(ValueError): + await async_import_client_credential(hass, TEST_DOMAIN, config_credential) + + # Config flow does not have authentication + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT + assert result.get("reason") == "missing_configuration" + + +@pytest.mark.parametrize("mock_application_credentials_integration", [None]) +async def test_websocket_without_platform( + hass: HomeAssistant, ws_client: ClientFixture +): + """Test an integration without the application credential platform.""" + assert await async_setup_component(hass, "application_credentials", {}) + hass.config.components.add(TEST_DOMAIN) + + client = await ws_client() + resp = await client.cmd( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + assert not resp.get("success") + assert "error" in resp + assert resp["error"].get("code") == "invalid_format" + + # Config flow does not have authentication + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT + assert result.get("reason") == "missing_configuration" + + +@pytest.mark.parametrize("mock_application_credentials_integration", [None]) +async def test_websocket_without_authorization_server( + hass: HomeAssistant, ws_client: ClientFixture +): + """Test platform with incorrect implementation.""" + assert await async_setup_component(hass, "application_credentials", {}) + hass.config.components.add(TEST_DOMAIN) + + # Platform does not implemenent async_get_authorization_server + platform = Mock() + del platform.async_get_authorization_server + mock_platform( + hass, + f"{TEST_DOMAIN}.application_credentials", + platform, + ) + + client = await ws_client() + resp = await client.cmd( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + ) + assert not resp.get("success") + assert "error" in resp + assert resp["error"].get("code") == "invalid_format" + + # Config flow does not have authentication + with pytest.raises(ValueError): + await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + +async def test_websocket_integration_list(ws_client: ClientFixture): + """Test websocket integration list command.""" + client = await ws_client() + with patch( + "homeassistant.components.application_credentials.APPLICATION_CREDENTIALS", + ["example1", "example2"], + ): + assert await client.cmd_result("config") == { + "domains": ["example1", "example2"] + } diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index 2cd3184b44b..97e728d022d 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -537,11 +537,11 @@ async def test_implementation_provider(hass, local_impl): hass, mock_domain_with_impl ) == {TEST_DOMAIN: local_impl} - provider_source = {} + provider_source = [] async def async_provide_implementation(hass, domain): """Mock implementation provider.""" - return provider_source.get(domain) + return provider_source config_entry_oauth2_flow.async_add_implementation_provider( hass, "cloud", async_provide_implementation @@ -551,15 +551,29 @@ async def test_implementation_provider(hass, local_impl): hass, mock_domain_with_impl ) == {TEST_DOMAIN: local_impl} - provider_source[ - mock_domain_with_impl - ] = config_entry_oauth2_flow.LocalOAuth2Implementation( - hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL + provider_source.append( + config_entry_oauth2_flow.LocalOAuth2Implementation( + hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL + ) ) assert await config_entry_oauth2_flow.async_get_implementations( hass, mock_domain_with_impl - ) == {TEST_DOMAIN: local_impl, "cloud": provider_source[mock_domain_with_impl]} + ) == {TEST_DOMAIN: local_impl, "cloud": provider_source[0]} + + provider_source.append( + config_entry_oauth2_flow.LocalOAuth2Implementation( + hass, "other", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL + ) + ) + + assert await config_entry_oauth2_flow.async_get_implementations( + hass, mock_domain_with_impl + ) == { + TEST_DOMAIN: local_impl, + "cloud": provider_source[0], + "other": provider_source[1], + } async def test_oauth_session_refresh_failure(