diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index c5d709d60c3..2d01399bc07 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -1,5 +1,6 @@ """Component to integrate the Home Assistant cloud.""" import asyncio +from datetime import datetime import json import logging import os @@ -8,6 +9,7 @@ import voluptuous as vol from homeassistant.const import ( EVENT_HOMEASSISTANT_START, CONF_REGION, CONF_MODE) +from homeassistant.util import dt as dt_util from . import http_api, iot from .const import CONFIG_DIR, DOMAIN, SERVERS @@ -66,7 +68,6 @@ class Cloud: """Create an instance of Cloud.""" self.hass = hass self.mode = mode - self.email = None self.id_token = None self.access_token = None self.refresh_token = None @@ -89,7 +90,29 @@ class Cloud: @property def is_logged_in(self): """Get if cloud is logged in.""" - return self.email is not None + return self.id_token is not None + + @property + def subscription_expired(self): + """Return a boolen if the subscription has expired.""" + # For now, don't enforce subscriptions to exist + if 'custom:sub-exp' not in self.claims: + return False + + return dt_util.utcnow() > self.expiration_date + + @property + def expiration_date(self): + """Return the subscription expiration as a UTC datetime object.""" + return datetime.combine( + dt_util.parse_date(self.claims['custom:sub-exp']), + datetime.min.time()).replace(tzinfo=dt_util.UTC) + + @property + def claims(self): + """Get the claims from the id token.""" + from jose import jwt + return jwt.get_unverified_claims(self.id_token) @property def user_info_path(self): @@ -110,18 +133,20 @@ class Cloud: if os.path.isfile(user_info): with open(user_info, 'rt') as file: info = json.loads(file.read()) - self.email = info['email'] self.id_token = info['id_token'] self.access_token = info['access_token'] self.refresh_token = info['refresh_token'] yield from self.hass.async_add_job(load_config) - if self.email is not None: + if self.id_token is not None: yield from self.iot.connect() def path(self, *parts): - """Get config path inside cloud dir.""" + """Get config path inside cloud dir. + + Async friendly. + """ return self.hass.config.path(CONFIG_DIR, *parts) @asyncio.coroutine @@ -129,7 +154,6 @@ class Cloud: """Close connection and remove all credentials.""" yield from self.iot.disconnect() - self.email = None self.id_token = None self.access_token = None self.refresh_token = None @@ -141,7 +165,6 @@ class Cloud: """Write user info to a file.""" with open(self.user_info_path, 'wt') as file: file.write(json.dumps({ - 'email': self.email, 'id_token': self.id_token, 'access_token': self.access_token, 'refresh_token': self.refresh_token, diff --git a/homeassistant/components/cloud/auth_api.py b/homeassistant/components/cloud/auth_api.py index 50a88d4be4d..cb9fe15ab4a 100644 --- a/homeassistant/components/cloud/auth_api.py +++ b/homeassistant/components/cloud/auth_api.py @@ -113,7 +113,6 @@ def login(cloud, email, password): cloud.id_token = cognito.id_token cloud.access_token = cognito.access_token cloud.refresh_token = cognito.refresh_token - cloud.email = email cloud.write_user_info() diff --git a/homeassistant/components/cloud/const.py b/homeassistant/components/cloud/const.py index 334e522f81b..440e4179eea 100644 --- a/homeassistant/components/cloud/const.py +++ b/homeassistant/components/cloud/const.py @@ -12,3 +12,8 @@ SERVERS = { # 'relayer': '' # } } + +MESSAGE_EXPIRATION = """ +It looks like your Home Assistant Cloud subscription has expired. Please check +your [account page](/config/cloud/account) to continue using the service. +""" diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index aa91f5a45e7..d16df130c48 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -79,8 +79,10 @@ class CloudLoginView(HomeAssistantView): with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop): yield from hass.async_add_job(auth_api.login, cloud, data['email'], data['password']) - hass.async_add_job(cloud.iot.connect) + hass.async_add_job(cloud.iot.connect) + # Allow cloud to start connecting. + yield from asyncio.sleep(0, loop=hass.loop) return self.json(_account_data(cloud)) @@ -222,6 +224,10 @@ class CloudConfirmForgotPasswordView(HomeAssistantView): def _account_data(cloud): """Generate the auth data JSON response.""" + claims = cloud.claims + return { - 'email': cloud.email + 'email': claims['email'], + 'sub_exp': claims.get('custom:sub-exp'), + 'cloud': cloud.iot.state, } diff --git a/homeassistant/components/cloud/iot.py b/homeassistant/components/cloud/iot.py index 1bb6668e0cc..c0b6bb96da1 100644 --- a/homeassistant/components/cloud/iot.py +++ b/homeassistant/components/cloud/iot.py @@ -9,11 +9,16 @@ from homeassistant.components.alexa import smart_home from homeassistant.util.decorator import Registry from homeassistant.helpers.aiohttp_client import async_get_clientsession from . import auth_api +from .const import MESSAGE_EXPIRATION HANDLERS = Registry() _LOGGER = logging.getLogger(__name__) +STATE_CONNECTING = 'connecting' +STATE_CONNECTED = 'connected' +STATE_DISCONNECTED = 'disconnected' + class UnknownHandler(Exception): """Exception raised when trying to handle unknown handler.""" @@ -25,27 +30,41 @@ class CloudIoT: def __init__(self, cloud): """Initialize the CloudIoT class.""" self.cloud = cloud + # The WebSocket client self.client = None + # Scheduled sleep task till next connection retry + self.retry_task = None + # Boolean to indicate if we wanted the connection to close self.close_requested = False + # The current number of attempts to connect, impacts wait time self.tries = 0 - - @property - def is_connected(self): - """Return if connected to the cloud.""" - return self.client is not None + # Current state of the connection + self.state = STATE_DISCONNECTED @asyncio.coroutine def connect(self): """Connect to the IoT broker.""" - if self.client is not None: - raise RuntimeError('Cannot connect while already connected') - - self.close_requested = False - hass = self.cloud.hass - remove_hass_stop_listener = None + if self.cloud.subscription_expired: + # Try refreshing the token to see if it is still expired. + yield from hass.async_add_job(auth_api.check_token, self.cloud) + if self.cloud.subscription_expired: + hass.components.persistent_notification.async_create( + MESSAGE_EXPIRATION, 'Subscription expired', + 'cloud_subscription_expired') + self.state = STATE_DISCONNECTED + return + + if self.state == STATE_CONNECTED: + raise RuntimeError('Already connected') + + self.state = STATE_CONNECTING + self.close_requested = False + remove_hass_stop_listener = None session = async_get_clientsession(self.cloud.hass) + client = None + disconnect_warn = None @asyncio.coroutine def _handle_hass_stop(event): @@ -54,8 +73,6 @@ class CloudIoT: remove_hass_stop_listener = None yield from self.disconnect() - client = None - disconnect_warn = None try: yield from hass.async_add_job(auth_api.check_token, self.cloud) @@ -70,13 +87,14 @@ class CloudIoT: EVENT_HOMEASSISTANT_STOP, _handle_hass_stop) _LOGGER.info('Connected') + self.state = STATE_CONNECTED while not client.closed: msg = yield from client.receive() if msg.type in (WSMsgType.ERROR, WSMsgType.CLOSED, WSMsgType.CLOSING): - disconnect_warn = 'Closed by server' + disconnect_warn = 'Connection cancelled.' break elif msg.type != WSMsgType.TEXT: @@ -144,20 +162,33 @@ class CloudIoT: self.client = None yield from client.close() - if not self.close_requested: + if self.close_requested: + self.state = STATE_DISCONNECTED + + else: + self.state = STATE_CONNECTING self.tries += 1 - # Sleep 0, 5, 10, 15 … up to 30 seconds between retries - yield from asyncio.sleep( - min(30, (self.tries - 1) * 5), loop=hass.loop) - - hass.async_add_job(self.connect()) + try: + # Sleep 0, 5, 10, 15 … up to 30 seconds between retries + self.retry_task = hass.async_add_job(asyncio.sleep( + min(30, (self.tries - 1) * 5), loop=hass.loop)) + yield from self.retry_task + self.retry_task = None + hass.async_add_job(self.connect()) + except asyncio.CancelledError: + # Happens if disconnect called + pass @asyncio.coroutine def disconnect(self): """Disconnect the client.""" self.close_requested = True - yield from self.client.close() + + if self.client is not None: + yield from self.client.close() + elif self.retry_task is not None: + self.retry_task.cancel() @asyncio.coroutine diff --git a/tests/components/cloud/test_auth_api.py b/tests/components/cloud/test_auth_api.py index d9f005fdcfa..20f9265a1c1 100644 --- a/tests/components/cloud/test_auth_api.py +++ b/tests/components/cloud/test_auth_api.py @@ -69,7 +69,6 @@ def test_login(mock_cognito): auth_api.login(cloud, 'user', 'pass') assert len(mock_cognito.authenticate.mock_calls) == 1 - assert cloud.email == 'user' assert cloud.id_token == 'test_id_token' assert cloud.access_token == 'test_access_token' assert cloud.refresh_token == 'test_refresh_token' diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 1090acb01e9..296baa3f143 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -3,9 +3,10 @@ import asyncio from unittest.mock import patch, MagicMock import pytest +from jose import jwt from homeassistant.bootstrap import async_setup_component -from homeassistant.components.cloud import DOMAIN, auth_api +from homeassistant.components.cloud import DOMAIN, auth_api, iot from tests.common import mock_coro @@ -23,7 +24,8 @@ def cloud_client(hass, test_client): 'relayer': 'relayer', } })) - return hass.loop.run_until_complete(test_client(hass.http.app)) + with patch('homeassistant.components.cloud.Cloud.write_user_info'): + yield hass.loop.run_until_complete(test_client(hass.http.app)) @pytest.fixture @@ -43,21 +45,35 @@ def test_account_view_no_account(cloud_client): @asyncio.coroutine def test_account_view(hass, cloud_client): """Test fetching account if no account available.""" - hass.data[DOMAIN].email = 'hello@home-assistant.io' + hass.data[DOMAIN].id_token = jwt.encode({ + 'email': 'hello@home-assistant.io', + 'custom:sub-exp': '2018-01-03' + }, 'test') + hass.data[DOMAIN].iot.state = iot.STATE_CONNECTED req = yield from cloud_client.get('/api/cloud/account') assert req.status == 200 result = yield from req.json() - assert result == {'email': 'hello@home-assistant.io'} + assert result == { + 'email': 'hello@home-assistant.io', + 'sub_exp': '2018-01-03', + 'cloud': iot.STATE_CONNECTED, + } @asyncio.coroutine -def test_login_view(hass, cloud_client): +def test_login_view(hass, cloud_client, mock_cognito): """Test logging in.""" - hass.data[DOMAIN].email = 'hello@home-assistant.io' + mock_cognito.id_token = jwt.encode({ + 'email': 'hello@home-assistant.io', + 'custom:sub-exp': '2018-01-03' + }, 'test') + mock_cognito.access_token = 'access_token' + mock_cognito.refresh_token = 'refresh_token' - with patch('homeassistant.components.cloud.iot.CloudIoT.connect'), \ - patch('homeassistant.components.cloud.' - 'auth_api.login') as mock_login: + with patch('homeassistant.components.cloud.iot.CloudIoT.' + 'connect') as mock_connect, \ + patch('homeassistant.components.cloud.auth_api._authenticate', + return_value=mock_cognito) as mock_auth: req = yield from cloud_client.post('/api/cloud/login', json={ 'email': 'my_username', 'password': 'my_password' @@ -65,9 +81,13 @@ def test_login_view(hass, cloud_client): assert req.status == 200 result = yield from req.json() - assert result == {'email': 'hello@home-assistant.io'} - assert len(mock_login.mock_calls) == 1 - cloud, result_user, result_pass = mock_login.mock_calls[0][1] + assert result['email'] == 'hello@home-assistant.io' + assert result['sub_exp'] == '2018-01-03' + + assert len(mock_connect.mock_calls) == 1 + + assert len(mock_auth.mock_calls) == 1 + cloud, result_user, result_pass = mock_auth.mock_calls[0][1] assert result_user == 'my_username' assert result_pass == 'my_password' diff --git a/tests/components/cloud/test_init.py b/tests/components/cloud/test_init.py index 1eb1051520f..c05fdabf465 100644 --- a/tests/components/cloud/test_init.py +++ b/tests/components/cloud/test_init.py @@ -3,9 +3,11 @@ import asyncio import json from unittest.mock import patch, MagicMock, mock_open +from jose import jwt import pytest from homeassistant.components import cloud +from homeassistant.util.dt import utcnow from tests.common import mock_coro @@ -72,7 +74,6 @@ def test_initialize_loads_info(mock_os, hass): """Test initialize will load info from config file.""" mock_os.path.isfile.return_value = True mopen = mock_open(read_data=json.dumps({ - 'email': 'test-email', 'id_token': 'test-id-token', 'access_token': 'test-access-token', 'refresh_token': 'test-refresh-token', @@ -85,7 +86,6 @@ def test_initialize_loads_info(mock_os, hass): with patch('homeassistant.components.cloud.open', mopen, create=True): yield from cl.initialize() - assert cl.email == 'test-email' assert cl.id_token == 'test-id-token' assert cl.access_token == 'test-access-token' assert cl.refresh_token == 'test-refresh-token' @@ -102,7 +102,6 @@ def test_logout_clears_info(mock_os, hass): yield from cl.logout() assert len(cl.iot.disconnect.mock_calls) == 1 - assert cl.email is None assert cl.id_token is None assert cl.access_token is None assert cl.refresh_token is None @@ -115,7 +114,6 @@ def test_write_user_info(): mopen = mock_open() cl = cloud.Cloud(MagicMock(), cloud.MODE_DEV) - cl.email = 'test-email' cl.id_token = 'test-id-token' cl.access_token = 'test-access-token' cl.refresh_token = 'test-refresh-token' @@ -129,7 +127,41 @@ def test_write_user_info(): data = json.loads(handle.write.mock_calls[0][1][0]) assert data == { 'access_token': 'test-access-token', - 'email': 'test-email', 'id_token': 'test-id-token', 'refresh_token': 'test-refresh-token', } + + +@asyncio.coroutine +def test_subscription_not_expired_without_sub_in_claim(): + """Test that we do not enforce subscriptions yet.""" + cl = cloud.Cloud(None, cloud.MODE_DEV) + cl.id_token = jwt.encode({}, 'test') + + assert not cl.subscription_expired + + +@asyncio.coroutine +def test_subscription_expired(): + """Test subscription being expired.""" + cl = cloud.Cloud(None, cloud.MODE_DEV) + cl.id_token = jwt.encode({ + 'custom:sub-exp': '2017-11-13' + }, 'test') + + with patch('homeassistant.util.dt.utcnow', + return_value=utcnow().replace(year=2018)): + assert cl.subscription_expired + + +@asyncio.coroutine +def test_subscription_not_expired(): + """Test subscription not being expired.""" + cl = cloud.Cloud(None, cloud.MODE_DEV) + cl.id_token = jwt.encode({ + 'custom:sub-exp': '2017-11-13' + }, 'test') + + with patch('homeassistant.util.dt.utcnow', + return_value=utcnow().replace(year=2017, month=11, day=9)): + assert not cl.subscription_expired diff --git a/tests/components/cloud/test_iot.py b/tests/components/cloud/test_iot.py index f1254cdb3c7..be5a93c9e47 100644 --- a/tests/components/cloud/test_iot.py +++ b/tests/components/cloud/test_iot.py @@ -30,11 +30,16 @@ def mock_handle_message(): yield mock +@pytest.fixture +def mock_cloud(): + """Mock cloud class.""" + return MagicMock(subscription_expired=False) + + @asyncio.coroutine -def test_cloud_calling_handler(mock_client, mock_handle_message): +def test_cloud_calling_handler(mock_client, mock_handle_message, mock_cloud): """Test we call handle message with correct info.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.return_value = mock_coro(MagicMock( type=WSMsgType.text, json=MagicMock(return_value={ @@ -53,8 +58,8 @@ def test_cloud_calling_handler(mock_client, mock_handle_message): p_hass, p_cloud, handler_name, payload = \ mock_handle_message.mock_calls[0][1] - assert p_hass is cloud.hass - assert p_cloud is cloud + assert p_hass is mock_cloud.hass + assert p_cloud is mock_cloud assert handler_name == 'test-handler' assert payload == 'test-payload' @@ -67,10 +72,9 @@ def test_cloud_calling_handler(mock_client, mock_handle_message): @asyncio.coroutine -def test_connection_msg_for_unknown_handler(mock_client): +def test_connection_msg_for_unknown_handler(mock_client, mock_cloud): """Test a msg for an unknown handler.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.return_value = mock_coro(MagicMock( type=WSMsgType.text, json=MagicMock(return_value={ @@ -92,10 +96,10 @@ def test_connection_msg_for_unknown_handler(mock_client): @asyncio.coroutine -def test_connection_msg_for_handler_raising(mock_client, mock_handle_message): +def test_connection_msg_for_handler_raising(mock_client, mock_handle_message, + mock_cloud): """Test we sent error when handler raises exception.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.return_value = mock_coro(MagicMock( type=WSMsgType.text, json=MagicMock(return_value={ @@ -136,37 +140,34 @@ def test_handler_forwarding(): @asyncio.coroutine -def test_handling_core_messages(hass): +def test_handling_core_messages(hass, mock_cloud): """Test handling core messages.""" - cloud = MagicMock() - cloud.logout.return_value = mock_coro() - yield from iot.async_handle_cloud(hass, cloud, { + mock_cloud.logout.return_value = mock_coro() + yield from iot.async_handle_cloud(hass, mock_cloud, { 'action': 'logout', 'reason': 'Logged in at two places.' }) - assert len(cloud.logout.mock_calls) == 1 + assert len(mock_cloud.logout.mock_calls) == 1 @asyncio.coroutine -def test_cloud_getting_disconnected_by_server(mock_client, caplog): +def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud): """Test server disconnecting instance.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.return_value = mock_coro(MagicMock( type=WSMsgType.CLOSING, )) yield from conn.connect() - assert 'Connection closed: Closed by server' in caplog.text - assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0]) + assert 'Connection closed: Connection cancelled.' in caplog.text + assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0]) @asyncio.coroutine -def test_cloud_receiving_bytes(mock_client, caplog): +def test_cloud_receiving_bytes(mock_client, caplog, mock_cloud): """Test server disconnecting instance.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.return_value = mock_coro(MagicMock( type=WSMsgType.BINARY, )) @@ -174,14 +175,13 @@ def test_cloud_receiving_bytes(mock_client, caplog): yield from conn.connect() assert 'Connection closed: Received non-Text message' in caplog.text - assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0]) + assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0]) @asyncio.coroutine -def test_cloud_sending_invalid_json(mock_client, caplog): +def test_cloud_sending_invalid_json(mock_client, caplog, mock_cloud): """Test cloud sending invalid JSON.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.return_value = mock_coro(MagicMock( type=WSMsgType.TEXT, json=MagicMock(side_effect=ValueError) @@ -190,27 +190,25 @@ def test_cloud_sending_invalid_json(mock_client, caplog): yield from conn.connect() assert 'Connection closed: Received invalid JSON.' in caplog.text - assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0]) + assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0]) @asyncio.coroutine -def test_cloud_check_token_raising(mock_client, caplog): +def test_cloud_check_token_raising(mock_client, caplog, mock_cloud): """Test cloud sending invalid JSON.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.side_effect = auth_api.CloudError yield from conn.connect() assert 'Unable to connect: Unable to refresh token.' in caplog.text - assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0]) + assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0]) @asyncio.coroutine -def test_cloud_connect_invalid_auth(mock_client, caplog): +def test_cloud_connect_invalid_auth(mock_client, caplog, mock_cloud): """Test invalid auth detected by server.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.side_effect = \ client_exceptions.WSServerHandshakeError(None, None, code=401) @@ -220,10 +218,9 @@ def test_cloud_connect_invalid_auth(mock_client, caplog): @asyncio.coroutine -def test_cloud_unable_to_connect(mock_client, caplog): +def test_cloud_unable_to_connect(mock_client, caplog, mock_cloud): """Test unable to connect error.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.side_effect = client_exceptions.ClientError(None, None) yield from conn.connect() @@ -232,12 +229,28 @@ def test_cloud_unable_to_connect(mock_client, caplog): @asyncio.coroutine -def test_cloud_random_exception(mock_client, caplog): +def test_cloud_random_exception(mock_client, caplog, mock_cloud): """Test random exception.""" - cloud = MagicMock() - conn = iot.CloudIoT(cloud) + conn = iot.CloudIoT(mock_cloud) mock_client.receive.side_effect = Exception yield from conn.connect() assert 'Unexpected error' in caplog.text + + +@asyncio.coroutine +def test_refresh_token_before_expiration_fails(hass, mock_cloud): + """Test that we don't connect if token is expired.""" + mock_cloud.subscription_expired = True + mock_cloud.hass = hass + conn = iot.CloudIoT(mock_cloud) + + with patch('homeassistant.components.cloud.auth_api.check_token', + return_value=mock_coro()) as mock_check_token, \ + patch.object(hass.components.persistent_notification, + 'async_create') as mock_create: + yield from conn.connect() + + assert len(mock_check_token.mock_calls) == 1 + assert len(mock_create.mock_calls) == 1