diff --git a/homeassistant/components/google/api.py b/homeassistant/components/google/api.py index f4a4912a1b9..47aa32dcd11 100644 --- a/homeassistant/components/google/api.py +++ b/homeassistant/components/google/api.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable import datetime import logging from typing import Any, cast @@ -19,9 +18,12 @@ from oauth2client.client import ( from homeassistant.components.application_credentials import AuthImplementation from homeassistant.config_entries import ConfigEntry -from homeassistant.core import CALLBACK_TYPE, HomeAssistant +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import config_entry_oauth2_flow -from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.event import ( + async_track_point_in_utc_time, + async_track_time_interval, +) from homeassistant.util import dt from .const import ( @@ -76,6 +78,9 @@ class DeviceFlow: self._oauth_flow = oauth_flow self._device_flow_info: DeviceFlowInfo = device_flow_info self._exchange_task_unsub: CALLBACK_TYPE | None = None + self._timeout_unsub: CALLBACK_TYPE | None = None + self._listener: CALLBACK_TYPE | None = None + self._creds: Credentials | None = None @property def verification_url(self) -> str: @@ -87,15 +92,22 @@ class DeviceFlow: """Return the code that the user should enter at the verification url.""" return self._device_flow_info.user_code # type: ignore[no-any-return] - async def start_exchange_task( - self, finished_cb: Callable[[Credentials | None], Awaitable[None]] + @callback + def async_set_listener( + self, + update_callback: CALLBACK_TYPE, ) -> None: - """Start the device auth exchange flow polling. + """Invoke the update callback when the exchange finishes or on timeout.""" + self._listener = update_callback - The callback is invoked with the valid credentials or with None on timeout. - """ + @property + def creds(self) -> Credentials | None: + """Return result of exchange step or None on timeout.""" + return self._creds + + def async_start_exchange(self) -> None: + """Start the device auth exchange flow polling.""" _LOGGER.debug("Starting exchange flow") - assert not self._exchange_task_unsub max_timeout = dt.utcnow() + datetime.timedelta(seconds=EXCHANGE_TIMEOUT_SECONDS) # For some reason, oauth.step1_get_device_and_user_codes() returns a datetime # object without tzinfo. For the comparison below to work, it needs one. @@ -104,31 +116,40 @@ class DeviceFlow: ) expiration_time = min(user_code_expiry, max_timeout) - def _exchange() -> Credentials: - return self._oauth_flow.step2_exchange( - device_flow_info=self._device_flow_info - ) - - async def _poll_attempt(now: datetime.datetime) -> None: - assert self._exchange_task_unsub - _LOGGER.debug("Attempting OAuth code exchange") - # Note: The callback is invoked with None when the device code has expired - creds: Credentials | None = None - if now < expiration_time: - try: - creds = await self._hass.async_add_executor_job(_exchange) - except FlowExchangeError: - _LOGGER.debug("Token not yet ready; trying again later") - return - self._exchange_task_unsub() - self._exchange_task_unsub = None - await finished_cb(creds) - self._exchange_task_unsub = async_track_time_interval( self._hass, - _poll_attempt, + self._async_poll_attempt, datetime.timedelta(seconds=self._device_flow_info.interval), ) + self._timeout_unsub = async_track_point_in_utc_time( + self._hass, self._async_timeout, expiration_time + ) + + async def _async_poll_attempt(self, now: datetime.datetime) -> None: + _LOGGER.debug("Attempting OAuth code exchange") + try: + self._creds = await self._hass.async_add_executor_job(self._exchange) + except FlowExchangeError: + _LOGGER.debug("Token not yet ready; trying again later") + return + self._finish() + + def _exchange(self) -> Credentials: + return self._oauth_flow.step2_exchange(device_flow_info=self._device_flow_info) + + @callback + def _async_timeout(self, now: datetime.datetime) -> None: + _LOGGER.debug("OAuth token exchange timeout") + self._finish() + + @callback + def _finish(self) -> None: + if self._exchange_task_unsub: + self._exchange_task_unsub() + if self._timeout_unsub: + self._timeout_unsub() + if self._listener: + self._listener() def get_feature_access( diff --git a/homeassistant/components/google/config_flow.py b/homeassistant/components/google/config_flow.py index cbe1de69f9e..22b62094e76 100644 --- a/homeassistant/components/google/config_flow.py +++ b/homeassistant/components/google/config_flow.py @@ -7,7 +7,6 @@ from typing import Any from gcal_sync.api import GoogleCalendarService from gcal_sync.exceptions import ApiException -from oauth2client.client import Credentials import voluptuous as vol from homeassistant import config_entries @@ -97,9 +96,9 @@ class OAuth2FlowHandler( return self.async_abort(reason="oauth_error") self._device_flow = device_flow - async def _exchange_finished(creds: Credentials | None) -> None: + def _exchange_finished() -> None: self.external_data = { - DEVICE_AUTH_CREDS: creds + DEVICE_AUTH_CREDS: device_flow.creds } # is None on timeout/expiration self.hass.async_create_task( self.hass.config_entries.flow.async_configure( @@ -107,7 +106,8 @@ class OAuth2FlowHandler( ) ) - await device_flow.start_exchange_task(_exchange_finished) + device_flow.async_set_listener(_exchange_finished) + device_flow.async_start_exchange() return self.async_show_progress( step_id="auth", diff --git a/tests/components/google/test_config_flow.py b/tests/components/google/test_config_flow.py index 00f50e129e4..24ad8a7b769 100644 --- a/tests/components/google/test_config_flow.py +++ b/tests/components/google/test_config_flow.py @@ -10,6 +10,7 @@ from unittest.mock import Mock, patch from aiohttp.client_exceptions import ClientError from freezegun.api import FrozenDateTimeFactory from oauth2client.client import ( + DeviceFlowInfo, FlowExchangeError, OAuth2Credentials, OAuth2DeviceCodeError, @@ -59,10 +60,17 @@ async def mock_code_flow( ) -> YieldFixture[Mock]: """Fixture for initiating OAuth flow.""" with patch( - "oauth2client.client.OAuth2WebServerFlow.step1_get_device_and_user_codes", + "homeassistant.components.google.api.OAuth2WebServerFlow.step1_get_device_and_user_codes", ) as mock_flow: - mock_flow.return_value.user_code_expiry = utcnow() + code_expiration_delta - mock_flow.return_value.interval = CODE_CHECK_INTERVAL + mock_flow.return_value = DeviceFlowInfo.FromResponse( + { + "device_code": "4/4-GMMhmHCXhWEzkobqIHGG_EnNYYsAkukHspeYUk9E8", + "user_code": "GQVQ-JKEC", + "verification_url": "https://www.google.com/device", + "expires_in": code_expiration_delta.total_seconds(), + "interval": CODE_CHECK_INTERVAL, + } + ) yield mock_flow @@ -70,7 +78,8 @@ async def mock_code_flow( async def mock_exchange(creds: OAuth2Credentials) -> YieldFixture[Mock]: """Fixture for mocking out the exchange for credentials.""" with patch( - "oauth2client.client.OAuth2WebServerFlow.step2_exchange", return_value=creds + "homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange", + return_value=creds, ) as mock: yield mock @@ -108,7 +117,6 @@ async def fire_alarm(hass, point_in_time): await hass.async_block_till_done() -@pytest.mark.freeze_time("2022-06-03 15:19:59-00:00") async def test_full_flow_yaml_creds( hass: HomeAssistant, mock_code_flow: Mock, @@ -131,9 +139,8 @@ async def test_full_flow_yaml_creds( "homeassistant.components.google.async_setup_entry", return_value=True ) as mock_setup: # Run one tick to invoke the credential exchange check - freezer.tick(CODE_CHECK_ALARM_TIMEDELTA) - await fire_alarm(hass, datetime.datetime.utcnow()) - await hass.async_block_till_done() + now = utcnow() + await fire_alarm(hass, now + CODE_CHECK_ALARM_TIMEDELTA) result = await hass.config_entries.flow.async_configure( flow_id=result["flow_id"] ) @@ -143,11 +150,12 @@ async def test_full_flow_yaml_creds( assert "data" in result data = result["data"] assert "token" in data + assert 0 < data["token"]["expires_in"] <= 60 * 60 assert ( - data["token"]["expires_in"] - == 60 * 60 - CODE_CHECK_ALARM_TIMEDELTA.total_seconds() + datetime.datetime.now().timestamp() + <= data["token"]["expires_at"] + < (datetime.datetime.now() + datetime.timedelta(days=8)).timestamp() ) - assert data["token"]["expires_at"] == 1654273199.0 data["token"].pop("expires_at") data["token"].pop("expires_in") assert data == { @@ -238,7 +246,7 @@ async def test_code_error( assert await component_setup() with patch( - "oauth2client.client.OAuth2WebServerFlow.step1_get_device_and_user_codes", + "homeassistant.components.google.api.OAuth2WebServerFlow.step1_get_device_and_user_codes", side_effect=OAuth2DeviceCodeError("Test Failure"), ): result = await hass.config_entries.flow.async_init( @@ -248,13 +256,13 @@ async def test_code_error( assert result.get("reason") == "oauth_error" -@pytest.mark.parametrize("code_expiration_delta", [datetime.timedelta(minutes=-5)]) +@pytest.mark.parametrize("code_expiration_delta", [datetime.timedelta(seconds=50)]) async def test_expired_after_exchange( hass: HomeAssistant, mock_code_flow: Mock, component_setup: ComponentSetup, ) -> None: - """Test successful creds setup.""" + """Test credential exchange expires.""" assert await component_setup() result = await hass.config_entries.flow.async_init( @@ -265,10 +273,14 @@ async def test_expired_after_exchange( assert "description_placeholders" in result assert "url" in result["description_placeholders"] - # Run one tick to invoke the credential exchange check - now = utcnow() - await fire_alarm(hass, now + CODE_CHECK_ALARM_TIMEDELTA) - await hass.async_block_till_done() + # Fail first attempt then advance clock past exchange timeout + with patch( + "homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange", + side_effect=FlowExchangeError(), + ): + now = utcnow() + await fire_alarm(hass, now + datetime.timedelta(seconds=65)) + await hass.async_block_till_done() result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"]) assert result.get("type") == "abort" @@ -295,7 +307,7 @@ async def test_exchange_error( # Run one tick to invoke the credential exchange check now = utcnow() with patch( - "oauth2client.client.OAuth2WebServerFlow.step2_exchange", + "homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange", side_effect=FlowExchangeError(), ): now += CODE_CHECK_ALARM_TIMEDELTA