Keep cloud tokens always valid (#20762)
* Keep auth token always valid * Remove unused refresh_auth message * Capture EndpointConnectionError * Lint
This commit is contained in:
parent
b1faad0a50
commit
2733919cd8
5 changed files with 122 additions and 36 deletions
|
@ -106,6 +106,7 @@ async def async_setup(hass, config):
|
|||
)
|
||||
|
||||
cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs)
|
||||
await auth_api.async_setup(hass, cloud)
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start)
|
||||
await http_api.async_setup(hass)
|
||||
return True
|
||||
|
@ -263,7 +264,7 @@ class Cloud:
|
|||
self.access_token = info['access_token']
|
||||
self.refresh_token = info['refresh_token']
|
||||
|
||||
self.hass.add_job(self.iot.connect())
|
||||
self.hass.async_create_task(self.iot.connect())
|
||||
|
||||
def _decode_claims(self, token): # pylint: disable=no-self-use
|
||||
"""Decode the claims in a token."""
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
"""Package to communicate with the authentication API."""
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CloudError(Exception):
|
||||
|
@ -39,6 +45,40 @@ AWS_EXCEPTIONS = {
|
|||
}
|
||||
|
||||
|
||||
async def async_setup(hass, cloud):
|
||||
"""Configure the auth api."""
|
||||
refresh_task = None
|
||||
|
||||
async def handle_token_refresh():
|
||||
"""Handle Cloud access token refresh."""
|
||||
sleep_time = 5
|
||||
sleep_time = random.randint(2400, 3600)
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(sleep_time)
|
||||
await hass.async_add_executor_job(renew_access_token, cloud)
|
||||
except CloudError as err:
|
||||
_LOGGER.error("Can't refresh cloud token: %s", err)
|
||||
except asyncio.CancelledError:
|
||||
# Task is canceled, stop it.
|
||||
break
|
||||
|
||||
sleep_time = random.randint(3100, 3600)
|
||||
|
||||
async def on_connect():
|
||||
"""When the instance is connected."""
|
||||
nonlocal refresh_task
|
||||
refresh_task = hass.async_create_task(handle_token_refresh())
|
||||
|
||||
async def on_disconnect():
|
||||
"""When the instance is disconnected."""
|
||||
nonlocal refresh_task
|
||||
refresh_task.cancel()
|
||||
|
||||
cloud.iot.register_on_connect(on_connect)
|
||||
cloud.iot.register_on_disconnect(on_disconnect)
|
||||
|
||||
|
||||
def _map_aws_exception(err):
|
||||
"""Map AWS exception to our exceptions."""
|
||||
ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError)
|
||||
|
@ -47,7 +87,7 @@ def _map_aws_exception(err):
|
|||
|
||||
def register(cloud, email, password):
|
||||
"""Register a new account."""
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||
|
||||
cognito = _cognito(cloud)
|
||||
# Workaround for bug in Warrant. PR with fix:
|
||||
|
@ -55,13 +95,16 @@ def register(cloud, email, password):
|
|||
cognito.add_base_attributes()
|
||||
try:
|
||||
cognito.register(email, password)
|
||||
|
||||
except ClientError as err:
|
||||
raise _map_aws_exception(err)
|
||||
except EndpointConnectionError:
|
||||
raise UnknownError()
|
||||
|
||||
|
||||
def resend_email_confirm(cloud, email):
|
||||
"""Resend email confirmation."""
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||
|
||||
cognito = _cognito(cloud, username=email)
|
||||
|
||||
|
@ -72,18 +115,23 @@ def resend_email_confirm(cloud, email):
|
|||
)
|
||||
except ClientError as err:
|
||||
raise _map_aws_exception(err)
|
||||
except EndpointConnectionError:
|
||||
raise UnknownError()
|
||||
|
||||
|
||||
def forgot_password(cloud, email):
|
||||
"""Initialize forgotten password flow."""
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||
|
||||
cognito = _cognito(cloud, username=email)
|
||||
|
||||
try:
|
||||
cognito.initiate_forgot_password()
|
||||
|
||||
except ClientError as err:
|
||||
raise _map_aws_exception(err)
|
||||
except EndpointConnectionError:
|
||||
raise UnknownError()
|
||||
|
||||
|
||||
def login(cloud, email, password):
|
||||
|
@ -97,7 +145,7 @@ def login(cloud, email, password):
|
|||
|
||||
def check_token(cloud):
|
||||
"""Check that the token is valid and verify if needed."""
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||
|
||||
cognito = _cognito(
|
||||
cloud,
|
||||
|
@ -109,13 +157,17 @@ def check_token(cloud):
|
|||
cloud.id_token = cognito.id_token
|
||||
cloud.access_token = cognito.access_token
|
||||
cloud.write_user_info()
|
||||
|
||||
except ClientError as err:
|
||||
raise _map_aws_exception(err)
|
||||
|
||||
except EndpointConnectionError:
|
||||
raise UnknownError()
|
||||
|
||||
|
||||
def renew_access_token(cloud):
|
||||
"""Renew access token."""
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||
|
||||
cognito = _cognito(
|
||||
cloud,
|
||||
|
@ -127,13 +179,17 @@ def renew_access_token(cloud):
|
|||
cloud.id_token = cognito.id_token
|
||||
cloud.access_token = cognito.access_token
|
||||
cloud.write_user_info()
|
||||
|
||||
except ClientError as err:
|
||||
raise _map_aws_exception(err)
|
||||
|
||||
except EndpointConnectionError:
|
||||
raise UnknownError()
|
||||
|
||||
|
||||
def _authenticate(cloud, email, password):
|
||||
"""Log in and return an authenticated Cognito instance."""
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||
from warrant.exceptions import ForceChangePasswordException
|
||||
|
||||
assert not cloud.is_logged_in, 'Cannot login if already logged in.'
|
||||
|
@ -145,11 +201,14 @@ def _authenticate(cloud, email, password):
|
|||
return cognito
|
||||
|
||||
except ForceChangePasswordException:
|
||||
raise PasswordChangeRequired
|
||||
raise PasswordChangeRequired()
|
||||
|
||||
except ClientError as err:
|
||||
raise _map_aws_exception(err)
|
||||
|
||||
except EndpointConnectionError:
|
||||
raise UnknownError()
|
||||
|
||||
|
||||
def _cognito(cloud, **kwargs):
|
||||
"""Get the client credentials."""
|
||||
|
|
|
@ -62,12 +62,18 @@ class CloudIoT:
|
|||
# Local code waiting for a response
|
||||
self._response_handler = {}
|
||||
self._on_connect = []
|
||||
self._on_disconnect = []
|
||||
|
||||
@callback
|
||||
def register_on_connect(self, on_connect_cb):
|
||||
"""Register an async on_connect callback."""
|
||||
self._on_connect.append(on_connect_cb)
|
||||
|
||||
@callback
|
||||
def register_on_disconnect(self, on_disconnect_cb):
|
||||
"""Register an async on_disconnect callback."""
|
||||
self._on_disconnect.append(on_disconnect_cb)
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
"""Return if we're currently connected."""
|
||||
|
@ -102,6 +108,17 @@ class CloudIoT:
|
|||
# Still adding it here to make sure we can always reconnect
|
||||
_LOGGER.exception("Unexpected error")
|
||||
|
||||
if self.state == STATE_CONNECTED and self._on_disconnect:
|
||||
try:
|
||||
yield from asyncio.wait([
|
||||
cb() for cb in self._on_disconnect
|
||||
])
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Safety net. This should never hit.
|
||||
# Still adding it here to make sure we don't break the flow
|
||||
_LOGGER.exception(
|
||||
"Unexpected error in on_disconnect callbacks")
|
||||
|
||||
if self.close_requested:
|
||||
break
|
||||
|
||||
|
@ -192,7 +209,13 @@ class CloudIoT:
|
|||
self.state = STATE_CONNECTED
|
||||
|
||||
if self._on_connect:
|
||||
try:
|
||||
yield from asyncio.wait([cb() for cb in self._on_connect])
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Safety net. This should never hit.
|
||||
# Still adding it here to make sure we don't break the flow
|
||||
_LOGGER.exception(
|
||||
"Unexpected error in on_connect callbacks")
|
||||
|
||||
while not client.closed:
|
||||
msg = yield from client.receive()
|
||||
|
@ -326,11 +349,6 @@ async def async_handle_cloud(hass, cloud, payload):
|
|||
await cloud.logout()
|
||||
_LOGGER.error("You have been logged out from Home Assistant cloud: %s",
|
||||
payload['reason'])
|
||||
elif action == 'refresh_auth':
|
||||
# Refresh the auth token between now and payload['seconds']
|
||||
hass.helpers.event.async_call_later(
|
||||
random.randint(0, payload['seconds']),
|
||||
lambda now: auth_api.check_token(cloud))
|
||||
else:
|
||||
_LOGGER.warning("Received unknown cloud action: %s", action)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Tests for the tools to communicate with the cloud."""
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
@ -165,3 +166,31 @@ def test_check_token_raises(mock_cognito):
|
|||
assert cloud.id_token != mock_cognito.id_token
|
||||
assert cloud.access_token != mock_cognito.access_token
|
||||
assert len(cloud.write_user_info.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_async_setup(hass):
|
||||
"""Test async setup."""
|
||||
cloud = MagicMock()
|
||||
await auth_api.async_setup(hass, cloud)
|
||||
assert len(cloud.iot.mock_calls) == 2
|
||||
on_connect = cloud.iot.mock_calls[0][1][0]
|
||||
on_disconnect = cloud.iot.mock_calls[1][1][0]
|
||||
|
||||
with patch('random.randint', return_value=0), patch(
|
||||
'homeassistant.components.cloud.auth_api.renew_access_token'
|
||||
) as mock_renew:
|
||||
await on_connect()
|
||||
# Let handle token sleep once
|
||||
await asyncio.sleep(0)
|
||||
# Let handle token refresh token
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(mock_renew.mock_calls) == 1
|
||||
assert mock_renew.mock_calls[0][1][0] is cloud
|
||||
|
||||
await on_disconnect()
|
||||
|
||||
# Make sure task is no longer being called
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert len(mock_renew.mock_calls) == 1
|
||||
|
|
|
@ -10,9 +10,8 @@ from homeassistant.components.cloud import (
|
|||
Cloud, iot, auth_api, MODE_DEV)
|
||||
from homeassistant.components.cloud.const import (
|
||||
PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE)
|
||||
from homeassistant.util import dt as dt_util
|
||||
from tests.components.alexa import test_smart_home as test_alexa
|
||||
from tests.common import mock_coro, async_fire_time_changed
|
||||
from tests.common import mock_coro
|
||||
|
||||
from . import mock_cloud_prefs
|
||||
|
||||
|
@ -158,26 +157,6 @@ async def test_handling_core_messages_logout(hass, mock_cloud):
|
|||
assert len(mock_cloud.logout.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_handling_core_messages_refresh_auth(hass, mock_cloud):
|
||||
"""Test handling core messages."""
|
||||
mock_cloud.hass = hass
|
||||
with patch('random.randint', return_value=0) as mock_rand, patch(
|
||||
'homeassistant.components.cloud.auth_api.check_token'
|
||||
) as mock_check:
|
||||
await iot.async_handle_cloud(hass, mock_cloud, {
|
||||
'action': 'refresh_auth',
|
||||
'seconds': 230,
|
||||
})
|
||||
async_fire_time_changed(hass, dt_util.utcnow())
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_rand.mock_calls) == 1
|
||||
assert mock_rand.mock_calls[0][1] == (0, 230)
|
||||
|
||||
assert len(mock_check.mock_calls) == 1
|
||||
assert mock_check.mock_calls[0][1][0] is mock_cloud
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
|
||||
"""Test server disconnecting instance."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue