diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index d938dd20e67..98e649e1742 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -106,6 +106,7 @@ async def async_setup(hass, config): ) cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs) + await auth_api.async_setup(hass, cloud) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start) await http_api.async_setup(hass) return True @@ -263,7 +264,7 @@ class Cloud: self.access_token = info['access_token'] self.refresh_token = info['refresh_token'] - self.hass.add_job(self.iot.connect()) + self.hass.async_create_task(self.iot.connect()) def _decode_claims(self, token): # pylint: disable=no-self-use """Decode the claims in a token.""" diff --git a/homeassistant/components/cloud/auth_api.py b/homeassistant/components/cloud/auth_api.py index 954d28b803f..6019dac87b9 100644 --- a/homeassistant/components/cloud/auth_api.py +++ b/homeassistant/components/cloud/auth_api.py @@ -1,4 +1,10 @@ """Package to communicate with the authentication API.""" +import asyncio +import logging +import random + + +_LOGGER = logging.getLogger(__name__) class CloudError(Exception): @@ -39,6 +45,40 @@ AWS_EXCEPTIONS = { } +async def async_setup(hass, cloud): + """Configure the auth api.""" + refresh_task = None + + async def handle_token_refresh(): + """Handle Cloud access token refresh.""" + sleep_time = 5 + sleep_time = random.randint(2400, 3600) + while True: + try: + await asyncio.sleep(sleep_time) + await hass.async_add_executor_job(renew_access_token, cloud) + except CloudError as err: + _LOGGER.error("Can't refresh cloud token: %s", err) + except asyncio.CancelledError: + # Task is canceled, stop it. + break + + sleep_time = random.randint(3100, 3600) + + async def on_connect(): + """When the instance is connected.""" + nonlocal refresh_task + refresh_task = hass.async_create_task(handle_token_refresh()) + + async def on_disconnect(): + """When the instance is disconnected.""" + nonlocal refresh_task + refresh_task.cancel() + + cloud.iot.register_on_connect(on_connect) + cloud.iot.register_on_disconnect(on_disconnect) + + def _map_aws_exception(err): """Map AWS exception to our exceptions.""" ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError) @@ -47,7 +87,7 @@ def _map_aws_exception(err): def register(cloud, email, password): """Register a new account.""" - from botocore.exceptions import ClientError + from botocore.exceptions import ClientError, EndpointConnectionError cognito = _cognito(cloud) # Workaround for bug in Warrant. PR with fix: @@ -55,13 +95,16 @@ def register(cloud, email, password): cognito.add_base_attributes() try: cognito.register(email, password) + except ClientError as err: raise _map_aws_exception(err) + except EndpointConnectionError: + raise UnknownError() def resend_email_confirm(cloud, email): """Resend email confirmation.""" - from botocore.exceptions import ClientError + from botocore.exceptions import ClientError, EndpointConnectionError cognito = _cognito(cloud, username=email) @@ -72,18 +115,23 @@ def resend_email_confirm(cloud, email): ) except ClientError as err: raise _map_aws_exception(err) + except EndpointConnectionError: + raise UnknownError() def forgot_password(cloud, email): """Initialize forgotten password flow.""" - from botocore.exceptions import ClientError + from botocore.exceptions import ClientError, EndpointConnectionError cognito = _cognito(cloud, username=email) try: cognito.initiate_forgot_password() + except ClientError as err: raise _map_aws_exception(err) + except EndpointConnectionError: + raise UnknownError() def login(cloud, email, password): @@ -97,7 +145,7 @@ def login(cloud, email, password): def check_token(cloud): """Check that the token is valid and verify if needed.""" - from botocore.exceptions import ClientError + from botocore.exceptions import ClientError, EndpointConnectionError cognito = _cognito( cloud, @@ -109,13 +157,17 @@ def check_token(cloud): cloud.id_token = cognito.id_token cloud.access_token = cognito.access_token cloud.write_user_info() + except ClientError as err: raise _map_aws_exception(err) + except EndpointConnectionError: + raise UnknownError() + def renew_access_token(cloud): """Renew access token.""" - from botocore.exceptions import ClientError + from botocore.exceptions import ClientError, EndpointConnectionError cognito = _cognito( cloud, @@ -127,13 +179,17 @@ def renew_access_token(cloud): cloud.id_token = cognito.id_token cloud.access_token = cognito.access_token cloud.write_user_info() + except ClientError as err: raise _map_aws_exception(err) + except EndpointConnectionError: + raise UnknownError() + def _authenticate(cloud, email, password): """Log in and return an authenticated Cognito instance.""" - from botocore.exceptions import ClientError + from botocore.exceptions import ClientError, EndpointConnectionError from warrant.exceptions import ForceChangePasswordException assert not cloud.is_logged_in, 'Cannot login if already logged in.' @@ -145,11 +201,14 @@ def _authenticate(cloud, email, password): return cognito except ForceChangePasswordException: - raise PasswordChangeRequired + raise PasswordChangeRequired() except ClientError as err: raise _map_aws_exception(err) + except EndpointConnectionError: + raise UnknownError() + def _cognito(cloud, **kwargs): """Get the client credentials.""" diff --git a/homeassistant/components/cloud/iot.py b/homeassistant/components/cloud/iot.py index d725cb309bc..055c4dbaa64 100644 --- a/homeassistant/components/cloud/iot.py +++ b/homeassistant/components/cloud/iot.py @@ -62,12 +62,18 @@ class CloudIoT: # Local code waiting for a response self._response_handler = {} self._on_connect = [] + self._on_disconnect = [] @callback def register_on_connect(self, on_connect_cb): """Register an async on_connect callback.""" self._on_connect.append(on_connect_cb) + @callback + def register_on_disconnect(self, on_disconnect_cb): + """Register an async on_disconnect callback.""" + self._on_disconnect.append(on_disconnect_cb) + @property def connected(self): """Return if we're currently connected.""" @@ -102,6 +108,17 @@ class CloudIoT: # Still adding it here to make sure we can always reconnect _LOGGER.exception("Unexpected error") + if self.state == STATE_CONNECTED and self._on_disconnect: + try: + yield from asyncio.wait([ + cb() for cb in self._on_disconnect + ]) + except Exception: # pylint: disable=broad-except + # Safety net. This should never hit. + # Still adding it here to make sure we don't break the flow + _LOGGER.exception( + "Unexpected error in on_disconnect callbacks") + if self.close_requested: break @@ -192,7 +209,13 @@ class CloudIoT: self.state = STATE_CONNECTED if self._on_connect: - yield from asyncio.wait([cb() for cb in self._on_connect]) + try: + yield from asyncio.wait([cb() for cb in self._on_connect]) + except Exception: # pylint: disable=broad-except + # Safety net. This should never hit. + # Still adding it here to make sure we don't break the flow + _LOGGER.exception( + "Unexpected error in on_connect callbacks") while not client.closed: msg = yield from client.receive() @@ -326,11 +349,6 @@ async def async_handle_cloud(hass, cloud, payload): await cloud.logout() _LOGGER.error("You have been logged out from Home Assistant cloud: %s", payload['reason']) - elif action == 'refresh_auth': - # Refresh the auth token between now and payload['seconds'] - hass.helpers.event.async_call_later( - random.randint(0, payload['seconds']), - lambda now: auth_api.check_token(cloud)) else: _LOGGER.warning("Received unknown cloud action: %s", action) diff --git a/tests/components/cloud/test_auth_api.py b/tests/components/cloud/test_auth_api.py index a50a4d796aa..bdf9939cb2b 100644 --- a/tests/components/cloud/test_auth_api.py +++ b/tests/components/cloud/test_auth_api.py @@ -1,4 +1,5 @@ """Tests for the tools to communicate with the cloud.""" +import asyncio from unittest.mock import MagicMock, patch from botocore.exceptions import ClientError @@ -165,3 +166,31 @@ def test_check_token_raises(mock_cognito): assert cloud.id_token != mock_cognito.id_token assert cloud.access_token != mock_cognito.access_token assert len(cloud.write_user_info.mock_calls) == 0 + + +async def test_async_setup(hass): + """Test async setup.""" + cloud = MagicMock() + await auth_api.async_setup(hass, cloud) + assert len(cloud.iot.mock_calls) == 2 + on_connect = cloud.iot.mock_calls[0][1][0] + on_disconnect = cloud.iot.mock_calls[1][1][0] + + with patch('random.randint', return_value=0), patch( + 'homeassistant.components.cloud.auth_api.renew_access_token' + ) as mock_renew: + await on_connect() + # Let handle token sleep once + await asyncio.sleep(0) + # Let handle token refresh token + await asyncio.sleep(0) + + assert len(mock_renew.mock_calls) == 1 + assert mock_renew.mock_calls[0][1][0] is cloud + + await on_disconnect() + + # Make sure task is no longer being called + await asyncio.sleep(0) + await asyncio.sleep(0) + assert len(mock_renew.mock_calls) == 1 diff --git a/tests/components/cloud/test_iot.py b/tests/components/cloud/test_iot.py index 1a528f8cedf..10a94f46833 100644 --- a/tests/components/cloud/test_iot.py +++ b/tests/components/cloud/test_iot.py @@ -10,9 +10,8 @@ from homeassistant.components.cloud import ( Cloud, iot, auth_api, MODE_DEV) from homeassistant.components.cloud.const import ( PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE) -from homeassistant.util import dt as dt_util from tests.components.alexa import test_smart_home as test_alexa -from tests.common import mock_coro, async_fire_time_changed +from tests.common import mock_coro from . import mock_cloud_prefs @@ -158,26 +157,6 @@ async def test_handling_core_messages_logout(hass, mock_cloud): assert len(mock_cloud.logout.mock_calls) == 1 -async def test_handling_core_messages_refresh_auth(hass, mock_cloud): - """Test handling core messages.""" - mock_cloud.hass = hass - with patch('random.randint', return_value=0) as mock_rand, patch( - 'homeassistant.components.cloud.auth_api.check_token' - ) as mock_check: - await iot.async_handle_cloud(hass, mock_cloud, { - 'action': 'refresh_auth', - 'seconds': 230, - }) - async_fire_time_changed(hass, dt_util.utcnow()) - await hass.async_block_till_done() - - assert len(mock_rand.mock_calls) == 1 - assert mock_rand.mock_calls[0][1] == (0, 230) - - assert len(mock_check.mock_calls) == 1 - assert mock_check.mock_calls[0][1][0] is mock_cloud - - @asyncio.coroutine def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud): """Test server disconnecting instance."""