Update Nest configuration flow to handle upcoming changes to Pub/Sub provisioning (#128909)

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
Allen Porter 2024-10-29 04:58:36 -07:00 committed by GitHub
parent f0bff09b5e
commit 8e7ffd9e16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 669 additions and 132 deletions

View file

@ -59,6 +59,7 @@ from .const import (
CONF_PROJECT_ID,
CONF_SUBSCRIBER_ID,
CONF_SUBSCRIBER_ID_IMPORTED,
CONF_SUBSCRIPTION_NAME,
DATA_DEVICE_MANAGER,
DATA_SDM,
DATA_SUBSCRIBER,
@ -289,7 +290,9 @@ async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle removal of pubsub subscriptions created during config flow."""
if (
DATA_SDM not in entry.data
or CONF_SUBSCRIBER_ID not in entry.data
or not (
CONF_SUBSCRIPTION_NAME in entry.data or CONF_SUBSCRIBER_ID in entry.data
)
or CONF_SUBSCRIBER_ID_IMPORTED in entry.data
):
return

View file

@ -8,6 +8,7 @@ from typing import cast
from aiohttp import ClientSession
from google.oauth2.credentials import Credentials
from google_nest_sdm.admin_client import PUBSUB_API_HOST, AdminClient
from google_nest_sdm.auth import AbstractAuth
from google_nest_sdm.google_nest_subscriber import GoogleNestSubscriber
@ -19,6 +20,7 @@ from .const import (
API_URL,
CONF_PROJECT_ID,
CONF_SUBSCRIBER_ID,
CONF_SUBSCRIPTION_NAME,
OAUTH2_TOKEN,
SDM_SCOPES,
)
@ -80,9 +82,10 @@ class AccessTokenAuthImpl(AbstractAuth):
self,
websession: ClientSession,
access_token: str,
host: str,
) -> None:
"""Init the Nest client library auth implementation."""
super().__init__(websession, API_URL)
super().__init__(websession, host)
self._access_token = access_token
async def async_get_access_token(self) -> str:
@ -111,29 +114,47 @@ async def new_subscriber(
implementation, config_entry_oauth2_flow.LocalOAuth2Implementation
):
raise TypeError(f"Unexpected auth implementation {implementation}")
if not (subscriber_id := entry.data.get(CONF_SUBSCRIBER_ID)):
raise ValueError("Configuration option 'subscriber_id' missing")
subscription_name = entry.data.get(
CONF_SUBSCRIPTION_NAME, entry.data[CONF_SUBSCRIBER_ID]
)
auth = AsyncConfigEntryAuth(
aiohttp_client.async_get_clientsession(hass),
config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation),
implementation.client_id,
implementation.client_secret,
)
return GoogleNestSubscriber(auth, entry.data[CONF_PROJECT_ID], subscriber_id)
return GoogleNestSubscriber(auth, entry.data[CONF_PROJECT_ID], subscription_name)
def new_subscriber_with_token(
hass: HomeAssistant,
access_token: str,
project_id: str,
subscriber_id: str,
subscription_name: str,
) -> GoogleNestSubscriber:
"""Create a GoogleNestSubscriber with an access token."""
return GoogleNestSubscriber(
AccessTokenAuthImpl(
aiohttp_client.async_get_clientsession(hass),
access_token,
API_URL,
),
project_id,
subscriber_id,
subscription_name,
)
def new_pubsub_admin_client(
hass: HomeAssistant,
access_token: str,
cloud_project_id: str,
) -> AdminClient:
"""Create a Nest AdminClient with an access token."""
return AdminClient(
auth=AccessTokenAuthImpl(
aiohttp_client.async_get_clientsession(hass),
access_token,
PUBSUB_API_HOST,
),
cloud_project_id=cloud_project_id,
)

View file

@ -12,14 +12,14 @@ from __future__ import annotations
from collections.abc import Iterable, Mapping
import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from google_nest_sdm.exceptions import (
ApiException,
AuthException,
ConfigurationException,
SubscriberException,
from google_nest_sdm.admin_client import (
AdminClient,
EligibleSubscriptions,
EligibleTopics,
)
from google_nest_sdm.exceptions import ApiException
from google_nest_sdm.structure import Structure
import voluptuous as vol
@ -31,8 +31,9 @@ from . import api
from .const import (
CONF_CLOUD_PROJECT_ID,
CONF_PROJECT_ID,
CONF_SUBSCRIBER_ID,
DATA_NEST_CONFIG,
CONF_SUBSCRIBER_ID_IMPORTED,
CONF_SUBSCRIPTION_NAME,
CONF_TOPIC_NAME,
DATA_SDM,
DOMAIN,
OAUTH2_AUTHORIZE,
@ -58,7 +59,7 @@ DEVICE_ACCESS_CONSOLE_URL = "https://console.nest.google.com/device-access/"
DEVICE_ACCESS_CONSOLE_EDIT_URL = (
"https://console.nest.google.com/device-access/project/{project_id}/information"
)
CREATE_NEW_SUBSCRIPTION_KEY = "create_new_subscription"
_LOGGER = logging.getLogger(__name__)
@ -95,6 +96,9 @@ class NestFlowHandler(
self._data: dict[str, Any] = {DATA_SDM: {}}
# Possible name to use for config entry based on the Google Home name
self._structure_config_title: str | None = None
self._admin_client: AdminClient | None = None
self._eligible_topics: EligibleTopics | None = None
self._eligible_subscriptions: EligibleSubscriptions | None = None
@property
def logger(self) -> logging.Logger:
@ -113,8 +117,7 @@ class NestFlowHandler(
async def async_generate_authorize_url(self) -> str:
"""Generate a url for the user to authorize based on user input."""
config = self.hass.data.get(DOMAIN, {}).get(DATA_NEST_CONFIG, {})
project_id = self._data.get(CONF_PROJECT_ID, config.get(CONF_PROJECT_ID, ""))
project_id = self._data.get(CONF_PROJECT_ID)
query = await super().async_generate_authorize_url()
authorize_url = OAUTH2_AUTHORIZE.format(project_id=project_id)
return f"{authorize_url}{query}"
@ -123,6 +126,7 @@ class NestFlowHandler(
"""Complete OAuth setup and finish pubsub or finish."""
_LOGGER.debug("Finishing post-oauth configuration")
self._data.update(data)
_LOGGER.debug("self.source=%s", self.source)
if self.source == SOURCE_REAUTH:
_LOGGER.debug("Skipping Pub/Sub configuration")
return await self._async_finish()
@ -132,6 +136,7 @@ class NestFlowHandler(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Perform reauth upon an API authentication error."""
_LOGGER.debug("async_step_reauth %s", self.source)
self._data.update(entry_data)
return await self.async_step_reauth_confirm()
@ -238,40 +243,114 @@ class NestFlowHandler(
async def async_step_pubsub(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Configure and create Pub/Sub subscriber."""
"""Configure and the pre-requisites to configure Pub/Sub topics and subscriptions."""
data = {
**self._data,
**(user_input if user_input is not None else {}),
}
cloud_project_id = data.get(CONF_CLOUD_PROJECT_ID, "").strip()
config = self.hass.data.get(DOMAIN, {}).get(DATA_NEST_CONFIG, {})
project_id = data.get(CONF_PROJECT_ID, config.get(CONF_PROJECT_ID))
device_access_project_id = data[CONF_PROJECT_ID]
errors: dict[str, str] = {}
if cloud_project_id:
# Create the subscriber id and/or verify it already exists. Note that
# the existing id is used, and create call below is idempotent
if not (subscriber_id := data.get(CONF_SUBSCRIBER_ID, "")):
subscriber_id = _generate_subscription_id(cloud_project_id)
_LOGGER.debug("Creating subscriber id '%s'", subscriber_id)
subscriber = api.new_subscriber_with_token(
self.hass,
self._data["token"]["access_token"],
project_id,
subscriber_id,
access_token = self._data["token"]["access_token"]
self._admin_client = api.new_pubsub_admin_client(
self.hass, access_token=access_token, cloud_project_id=cloud_project_id
)
try:
await subscriber.create_subscription()
except AuthException as err:
_LOGGER.error("Subscriber authentication error: %s", err)
return self.async_abort(reason="invalid_access_token")
except ConfigurationException as err:
_LOGGER.error("Configuration error creating subscription: %s", err)
errors[CONF_CLOUD_PROJECT_ID] = "bad_project_id"
except SubscriberException as err:
_LOGGER.error("Error creating subscription: %s", err)
errors[CONF_CLOUD_PROJECT_ID] = "subscriber_error"
eligible_topics = await self._admin_client.list_eligible_topics(
device_access_project_id=device_access_project_id
)
except ApiException as err:
_LOGGER.error("Error listing eligible Pub/Sub topics: %s", err)
errors["base"] = "pubsub_api_error"
else:
if not eligible_topics.topic_names:
errors["base"] = "no_pubsub_topics"
if not errors:
self._data[CONF_CLOUD_PROJECT_ID] = cloud_project_id
self._eligible_topics = eligible_topics
return await self.async_step_pubsub_topic()
return self.async_show_form(
step_id="pubsub",
data_schema=vol.Schema(
{
vol.Required(CONF_CLOUD_PROJECT_ID, default=cloud_project_id): str,
}
),
description_placeholders={
"url": CLOUD_CONSOLE_URL,
"device_access_console_url": DEVICE_ACCESS_CONSOLE_URL,
"more_info_url": MORE_INFO_URL,
},
errors=errors,
)
async def async_step_pubsub_topic(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Configure and create Pub/Sub topic."""
if TYPE_CHECKING:
assert self._eligible_topics
if user_input is not None:
self._data.update(user_input)
return await self.async_step_pubsub_subscription()
topics = list(self._eligible_topics.topic_names)
return self.async_show_form(
step_id="pubsub_topic",
data_schema=vol.Schema(
{
vol.Optional(CONF_TOPIC_NAME, default=topics[0]): vol.In(topics),
}
),
description_placeholders={
"device_access_console_url": DEVICE_ACCESS_CONSOLE_URL,
"more_info_url": MORE_INFO_URL,
},
)
async def async_step_pubsub_subscription(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Configure and create Pub/Sub subscription."""
if TYPE_CHECKING:
assert self._admin_client
errors = {}
if user_input is not None:
subscription_name = user_input[CONF_SUBSCRIPTION_NAME]
if subscription_name == CREATE_NEW_SUBSCRIPTION_KEY:
topic_name = self._data[CONF_TOPIC_NAME]
subscription_name = _generate_subscription_id(
self._data[CONF_CLOUD_PROJECT_ID]
)
_LOGGER.debug(
"Creating subscription %s on topic %s",
subscription_name,
topic_name,
)
try:
await self._admin_client.create_subscription(
topic_name,
subscription_name,
)
except ApiException as err:
_LOGGER.error("Error creatingPub/Sub subscription: %s", err)
errors["base"] = "pubsub_api_error"
else:
user_input[CONF_SUBSCRIPTION_NAME] = subscription_name
else:
# The user created this subscription themselves so do not delete when removing the integration.
user_input[CONF_SUBSCRIBER_ID_IMPORTED] = True
if not errors:
self._data.update(user_input)
subscriber = api.new_subscriber_with_token(
self.hass,
self._data["token"]["access_token"],
self._data[CONF_PROJECT_ID],
subscription_name,
)
try:
device_manager = await subscriber.async_get_device_manager()
except ApiException as err:
@ -281,23 +360,39 @@ class NestFlowHandler(
self._structure_config_title = generate_config_title(
device_manager.structures.values()
)
self._data.update(
{
CONF_SUBSCRIBER_ID: subscriber_id,
CONF_CLOUD_PROJECT_ID: cloud_project_id,
}
)
return await self._async_finish()
subscriptions = {}
try:
eligible_subscriptions = (
await self._admin_client.list_eligible_subscriptions(
expected_topic_name=self._data[CONF_TOPIC_NAME],
)
)
except ApiException as err:
_LOGGER.error(
"Error talking to API to list eligible Pub/Sub subscriptions: %s", err
)
errors["base"] = "pubsub_api_error"
else:
subscriptions.update(
{name: name for name in eligible_subscriptions.subscription_names}
)
subscriptions[CREATE_NEW_SUBSCRIPTION_KEY] = "Create New"
return self.async_show_form(
step_id="pubsub",
step_id="pubsub_subscription",
data_schema=vol.Schema(
{
vol.Required(CONF_CLOUD_PROJECT_ID, default=cloud_project_id): str,
vol.Optional(
CONF_SUBSCRIPTION_NAME,
default=next(iter(subscriptions)),
): vol.In(subscriptions),
}
),
description_placeholders={"url": CLOUD_CONSOLE_URL},
description_placeholders={
"topic": self._data[CONF_TOPIC_NAME],
"more_info_url": MORE_INFO_URL,
},
errors=errors,
)

View file

@ -4,13 +4,14 @@ DOMAIN = "nest"
DATA_SDM = "sdm"
DATA_SUBSCRIBER = "subscriber"
DATA_DEVICE_MANAGER = "device_manager"
DATA_NEST_CONFIG = "nest_config"
WEB_AUTH_DOMAIN = DOMAIN
INSTALLED_AUTH_DOMAIN = f"{DOMAIN}.installed"
CONF_PROJECT_ID = "project_id"
CONF_SUBSCRIBER_ID = "subscriber_id"
CONF_TOPIC_NAME = "topic_name"
CONF_SUBSCRIPTION_NAME = "subscription_name"
CONF_SUBSCRIBER_ID = "subscriber_id" # Old format
CONF_SUBSCRIBER_ID_IMPORTED = "subscriber_id_imported"
CONF_CLOUD_PROJECT_ID = "cloud_project_id"

View file

@ -26,12 +26,26 @@
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]"
},
"pubsub": {
"title": "Configure Google Cloud",
"description": "Visit the [Cloud Console]({url}) to find your Google Cloud Project ID.",
"title": "Configure Google Cloud Pub/Sub",
"description": "Home Assistant uses Cloud Pub/Sub receive realtime Nest device updates. Nest servers publish updates to a Pub/Sub topic and Home Assistat receives the updates through a Pub/Sub subscription.\n\n1. Visit the [Device Access Console]({device_access_console_url}) and ensure a Pub/Sub topic is configured.\n2. Visit the [Cloud Console]({url}) to find your Google Cloud Project ID and confirm it is correct below.\n3. The next step will attempt to audo-discover Pub/Sub topics and subscriptions.\n\nSee the integration documentation for [more info]({more_info_url}).",
"data": {
"cloud_project_id": "[%key:component::nest::config::step::cloud_project::data::cloud_project_id%]"
}
},
"pubsub_topic": {
"title": "Configure Cloud Pub/Sub topic",
"description": "Nest devices publish updates on a Cloud Pub/Sub topic. Select the Pub/Sub topic below that is the same as the [Device Access Console]({device_access_console_url}). See the integration documentation for [more info]({more_info_url}).",
"data": {
"topic_name": "Pub/Sub topic Name"
}
},
"pubsub_subscription": {
"title": "Configure Cloud Pub/Sub subscription",
"description": "Home Assistant receives realtime Nest device updates with a Cloud Pub/Sub subscription for topic `{topic}`.\n\nSelect an existing subscription below if one already exists, or the next step will create a new one for you. See the integration documentation for [more info]({more_info_url}).",
"data": {
"subscription_name": "Pub/Sub subscription Name"
}
},
"reauth_confirm": {
"title": "[%key:common::config_flow::title::reauth%]",
"description": "The Nest integration needs to re-authenticate your account"
@ -40,7 +54,9 @@
"error": {
"bad_project_id": "Please enter a valid Cloud Project ID (check Cloud Console)",
"wrong_project_id": "Please enter a valid Cloud Project ID (was same as Device Access Project ID)",
"subscriber_error": "Unknown subscriber error, see logs"
"subscriber_error": "Unknown subscriber error, see logs",
"no_pubsub_topics": "No eligible Pub/Sub topics found, please ensure Device Access Console has a Pub/Sub topic.",
"pubsub_api_error": "Unknown error talking to Cloud Pub/Sub, see logs"
},
"abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_account%]",

View file

@ -6,11 +6,7 @@ from http import HTTPStatus
from typing import Any
from unittest.mock import patch
from google_nest_sdm.exceptions import (
AuthException,
ConfigurationException,
SubscriberException,
)
from google_nest_sdm.exceptions import AuthException
from google_nest_sdm.structure import Structure
import pytest
@ -40,7 +36,7 @@ from tests.typing import ClientSessionGenerator
WEB_REDIRECT_URL = "https://example.com/auth/external/callback"
APP_REDIRECT_URL = "urn:ietf:wg:oauth:2.0:oob"
RAND_SUBSCRIBER_SUFFIX = "ABCDEF"
FAKE_DHCP_DATA = dhcp.DhcpServiceInfo(
ip="127.0.0.2", macaddress="001122334455", hostname="fake_hostname"
@ -53,6 +49,16 @@ def nest_test_config() -> NestTestConfig:
return TEST_CONFIGFLOW_APP_CREDS
@pytest.fixture(autouse=True)
def mock_rand_topic_name_fixture() -> None:
"""Set the topic name random string to a constant."""
with patch(
"homeassistant.components.nest.config_flow.get_random_string",
return_value=RAND_SUBSCRIBER_SUFFIX,
):
yield
class OAuthFixture:
"""Simulate the oauth flow used by the config flow."""
@ -158,6 +164,43 @@ class OAuthFixture:
},
)
async def async_complete_pubsub_flow(
self,
result: dict,
selected_topic: str,
selected_subscription: str = "create_new_subscription",
user_input: dict | None = None,
) -> ConfigEntry:
"""Fixture to walk through the Pub/Sub topic and subscription steps.
This picks a simple set of steps that are reusable for most flows without
exercising the corner cases.
"""
# Validate Pub/Sub topics are shown
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_topic"
assert not result.get("errors")
# Select Pub/Sub topic the show available subscriptions (none)
result = await self.async_configure(
result,
{
"topic_name": selected_topic,
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_subscription"
assert not result.get("errors")
# Create the subscription and end the flow
return await self.async_finish_setup(
result,
{
"subscription_name": selected_subscription,
},
)
async def async_finish_setup(
self, result: dict, user_input: dict | None = None
) -> ConfigEntry:
@ -179,15 +222,6 @@ class OAuthFixture:
user_input,
)
async def async_pubsub_flow(self, result: dict, cloud_project_id="") -> None:
"""Verify the pubsub creation step."""
# Render form with a link to get an auth token
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "pubsub"
assert "description_placeholders" in result
assert "url" in result["description_placeholders"]
assert result["data_schema"]({}) == {"cloud_project_id": cloud_project_id}
def get_config_entry(self) -> ConfigEntry:
"""Get the config entry."""
entries = self.hass.config_entries.async_entries(DOMAIN)
@ -206,6 +240,115 @@ async def oauth(
return OAuthFixture(hass, hass_client_no_auth, aioclient_mock)
@pytest.fixture(name="sdm_managed_topic")
def mock_sdm_managed_topic() -> bool:
"""Fixture to configure fake server responses for SDM owend Pub/Sub topics."""
return False
@pytest.fixture(name="user_managed_topics")
def mock_user_managed_topics() -> list[str]:
"""Fixture to configure fake server response for user owned Pub/Sub topics."""
return []
@pytest.fixture(name="subscriptions")
def mock_subscriptions() -> list[tuple[str, str]]:
"""Fixture to configure fake server response for user subscriptions that exist."""
return []
@pytest.fixture(name="device_access_project_id")
def mock_device_access_project_id() -> str:
"""Fixture to configure the device access console project id used in tests."""
return PROJECT_ID
@pytest.fixture(name="cloud_project_id")
def mock_cloud_project_id() -> str:
"""Fixture to configure the cloud console project id used in tests."""
return CLOUD_PROJECT_ID
@pytest.fixture(name="create_subscription_status")
def mock_create_subscription_status() -> str:
"""Fixture to configure the return code when creating the subscription."""
return HTTPStatus.OK
@pytest.fixture(name="list_topics_status")
def mock_list_topics_status() -> str:
"""Fixture to configure the return code when listing topics."""
return HTTPStatus.OK
@pytest.fixture(name="list_subscriptions_status")
def mock_list_subscriptions_status() -> str:
"""Fixture to configure the return code when listing subscriptions."""
return HTTPStatus.OK
@pytest.fixture(autouse=True)
def mock_pubsub_api_responses(
aioclient_mock: AiohttpClientMocker,
sdm_managed_topic: bool,
user_managed_topics: list[str],
subscriptions: list[tuple[str, str]],
device_access_project_id: str,
cloud_project_id: str,
create_subscription_status: HTTPStatus,
list_topics_status: HTTPStatus,
list_subscriptions_status: HTTPStatus,
) -> None:
"""Configure a server response for an SDM managed Pub/Sub topic.
We check for a topic created by the SDM Device Access Console (but note we don't have permission to read it)
or the user has created one themselves in the Google Cloud Project.
"""
aioclient_mock.get(
f"https://pubsub.googleapis.com/v1/projects/sdm-prod/topics/enterprise-{device_access_project_id}",
status=HTTPStatus.FORBIDDEN if sdm_managed_topic else HTTPStatus.NOT_FOUND,
)
aioclient_mock.get(
f"https://pubsub.googleapis.com/v1/projects/{cloud_project_id}/topics",
json={
"topics": [
{
"name": topic_name,
}
for topic_name in user_managed_topics or ()
]
},
status=list_topics_status,
)
# We check for a topic created by the SDM Device Access Console (but note we don't have permission to read it)
# or the user has created one themselves in the Google Cloud Project.
aioclient_mock.get(
f"https://pubsub.googleapis.com/v1/projects/{cloud_project_id}/subscriptions",
json={
"subscriptions": [
{
"name": subscription_name,
"topic": topic,
"pushConfig": {},
"ackDeadlineSeconds": 10,
"messageRetentionDuration": "604800s",
"expirationPolicy": {"ttl": "2678400s"},
"state": "ACTIVE",
}
for (subscription_name, topic) in subscriptions or ()
]
},
status=list_subscriptions_status,
)
aioclient_mock.put(
f"https://pubsub.googleapis.com/v1/projects/{cloud_project_id}/subscriptions/home-assistant-{RAND_SUBSCRIBER_SUFFIX}",
json={},
status=create_subscription_status,
)
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_app_credentials(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
@ -218,20 +361,22 @@ async def test_app_credentials(
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result)
result = await oauth.async_configure(result, None)
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic=f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
data = dict(entry.data)
assert "token" in data
data["token"].pop("expires_in")
data["token"].pop("expires_at")
assert "subscriber_id" in data
assert f"projects/{CLOUD_PROJECT_ID}/subscriptions" in data["subscriber_id"]
data.pop("subscriber_id")
assert data == {
"sdm": {},
"auth_implementation": "imported-cred",
"cloud_project_id": CLOUD_PROJECT_ID,
"project_id": PROJECT_ID,
"subscription_name": f"projects/{CLOUD_PROJECT_ID}/subscriptions/home-assistant-{RAND_SUBSCRIBER_SUFFIX}",
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
"token": {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access-token",
@ -240,6 +385,10 @@ async def test_app_credentials(
}
@pytest.mark.parametrize(
("sdm_managed_topic", "device_access_project_id", "cloud_project_id"),
[(True, "new-project-id", "new-cloud-project-id")],
)
async def test_config_flow_restart(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
@ -272,20 +421,22 @@ async def test_config_flow_restart(
await oauth.async_oauth_web_flow(result, "new-project-id")
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic="projects/sdm-prod/topics/enterprise-new-project-id"
)
data = dict(entry.data)
assert "token" in data
data["token"].pop("expires_in")
data["token"].pop("expires_at")
assert "subscriber_id" in data
assert "projects/new-cloud-project-id/subscriptions" in data["subscriber_id"]
data.pop("subscriber_id")
assert data == {
"sdm": {},
"auth_implementation": "imported-cred",
"cloud_project_id": "new-cloud-project-id",
"project_id": "new-project-id",
"subscription_name": "projects/new-cloud-project-id/subscriptions/home-assistant-ABCDEF",
"topic_name": "projects/sdm-prod/topics/enterprise-new-project-id",
"token": {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access-token",
@ -294,6 +445,7 @@ async def test_config_flow_restart(
}
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_config_flow_wrong_project_id(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
@ -324,20 +476,22 @@ async def test_config_flow_wrong_project_id(
await hass.async_block_till_done()
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic="projects/sdm-prod/topics/enterprise-some-project-id"
)
data = dict(entry.data)
assert "token" in data
data["token"].pop("expires_in")
data["token"].pop("expires_at")
assert "subscriber_id" in data
assert f"projects/{CLOUD_PROJECT_ID}/subscriptions" in data["subscriber_id"]
data.pop("subscriber_id")
assert data == {
"sdm": {},
"auth_implementation": "imported-cred",
"cloud_project_id": CLOUD_PROJECT_ID,
"project_id": PROJECT_ID,
"subscription_name": "projects/cloud-id-9876/subscriptions/home-assistant-ABCDEF",
"topic_name": "projects/sdm-prod/topics/enterprise-some-project-id",
"token": {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access-token",
@ -346,6 +500,9 @@ async def test_config_flow_wrong_project_id(
}
@pytest.mark.parametrize(
("sdm_managed_topic", "create_subscription_status"), [(True, HTTPStatus.NOT_FOUND)]
)
async def test_config_flow_pubsub_configuration_error(
hass: HomeAssistant,
oauth,
@ -361,14 +518,41 @@ async def test_config_flow_pubsub_configuration_error(
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
mock_subscriber.create_subscription.side_effect = ConfigurationException
result = await oauth.async_configure(result, {"code": "1234"})
assert result["type"] is FlowResultType.FORM
assert "errors" in result
assert "cloud_project_id" in result["errors"]
assert result["errors"]["cloud_project_id"] == "bad_project_id"
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_topic"
assert result.get("data_schema")({}) == {
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
}
# Select Pub/Sub topic the show available subscriptions (none)
result = await oauth.async_configure(
result,
{
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_subscription"
assert result.get("data_schema")({}) == {
"subscription_name": "create_new_subscription",
}
# Failure when creating the subscription
result = await oauth.async_configure(
result,
{
"subscription_name": "create_new_subscription",
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("errors") == {"base": "pubsub_api_error"}
@pytest.mark.parametrize(
("sdm_managed_topic", "create_subscription_status"),
[(True, HTTPStatus.INTERNAL_SERVER_ERROR)],
)
async def test_config_flow_pubsub_subscriber_error(
hass: HomeAssistant, oauth, setup_platform, mock_subscriber
) -> None:
@ -380,17 +564,42 @@ async def test_config_flow_pubsub_subscriber_error(
)
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
mock_subscriber.create_subscription.side_effect = SubscriberException()
result = await oauth.async_configure(result, {"code": "1234"})
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_topic"
assert result.get("data_schema")({}) == {
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
}
assert result["type"] is FlowResultType.FORM
assert "errors" in result
assert "cloud_project_id" in result["errors"]
assert result["errors"]["cloud_project_id"] == "subscriber_error"
# Select Pub/Sub topic the show available subscriptions (none)
result = await oauth.async_configure(
result,
{
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_subscription"
assert result.get("data_schema")({}) == {
"subscription_name": "create_new_subscription",
}
# Failure when creating the subscription
result = await oauth.async_configure(
result,
{
"subscription_name": "create_new_subscription",
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("errors") == {"base": "pubsub_api_error"}
@pytest.mark.parametrize("nest_test_config", [TEST_CONFIG_APP_CREDS])
@pytest.mark.parametrize(
("nest_test_config", "sdm_managed_topic", "device_access_project_id"),
[(TEST_CONFIG_APP_CREDS, True, "project-id-2")],
)
async def test_multiple_config_entries(
hass: HomeAssistant, oauth, setup_platform
) -> None:
@ -405,7 +614,10 @@ async def test_multiple_config_entries(
)
await oauth.async_app_creds_flow(result, project_id="project-id-2")
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result)
result = await oauth.async_configure(result, user_input={})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic="projects/sdm-prod/topics/enterprise-project-id-2"
)
assert entry.title == "Mock Title"
assert "token" in entry.data
@ -413,7 +625,9 @@ async def test_multiple_config_entries(
assert len(entries) == 2
@pytest.mark.parametrize("nest_test_config", [TEST_CONFIG_APP_CREDS])
@pytest.mark.parametrize(
("nest_test_config", "sdm_managed_topic"), [(TEST_CONFIG_APP_CREDS, True)]
)
async def test_duplicate_config_entries(
hass: HomeAssistant, oauth, setup_platform
) -> None:
@ -438,7 +652,9 @@ async def test_duplicate_config_entries(
assert result.get("reason") == "already_configured"
@pytest.mark.parametrize("nest_test_config", [TEST_CONFIG_APP_CREDS])
@pytest.mark.parametrize(
("nest_test_config", "sdm_managed_topic"), [(TEST_CONFIG_APP_CREDS, True)]
)
async def test_reauth_multiple_config_entries(
hass: HomeAssistant, oauth, setup_platform, config_entry
) -> None:
@ -489,6 +705,7 @@ async def test_reauth_multiple_config_entries(
assert entry.data.get("extra_data")
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_pubsub_subscription_strip_whitespace(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
@ -502,8 +719,10 @@ async def test_pubsub_subscription_strip_whitespace(
result, cloud_project_id=" " + CLOUD_PROJECT_ID + " "
)
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic="projects/sdm-prod/topics/enterprise-some-project-id"
)
assert entry.title == "Import from configuration.yaml"
assert "token" in entry.data
entry.data["token"].pop("expires_at")
@ -514,10 +733,14 @@ async def test_pubsub_subscription_strip_whitespace(
"type": "Bearer",
"expires_in": 60,
}
assert "subscriber_id" in entry.data
assert "subscription_name" in entry.data
assert entry.data["cloud_project_id"] == CLOUD_PROJECT_ID
@pytest.mark.parametrize(
("sdm_managed_topic", "create_subscription_status"),
[(True, HTTPStatus.UNAUTHORIZED)],
)
async def test_pubsub_subscription_auth_failure(
hass: HomeAssistant, oauth, setup_platform, mock_subscriber
) -> None:
@ -528,17 +751,43 @@ async def test_pubsub_subscription_auth_failure(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_subscriber.create_subscription.side_effect = AuthException()
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
result = await oauth.async_configure(result, {"code": "1234"})
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_topic"
assert result.get("data_schema")({}) == {
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
}
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "invalid_access_token"
# Select Pub/Sub topic the show available subscriptions (none)
result = await oauth.async_configure(
result,
{
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_subscription"
assert result.get("data_schema")({}) == {
"subscription_name": "create_new_subscription",
}
# Failure when creating the subscription
result = await oauth.async_configure(
result,
{
"subscription_name": "create_new_subscription",
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_subscription"
assert result.get("errors") == {"base": "pubsub_api_error"}
@pytest.mark.parametrize("nest_test_config", [TEST_CONFIG_APP_CREDS])
@pytest.mark.parametrize(
("nest_test_config", "sdm_managed_topic"), [(TEST_CONFIG_APP_CREDS, True)]
)
async def test_pubsub_subscriber_config_entry_reauth(
hass: HomeAssistant,
oauth,
@ -568,6 +817,7 @@ async def test_pubsub_subscriber_config_entry_reauth(
assert entry.data["cloud_project_id"] == CLOUD_PROJECT_ID
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_config_entry_title_from_home(
hass: HomeAssistant, oauth, setup_platform, subscriber
) -> None:
@ -595,13 +845,24 @@ async def test_config_entry_title_from_home(
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic=f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
assert entry.title == "Example Home"
assert "token" in entry.data
assert "subscriber_id" in entry.data
assert entry.data["cloud_project_id"] == CLOUD_PROJECT_ID
assert entry.data.get("cloud_project_id") == CLOUD_PROJECT_ID
assert (
entry.data.get("subscription_name")
== f"projects/{CLOUD_PROJECT_ID}/subscriptions/home-assistant-{RAND_SUBSCRIBER_SUFFIX}"
)
assert (
entry.data.get("topic_name")
== f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_config_entry_title_multiple_homes(
hass: HomeAssistant, oauth, setup_platform, subscriber
) -> None:
@ -641,10 +902,14 @@ async def test_config_entry_title_multiple_homes(
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic=f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
assert entry.title == "Example Home #1, Example Home #2"
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_title_failure_fallback(
hass: HomeAssistant, oauth, setup_platform, mock_subscriber
) -> None:
@ -658,13 +923,26 @@ async def test_title_failure_fallback(
oauth.async_mock_refresh()
mock_subscriber.async_get_device_manager.side_effect = AuthException()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic=f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
assert entry.title == "Import from configuration.yaml"
assert "token" in entry.data
assert "subscriber_id" in entry.data
assert entry.data["cloud_project_id"] == CLOUD_PROJECT_ID
assert entry.data.get("cloud_project_id") == CLOUD_PROJECT_ID
assert (
entry.data.get("subscription_name")
== f"projects/{CLOUD_PROJECT_ID}/subscriptions/home-assistant-{RAND_SUBSCRIBER_SUFFIX}"
)
assert (
entry.data.get("topic_name")
== f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_structure_missing_trait(
hass: HomeAssistant, oauth, setup_platform, subscriber
) -> None:
@ -689,7 +967,10 @@ async def test_structure_missing_trait(
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic=f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
# Fallback to default name
assert entry.title == "Import from configuration.yaml"
@ -713,6 +994,7 @@ async def test_dhcp_discovery(
assert result.get("reason") == "missing_credentials"
@pytest.mark.parametrize(("sdm_managed_topic"), [(True)])
async def test_dhcp_discovery_with_creds(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
@ -735,21 +1017,23 @@ async def test_dhcp_discovery_with_creds(
result = await oauth.async_configure(result, {"project_id": PROJECT_ID})
await oauth.async_oauth_web_flow(result)
oauth.async_mock_refresh()
entry = await oauth.async_finish_setup(result, {"code": "1234"})
await hass.async_block_till_done()
result = await oauth.async_configure(result, {"code": "1234"})
entry = await oauth.async_complete_pubsub_flow(
result, selected_topic=f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}"
)
data = dict(entry.data)
assert "token" in data
data["token"].pop("expires_in")
data["token"].pop("expires_at")
assert "subscriber_id" in data
assert f"projects/{CLOUD_PROJECT_ID}/subscriptions" in data["subscriber_id"]
data.pop("subscriber_id")
assert data == {
"sdm": {},
"auth_implementation": "imported-cred",
"cloud_project_id": CLOUD_PROJECT_ID,
"project_id": PROJECT_ID,
"subscription_name": f"projects/{CLOUD_PROJECT_ID}/subscriptions/home-assistant-{RAND_SUBSCRIBER_SUFFIX}",
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
"token": {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access-token",
@ -789,3 +1073,133 @@ async def test_token_error(
result = await oauth.async_configure(result, user_input=None)
assert result.get("type") is FlowResultType.ABORT
assert result.get("reason") == error_reason
@pytest.mark.parametrize(
("user_managed_topics", "subscriptions"),
[
(
[f"projects/{CLOUD_PROJECT_ID}/topics/some-topic-id"],
[
(
f"projects/{CLOUD_PROJECT_ID}/subscriptions/some-subscription-id",
f"projects/{CLOUD_PROJECT_ID}/topics/some-topic-id",
)
],
)
],
)
async def test_existing_topic_and_subscription(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
"""Test selecting existing user managed topic and subscription."""
await setup_platform()
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
result = await oauth.async_configure(result, None)
entry = await oauth.async_complete_pubsub_flow(
result,
selected_topic=f"projects/{CLOUD_PROJECT_ID}/topics/some-topic-id",
selected_subscription=f"projects/{CLOUD_PROJECT_ID}/subscriptions/some-subscription-id",
)
data = dict(entry.data)
assert "token" in data
data["token"].pop("expires_in")
data["token"].pop("expires_at")
assert data == {
"sdm": {},
"auth_implementation": "imported-cred",
"cloud_project_id": CLOUD_PROJECT_ID,
"project_id": PROJECT_ID,
"subscription_name": f"projects/{CLOUD_PROJECT_ID}/subscriptions/some-subscription-id",
"subscriber_id_imported": True,
"topic_name": f"projects/{CLOUD_PROJECT_ID}/topics/some-topic-id",
"token": {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access-token",
"type": "Bearer",
},
}
async def test_no_eligible_topics(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
"""Test the case where there are no eligible pub/sub topics."""
await setup_platform()
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
result = await oauth.async_configure(result, None)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub"
assert result.get("errors") == {"base": "no_pubsub_topics"}
@pytest.mark.parametrize(
("list_topics_status"),
[
(HTTPStatus.INTERNAL_SERVER_ERROR),
],
)
async def test_list_topics_failure(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
"""Test selecting existing user managed topic and subscription."""
await setup_platform()
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
result = await oauth.async_configure(result, None)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub"
assert result.get("errors") == {"base": "pubsub_api_error"}
@pytest.mark.parametrize(
("sdm_managed_topic", "list_subscriptions_status"),
[
(True, HTTPStatus.INTERNAL_SERVER_ERROR),
],
)
async def test_list_subscriptions_failure(
hass: HomeAssistant, oauth, subscriber, setup_platform
) -> None:
"""Test selecting existing user managed topic and subscription."""
await setup_platform()
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await oauth.async_app_creds_flow(result)
oauth.async_mock_refresh()
result = await oauth.async_configure(result, None)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_topic"
assert not result.get("errors")
# Select Pub/Sub topic the show available subscriptions (none)
result = await oauth.async_configure(
result,
{
"topic_name": f"projects/sdm-prod/topics/enterprise-{PROJECT_ID}",
},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "pubsub_subscription"
assert result.get("errors") == {"base": "pubsub_api_error"}

View file

@ -171,19 +171,6 @@ async def test_subscriber_auth_failure(
assert flows[0]["step_id"] == "reauth_confirm"
@pytest.mark.parametrize("subscriber_id", [(None)])
async def test_setup_missing_subscriber_id(
hass: HomeAssistant, warning_caplog: pytest.LogCaptureFixture, setup_base_platform
) -> None:
"""Test missing subscriber id from configuration."""
await setup_base_platform()
assert "Configuration option" in warning_caplog.text
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
assert entries[0].state is ConfigEntryState.SETUP_ERROR
@pytest.mark.parametrize("subscriber_side_effect", [(ConfigurationException())])
async def test_subscriber_configuration_failure(
hass: HomeAssistant,