Keep cloud tokens always valid (#20762)
* Keep auth token always valid * Remove unused refresh_auth message * Capture EndpointConnectionError * Lint
This commit is contained in:
parent
b1faad0a50
commit
2733919cd8
5 changed files with 122 additions and 36 deletions
|
@ -106,6 +106,7 @@ async def async_setup(hass, config):
|
||||||
)
|
)
|
||||||
|
|
||||||
cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs)
|
cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs)
|
||||||
|
await auth_api.async_setup(hass, cloud)
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start)
|
||||||
await http_api.async_setup(hass)
|
await http_api.async_setup(hass)
|
||||||
return True
|
return True
|
||||||
|
@ -263,7 +264,7 @@ class Cloud:
|
||||||
self.access_token = info['access_token']
|
self.access_token = info['access_token']
|
||||||
self.refresh_token = info['refresh_token']
|
self.refresh_token = info['refresh_token']
|
||||||
|
|
||||||
self.hass.add_job(self.iot.connect())
|
self.hass.async_create_task(self.iot.connect())
|
||||||
|
|
||||||
def _decode_claims(self, token): # pylint: disable=no-self-use
|
def _decode_claims(self, token): # pylint: disable=no-self-use
|
||||||
"""Decode the claims in a token."""
|
"""Decode the claims in a token."""
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
"""Package to communicate with the authentication API."""
|
"""Package to communicate with the authentication API."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CloudError(Exception):
|
class CloudError(Exception):
|
||||||
|
@ -39,6 +45,40 @@ AWS_EXCEPTIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass, cloud):
|
||||||
|
"""Configure the auth api."""
|
||||||
|
refresh_task = None
|
||||||
|
|
||||||
|
async def handle_token_refresh():
|
||||||
|
"""Handle Cloud access token refresh."""
|
||||||
|
sleep_time = 5
|
||||||
|
sleep_time = random.randint(2400, 3600)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(sleep_time)
|
||||||
|
await hass.async_add_executor_job(renew_access_token, cloud)
|
||||||
|
except CloudError as err:
|
||||||
|
_LOGGER.error("Can't refresh cloud token: %s", err)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Task is canceled, stop it.
|
||||||
|
break
|
||||||
|
|
||||||
|
sleep_time = random.randint(3100, 3600)
|
||||||
|
|
||||||
|
async def on_connect():
|
||||||
|
"""When the instance is connected."""
|
||||||
|
nonlocal refresh_task
|
||||||
|
refresh_task = hass.async_create_task(handle_token_refresh())
|
||||||
|
|
||||||
|
async def on_disconnect():
|
||||||
|
"""When the instance is disconnected."""
|
||||||
|
nonlocal refresh_task
|
||||||
|
refresh_task.cancel()
|
||||||
|
|
||||||
|
cloud.iot.register_on_connect(on_connect)
|
||||||
|
cloud.iot.register_on_disconnect(on_disconnect)
|
||||||
|
|
||||||
|
|
||||||
def _map_aws_exception(err):
|
def _map_aws_exception(err):
|
||||||
"""Map AWS exception to our exceptions."""
|
"""Map AWS exception to our exceptions."""
|
||||||
ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError)
|
ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError)
|
||||||
|
@ -47,7 +87,7 @@ def _map_aws_exception(err):
|
||||||
|
|
||||||
def register(cloud, email, password):
|
def register(cloud, email, password):
|
||||||
"""Register a new account."""
|
"""Register a new account."""
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||||
|
|
||||||
cognito = _cognito(cloud)
|
cognito = _cognito(cloud)
|
||||||
# Workaround for bug in Warrant. PR with fix:
|
# Workaround for bug in Warrant. PR with fix:
|
||||||
|
@ -55,13 +95,16 @@ def register(cloud, email, password):
|
||||||
cognito.add_base_attributes()
|
cognito.add_base_attributes()
|
||||||
try:
|
try:
|
||||||
cognito.register(email, password)
|
cognito.register(email, password)
|
||||||
|
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
raise _map_aws_exception(err)
|
raise _map_aws_exception(err)
|
||||||
|
except EndpointConnectionError:
|
||||||
|
raise UnknownError()
|
||||||
|
|
||||||
|
|
||||||
def resend_email_confirm(cloud, email):
|
def resend_email_confirm(cloud, email):
|
||||||
"""Resend email confirmation."""
|
"""Resend email confirmation."""
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||||
|
|
||||||
cognito = _cognito(cloud, username=email)
|
cognito = _cognito(cloud, username=email)
|
||||||
|
|
||||||
|
@ -72,18 +115,23 @@ def resend_email_confirm(cloud, email):
|
||||||
)
|
)
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
raise _map_aws_exception(err)
|
raise _map_aws_exception(err)
|
||||||
|
except EndpointConnectionError:
|
||||||
|
raise UnknownError()
|
||||||
|
|
||||||
|
|
||||||
def forgot_password(cloud, email):
|
def forgot_password(cloud, email):
|
||||||
"""Initialize forgotten password flow."""
|
"""Initialize forgotten password flow."""
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||||
|
|
||||||
cognito = _cognito(cloud, username=email)
|
cognito = _cognito(cloud, username=email)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cognito.initiate_forgot_password()
|
cognito.initiate_forgot_password()
|
||||||
|
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
raise _map_aws_exception(err)
|
raise _map_aws_exception(err)
|
||||||
|
except EndpointConnectionError:
|
||||||
|
raise UnknownError()
|
||||||
|
|
||||||
|
|
||||||
def login(cloud, email, password):
|
def login(cloud, email, password):
|
||||||
|
@ -97,7 +145,7 @@ def login(cloud, email, password):
|
||||||
|
|
||||||
def check_token(cloud):
|
def check_token(cloud):
|
||||||
"""Check that the token is valid and verify if needed."""
|
"""Check that the token is valid and verify if needed."""
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||||
|
|
||||||
cognito = _cognito(
|
cognito = _cognito(
|
||||||
cloud,
|
cloud,
|
||||||
|
@ -109,13 +157,17 @@ def check_token(cloud):
|
||||||
cloud.id_token = cognito.id_token
|
cloud.id_token = cognito.id_token
|
||||||
cloud.access_token = cognito.access_token
|
cloud.access_token = cognito.access_token
|
||||||
cloud.write_user_info()
|
cloud.write_user_info()
|
||||||
|
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
raise _map_aws_exception(err)
|
raise _map_aws_exception(err)
|
||||||
|
|
||||||
|
except EndpointConnectionError:
|
||||||
|
raise UnknownError()
|
||||||
|
|
||||||
|
|
||||||
def renew_access_token(cloud):
|
def renew_access_token(cloud):
|
||||||
"""Renew access token."""
|
"""Renew access token."""
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||||
|
|
||||||
cognito = _cognito(
|
cognito = _cognito(
|
||||||
cloud,
|
cloud,
|
||||||
|
@ -127,13 +179,17 @@ def renew_access_token(cloud):
|
||||||
cloud.id_token = cognito.id_token
|
cloud.id_token = cognito.id_token
|
||||||
cloud.access_token = cognito.access_token
|
cloud.access_token = cognito.access_token
|
||||||
cloud.write_user_info()
|
cloud.write_user_info()
|
||||||
|
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
raise _map_aws_exception(err)
|
raise _map_aws_exception(err)
|
||||||
|
|
||||||
|
except EndpointConnectionError:
|
||||||
|
raise UnknownError()
|
||||||
|
|
||||||
|
|
||||||
def _authenticate(cloud, email, password):
|
def _authenticate(cloud, email, password):
|
||||||
"""Log in and return an authenticated Cognito instance."""
|
"""Log in and return an authenticated Cognito instance."""
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError, EndpointConnectionError
|
||||||
from warrant.exceptions import ForceChangePasswordException
|
from warrant.exceptions import ForceChangePasswordException
|
||||||
|
|
||||||
assert not cloud.is_logged_in, 'Cannot login if already logged in.'
|
assert not cloud.is_logged_in, 'Cannot login if already logged in.'
|
||||||
|
@ -145,11 +201,14 @@ def _authenticate(cloud, email, password):
|
||||||
return cognito
|
return cognito
|
||||||
|
|
||||||
except ForceChangePasswordException:
|
except ForceChangePasswordException:
|
||||||
raise PasswordChangeRequired
|
raise PasswordChangeRequired()
|
||||||
|
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
raise _map_aws_exception(err)
|
raise _map_aws_exception(err)
|
||||||
|
|
||||||
|
except EndpointConnectionError:
|
||||||
|
raise UnknownError()
|
||||||
|
|
||||||
|
|
||||||
def _cognito(cloud, **kwargs):
|
def _cognito(cloud, **kwargs):
|
||||||
"""Get the client credentials."""
|
"""Get the client credentials."""
|
||||||
|
|
|
@ -62,12 +62,18 @@ class CloudIoT:
|
||||||
# Local code waiting for a response
|
# Local code waiting for a response
|
||||||
self._response_handler = {}
|
self._response_handler = {}
|
||||||
self._on_connect = []
|
self._on_connect = []
|
||||||
|
self._on_disconnect = []
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def register_on_connect(self, on_connect_cb):
|
def register_on_connect(self, on_connect_cb):
|
||||||
"""Register an async on_connect callback."""
|
"""Register an async on_connect callback."""
|
||||||
self._on_connect.append(on_connect_cb)
|
self._on_connect.append(on_connect_cb)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def register_on_disconnect(self, on_disconnect_cb):
|
||||||
|
"""Register an async on_disconnect callback."""
|
||||||
|
self._on_disconnect.append(on_disconnect_cb)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connected(self):
|
def connected(self):
|
||||||
"""Return if we're currently connected."""
|
"""Return if we're currently connected."""
|
||||||
|
@ -102,6 +108,17 @@ class CloudIoT:
|
||||||
# Still adding it here to make sure we can always reconnect
|
# Still adding it here to make sure we can always reconnect
|
||||||
_LOGGER.exception("Unexpected error")
|
_LOGGER.exception("Unexpected error")
|
||||||
|
|
||||||
|
if self.state == STATE_CONNECTED and self._on_disconnect:
|
||||||
|
try:
|
||||||
|
yield from asyncio.wait([
|
||||||
|
cb() for cb in self._on_disconnect
|
||||||
|
])
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# Safety net. This should never hit.
|
||||||
|
# Still adding it here to make sure we don't break the flow
|
||||||
|
_LOGGER.exception(
|
||||||
|
"Unexpected error in on_disconnect callbacks")
|
||||||
|
|
||||||
if self.close_requested:
|
if self.close_requested:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -192,7 +209,13 @@ class CloudIoT:
|
||||||
self.state = STATE_CONNECTED
|
self.state = STATE_CONNECTED
|
||||||
|
|
||||||
if self._on_connect:
|
if self._on_connect:
|
||||||
yield from asyncio.wait([cb() for cb in self._on_connect])
|
try:
|
||||||
|
yield from asyncio.wait([cb() for cb in self._on_connect])
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# Safety net. This should never hit.
|
||||||
|
# Still adding it here to make sure we don't break the flow
|
||||||
|
_LOGGER.exception(
|
||||||
|
"Unexpected error in on_connect callbacks")
|
||||||
|
|
||||||
while not client.closed:
|
while not client.closed:
|
||||||
msg = yield from client.receive()
|
msg = yield from client.receive()
|
||||||
|
@ -326,11 +349,6 @@ async def async_handle_cloud(hass, cloud, payload):
|
||||||
await cloud.logout()
|
await cloud.logout()
|
||||||
_LOGGER.error("You have been logged out from Home Assistant cloud: %s",
|
_LOGGER.error("You have been logged out from Home Assistant cloud: %s",
|
||||||
payload['reason'])
|
payload['reason'])
|
||||||
elif action == 'refresh_auth':
|
|
||||||
# Refresh the auth token between now and payload['seconds']
|
|
||||||
hass.helpers.event.async_call_later(
|
|
||||||
random.randint(0, payload['seconds']),
|
|
||||||
lambda now: auth_api.check_token(cloud))
|
|
||||||
else:
|
else:
|
||||||
_LOGGER.warning("Received unknown cloud action: %s", action)
|
_LOGGER.warning("Received unknown cloud action: %s", action)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Tests for the tools to communicate with the cloud."""
|
"""Tests for the tools to communicate with the cloud."""
|
||||||
|
import asyncio
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
@ -165,3 +166,31 @@ def test_check_token_raises(mock_cognito):
|
||||||
assert cloud.id_token != mock_cognito.id_token
|
assert cloud.id_token != mock_cognito.id_token
|
||||||
assert cloud.access_token != mock_cognito.access_token
|
assert cloud.access_token != mock_cognito.access_token
|
||||||
assert len(cloud.write_user_info.mock_calls) == 0
|
assert len(cloud.write_user_info.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_setup(hass):
|
||||||
|
"""Test async setup."""
|
||||||
|
cloud = MagicMock()
|
||||||
|
await auth_api.async_setup(hass, cloud)
|
||||||
|
assert len(cloud.iot.mock_calls) == 2
|
||||||
|
on_connect = cloud.iot.mock_calls[0][1][0]
|
||||||
|
on_disconnect = cloud.iot.mock_calls[1][1][0]
|
||||||
|
|
||||||
|
with patch('random.randint', return_value=0), patch(
|
||||||
|
'homeassistant.components.cloud.auth_api.renew_access_token'
|
||||||
|
) as mock_renew:
|
||||||
|
await on_connect()
|
||||||
|
# Let handle token sleep once
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# Let handle token refresh token
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(mock_renew.mock_calls) == 1
|
||||||
|
assert mock_renew.mock_calls[0][1][0] is cloud
|
||||||
|
|
||||||
|
await on_disconnect()
|
||||||
|
|
||||||
|
# Make sure task is no longer being called
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert len(mock_renew.mock_calls) == 1
|
||||||
|
|
|
@ -10,9 +10,8 @@ from homeassistant.components.cloud import (
|
||||||
Cloud, iot, auth_api, MODE_DEV)
|
Cloud, iot, auth_api, MODE_DEV)
|
||||||
from homeassistant.components.cloud.const import (
|
from homeassistant.components.cloud.const import (
|
||||||
PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE)
|
PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE)
|
||||||
from homeassistant.util import dt as dt_util
|
|
||||||
from tests.components.alexa import test_smart_home as test_alexa
|
from tests.components.alexa import test_smart_home as test_alexa
|
||||||
from tests.common import mock_coro, async_fire_time_changed
|
from tests.common import mock_coro
|
||||||
|
|
||||||
from . import mock_cloud_prefs
|
from . import mock_cloud_prefs
|
||||||
|
|
||||||
|
@ -158,26 +157,6 @@ async def test_handling_core_messages_logout(hass, mock_cloud):
|
||||||
assert len(mock_cloud.logout.mock_calls) == 1
|
assert len(mock_cloud.logout.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_handling_core_messages_refresh_auth(hass, mock_cloud):
|
|
||||||
"""Test handling core messages."""
|
|
||||||
mock_cloud.hass = hass
|
|
||||||
with patch('random.randint', return_value=0) as mock_rand, patch(
|
|
||||||
'homeassistant.components.cloud.auth_api.check_token'
|
|
||||||
) as mock_check:
|
|
||||||
await iot.async_handle_cloud(hass, mock_cloud, {
|
|
||||||
'action': 'refresh_auth',
|
|
||||||
'seconds': 230,
|
|
||||||
})
|
|
||||||
async_fire_time_changed(hass, dt_util.utcnow())
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert len(mock_rand.mock_calls) == 1
|
|
||||||
assert mock_rand.mock_calls[0][1] == (0, 230)
|
|
||||||
|
|
||||||
assert len(mock_check.mock_calls) == 1
|
|
||||||
assert mock_check.mock_calls[0][1][0] is mock_cloud
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
|
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
|
||||||
"""Test server disconnecting instance."""
|
"""Test server disconnecting instance."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue