Keep cloud tokens always valid (#20762)

* Keep auth token always valid

* Remove unused refresh_auth message

* Capture EndpointConnectionError

* Lint
This commit is contained in:
Paulus Schoutsen 2019-02-05 01:45:03 -08:00 committed by Pascal Vizeli
parent b1faad0a50
commit 2733919cd8
5 changed files with 122 additions and 36 deletions

View file

@ -106,6 +106,7 @@ async def async_setup(hass, config):
) )
cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs) cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs)
await auth_api.async_setup(hass, cloud)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start)
await http_api.async_setup(hass) await http_api.async_setup(hass)
return True return True
@ -263,7 +264,7 @@ class Cloud:
self.access_token = info['access_token'] self.access_token = info['access_token']
self.refresh_token = info['refresh_token'] self.refresh_token = info['refresh_token']
self.hass.add_job(self.iot.connect()) self.hass.async_create_task(self.iot.connect())
def _decode_claims(self, token): # pylint: disable=no-self-use def _decode_claims(self, token): # pylint: disable=no-self-use
"""Decode the claims in a token.""" """Decode the claims in a token."""

View file

@ -1,4 +1,10 @@
"""Package to communicate with the authentication API.""" """Package to communicate with the authentication API."""
import asyncio
import logging
import random
_LOGGER = logging.getLogger(__name__)
class CloudError(Exception): class CloudError(Exception):
@ -39,6 +45,40 @@ AWS_EXCEPTIONS = {
} }
async def async_setup(hass, cloud):
"""Configure the auth api."""
refresh_task = None
async def handle_token_refresh():
"""Handle Cloud access token refresh."""
sleep_time = 5
sleep_time = random.randint(2400, 3600)
while True:
try:
await asyncio.sleep(sleep_time)
await hass.async_add_executor_job(renew_access_token, cloud)
except CloudError as err:
_LOGGER.error("Can't refresh cloud token: %s", err)
except asyncio.CancelledError:
# Task is canceled, stop it.
break
sleep_time = random.randint(3100, 3600)
async def on_connect():
"""When the instance is connected."""
nonlocal refresh_task
refresh_task = hass.async_create_task(handle_token_refresh())
async def on_disconnect():
"""When the instance is disconnected."""
nonlocal refresh_task
refresh_task.cancel()
cloud.iot.register_on_connect(on_connect)
cloud.iot.register_on_disconnect(on_disconnect)
def _map_aws_exception(err): def _map_aws_exception(err):
"""Map AWS exception to our exceptions.""" """Map AWS exception to our exceptions."""
ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError) ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError)
@ -47,7 +87,7 @@ def _map_aws_exception(err):
def register(cloud, email, password): def register(cloud, email, password):
"""Register a new account.""" """Register a new account."""
from botocore.exceptions import ClientError from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(cloud) cognito = _cognito(cloud)
# Workaround for bug in Warrant. PR with fix: # Workaround for bug in Warrant. PR with fix:
@ -55,13 +95,16 @@ def register(cloud, email, password):
cognito.add_base_attributes() cognito.add_base_attributes()
try: try:
cognito.register(email, password) cognito.register(email, password)
except ClientError as err: except ClientError as err:
raise _map_aws_exception(err) raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def resend_email_confirm(cloud, email): def resend_email_confirm(cloud, email):
"""Resend email confirmation.""" """Resend email confirmation."""
from botocore.exceptions import ClientError from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(cloud, username=email) cognito = _cognito(cloud, username=email)
@ -72,18 +115,23 @@ def resend_email_confirm(cloud, email):
) )
except ClientError as err: except ClientError as err:
raise _map_aws_exception(err) raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def forgot_password(cloud, email): def forgot_password(cloud, email):
"""Initialize forgotten password flow.""" """Initialize forgotten password flow."""
from botocore.exceptions import ClientError from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(cloud, username=email) cognito = _cognito(cloud, username=email)
try: try:
cognito.initiate_forgot_password() cognito.initiate_forgot_password()
except ClientError as err: except ClientError as err:
raise _map_aws_exception(err) raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def login(cloud, email, password): def login(cloud, email, password):
@ -97,7 +145,7 @@ def login(cloud, email, password):
def check_token(cloud): def check_token(cloud):
"""Check that the token is valid and verify if needed.""" """Check that the token is valid and verify if needed."""
from botocore.exceptions import ClientError from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito( cognito = _cognito(
cloud, cloud,
@ -109,13 +157,17 @@ def check_token(cloud):
cloud.id_token = cognito.id_token cloud.id_token = cognito.id_token
cloud.access_token = cognito.access_token cloud.access_token = cognito.access_token
cloud.write_user_info() cloud.write_user_info()
except ClientError as err: except ClientError as err:
raise _map_aws_exception(err) raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def renew_access_token(cloud): def renew_access_token(cloud):
"""Renew access token.""" """Renew access token."""
from botocore.exceptions import ClientError from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito( cognito = _cognito(
cloud, cloud,
@ -127,13 +179,17 @@ def renew_access_token(cloud):
cloud.id_token = cognito.id_token cloud.id_token = cognito.id_token
cloud.access_token = cognito.access_token cloud.access_token = cognito.access_token
cloud.write_user_info() cloud.write_user_info()
except ClientError as err: except ClientError as err:
raise _map_aws_exception(err) raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def _authenticate(cloud, email, password): def _authenticate(cloud, email, password):
"""Log in and return an authenticated Cognito instance.""" """Log in and return an authenticated Cognito instance."""
from botocore.exceptions import ClientError from botocore.exceptions import ClientError, EndpointConnectionError
from warrant.exceptions import ForceChangePasswordException from warrant.exceptions import ForceChangePasswordException
assert not cloud.is_logged_in, 'Cannot login if already logged in.' assert not cloud.is_logged_in, 'Cannot login if already logged in.'
@ -145,11 +201,14 @@ def _authenticate(cloud, email, password):
return cognito return cognito
except ForceChangePasswordException: except ForceChangePasswordException:
raise PasswordChangeRequired raise PasswordChangeRequired()
except ClientError as err: except ClientError as err:
raise _map_aws_exception(err) raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def _cognito(cloud, **kwargs): def _cognito(cloud, **kwargs):
"""Get the client credentials.""" """Get the client credentials."""

View file

@ -62,12 +62,18 @@ class CloudIoT:
# Local code waiting for a response # Local code waiting for a response
self._response_handler = {} self._response_handler = {}
self._on_connect = [] self._on_connect = []
self._on_disconnect = []
@callback @callback
def register_on_connect(self, on_connect_cb): def register_on_connect(self, on_connect_cb):
"""Register an async on_connect callback.""" """Register an async on_connect callback."""
self._on_connect.append(on_connect_cb) self._on_connect.append(on_connect_cb)
@callback
def register_on_disconnect(self, on_disconnect_cb):
"""Register an async on_disconnect callback."""
self._on_disconnect.append(on_disconnect_cb)
@property @property
def connected(self): def connected(self):
"""Return if we're currently connected.""" """Return if we're currently connected."""
@ -102,6 +108,17 @@ class CloudIoT:
# Still adding it here to make sure we can always reconnect # Still adding it here to make sure we can always reconnect
_LOGGER.exception("Unexpected error") _LOGGER.exception("Unexpected error")
if self.state == STATE_CONNECTED and self._on_disconnect:
try:
yield from asyncio.wait([
cb() for cb in self._on_disconnect
])
except Exception: # pylint: disable=broad-except
# Safety net. This should never hit.
# Still adding it here to make sure we don't break the flow
_LOGGER.exception(
"Unexpected error in on_disconnect callbacks")
if self.close_requested: if self.close_requested:
break break
@ -192,7 +209,13 @@ class CloudIoT:
self.state = STATE_CONNECTED self.state = STATE_CONNECTED
if self._on_connect: if self._on_connect:
yield from asyncio.wait([cb() for cb in self._on_connect]) try:
yield from asyncio.wait([cb() for cb in self._on_connect])
except Exception: # pylint: disable=broad-except
# Safety net. This should never hit.
# Still adding it here to make sure we don't break the flow
_LOGGER.exception(
"Unexpected error in on_connect callbacks")
while not client.closed: while not client.closed:
msg = yield from client.receive() msg = yield from client.receive()
@ -326,11 +349,6 @@ async def async_handle_cloud(hass, cloud, payload):
await cloud.logout() await cloud.logout()
_LOGGER.error("You have been logged out from Home Assistant cloud: %s", _LOGGER.error("You have been logged out from Home Assistant cloud: %s",
payload['reason']) payload['reason'])
elif action == 'refresh_auth':
# Refresh the auth token between now and payload['seconds']
hass.helpers.event.async_call_later(
random.randint(0, payload['seconds']),
lambda now: auth_api.check_token(cloud))
else: else:
_LOGGER.warning("Received unknown cloud action: %s", action) _LOGGER.warning("Received unknown cloud action: %s", action)

View file

@ -1,4 +1,5 @@
"""Tests for the tools to communicate with the cloud.""" """Tests for the tools to communicate with the cloud."""
import asyncio
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
@ -165,3 +166,31 @@ def test_check_token_raises(mock_cognito):
assert cloud.id_token != mock_cognito.id_token assert cloud.id_token != mock_cognito.id_token
assert cloud.access_token != mock_cognito.access_token assert cloud.access_token != mock_cognito.access_token
assert len(cloud.write_user_info.mock_calls) == 0 assert len(cloud.write_user_info.mock_calls) == 0
async def test_async_setup(hass):
"""Test async setup."""
cloud = MagicMock()
await auth_api.async_setup(hass, cloud)
assert len(cloud.iot.mock_calls) == 2
on_connect = cloud.iot.mock_calls[0][1][0]
on_disconnect = cloud.iot.mock_calls[1][1][0]
with patch('random.randint', return_value=0), patch(
'homeassistant.components.cloud.auth_api.renew_access_token'
) as mock_renew:
await on_connect()
# Let handle token sleep once
await asyncio.sleep(0)
# Let handle token refresh token
await asyncio.sleep(0)
assert len(mock_renew.mock_calls) == 1
assert mock_renew.mock_calls[0][1][0] is cloud
await on_disconnect()
# Make sure task is no longer being called
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(mock_renew.mock_calls) == 1

View file

@ -10,9 +10,8 @@ from homeassistant.components.cloud import (
Cloud, iot, auth_api, MODE_DEV) Cloud, iot, auth_api, MODE_DEV)
from homeassistant.components.cloud.const import ( from homeassistant.components.cloud.const import (
PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE) PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE)
from homeassistant.util import dt as dt_util
from tests.components.alexa import test_smart_home as test_alexa from tests.components.alexa import test_smart_home as test_alexa
from tests.common import mock_coro, async_fire_time_changed from tests.common import mock_coro
from . import mock_cloud_prefs from . import mock_cloud_prefs
@ -158,26 +157,6 @@ async def test_handling_core_messages_logout(hass, mock_cloud):
assert len(mock_cloud.logout.mock_calls) == 1 assert len(mock_cloud.logout.mock_calls) == 1
async def test_handling_core_messages_refresh_auth(hass, mock_cloud):
"""Test handling core messages."""
mock_cloud.hass = hass
with patch('random.randint', return_value=0) as mock_rand, patch(
'homeassistant.components.cloud.auth_api.check_token'
) as mock_check:
await iot.async_handle_cloud(hass, mock_cloud, {
'action': 'refresh_auth',
'seconds': 230,
})
async_fire_time_changed(hass, dt_util.utcnow())
await hass.async_block_till_done()
assert len(mock_rand.mock_calls) == 1
assert mock_rand.mock_calls[0][1] == (0, 230)
assert len(mock_check.mock_calls) == 1
assert mock_check.mock_calls[0][1][0] is mock_cloud
@asyncio.coroutine @asyncio.coroutine
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud): def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
"""Test server disconnecting instance.""" """Test server disconnecting instance."""