Cloud updates (#10567)

* Update cloud

* Fix tests

* Lint
This commit is contained in:
Paulus Schoutsen 2017-11-14 23:16:19 -08:00 committed by Pascal Vizeli
parent 0cd3271dfa
commit ea7ffff0ca
9 changed files with 219 additions and 91 deletions

View file

@ -1,5 +1,6 @@
"""Component to integrate the Home Assistant cloud."""
import asyncio
from datetime import datetime
import json
import logging
import os
@ -8,6 +9,7 @@ import voluptuous as vol
from homeassistant.const import (
EVENT_HOMEASSISTANT_START, CONF_REGION, CONF_MODE)
from homeassistant.util import dt as dt_util
from . import http_api, iot
from .const import CONFIG_DIR, DOMAIN, SERVERS
@ -66,7 +68,6 @@ class Cloud:
"""Create an instance of Cloud."""
self.hass = hass
self.mode = mode
self.email = None
self.id_token = None
self.access_token = None
self.refresh_token = None
@ -89,7 +90,29 @@ class Cloud:
@property
def is_logged_in(self):
"""Get if cloud is logged in."""
return self.email is not None
return self.id_token is not None
@property
def subscription_expired(self):
"""Return a boolen if the subscription has expired."""
# For now, don't enforce subscriptions to exist
if 'custom:sub-exp' not in self.claims:
return False
return dt_util.utcnow() > self.expiration_date
@property
def expiration_date(self):
"""Return the subscription expiration as a UTC datetime object."""
return datetime.combine(
dt_util.parse_date(self.claims['custom:sub-exp']),
datetime.min.time()).replace(tzinfo=dt_util.UTC)
@property
def claims(self):
"""Get the claims from the id token."""
from jose import jwt
return jwt.get_unverified_claims(self.id_token)
@property
def user_info_path(self):
@ -110,18 +133,20 @@ class Cloud:
if os.path.isfile(user_info):
with open(user_info, 'rt') as file:
info = json.loads(file.read())
self.email = info['email']
self.id_token = info['id_token']
self.access_token = info['access_token']
self.refresh_token = info['refresh_token']
yield from self.hass.async_add_job(load_config)
if self.email is not None:
if self.id_token is not None:
yield from self.iot.connect()
def path(self, *parts):
"""Get config path inside cloud dir."""
"""Get config path inside cloud dir.
Async friendly.
"""
return self.hass.config.path(CONFIG_DIR, *parts)
@asyncio.coroutine
@ -129,7 +154,6 @@ class Cloud:
"""Close connection and remove all credentials."""
yield from self.iot.disconnect()
self.email = None
self.id_token = None
self.access_token = None
self.refresh_token = None
@ -141,7 +165,6 @@ class Cloud:
"""Write user info to a file."""
with open(self.user_info_path, 'wt') as file:
file.write(json.dumps({
'email': self.email,
'id_token': self.id_token,
'access_token': self.access_token,
'refresh_token': self.refresh_token,

View file

@ -113,7 +113,6 @@ def login(cloud, email, password):
cloud.id_token = cognito.id_token
cloud.access_token = cognito.access_token
cloud.refresh_token = cognito.refresh_token
cloud.email = email
cloud.write_user_info()

View file

@ -12,3 +12,8 @@ SERVERS = {
# 'relayer': ''
# }
}
MESSAGE_EXPIRATION = """
It looks like your Home Assistant Cloud subscription has expired. Please check
your [account page](/config/cloud/account) to continue using the service.
"""

View file

@ -79,8 +79,10 @@ class CloudLoginView(HomeAssistantView):
with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop):
yield from hass.async_add_job(auth_api.login, cloud, data['email'],
data['password'])
hass.async_add_job(cloud.iot.connect)
hass.async_add_job(cloud.iot.connect)
# Allow cloud to start connecting.
yield from asyncio.sleep(0, loop=hass.loop)
return self.json(_account_data(cloud))
@ -222,6 +224,10 @@ class CloudConfirmForgotPasswordView(HomeAssistantView):
def _account_data(cloud):
"""Generate the auth data JSON response."""
claims = cloud.claims
return {
'email': cloud.email
'email': claims['email'],
'sub_exp': claims.get('custom:sub-exp'),
'cloud': cloud.iot.state,
}

View file

@ -9,11 +9,16 @@ from homeassistant.components.alexa import smart_home
from homeassistant.util.decorator import Registry
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from . import auth_api
from .const import MESSAGE_EXPIRATION
HANDLERS = Registry()
_LOGGER = logging.getLogger(__name__)
STATE_CONNECTING = 'connecting'
STATE_CONNECTED = 'connected'
STATE_DISCONNECTED = 'disconnected'
class UnknownHandler(Exception):
"""Exception raised when trying to handle unknown handler."""
@ -25,27 +30,41 @@ class CloudIoT:
def __init__(self, cloud):
"""Initialize the CloudIoT class."""
self.cloud = cloud
# The WebSocket client
self.client = None
# Scheduled sleep task till next connection retry
self.retry_task = None
# Boolean to indicate if we wanted the connection to close
self.close_requested = False
# The current number of attempts to connect, impacts wait time
self.tries = 0
@property
def is_connected(self):
"""Return if connected to the cloud."""
return self.client is not None
# Current state of the connection
self.state = STATE_DISCONNECTED
@asyncio.coroutine
def connect(self):
"""Connect to the IoT broker."""
if self.client is not None:
raise RuntimeError('Cannot connect while already connected')
self.close_requested = False
hass = self.cloud.hass
remove_hass_stop_listener = None
if self.cloud.subscription_expired:
# Try refreshing the token to see if it is still expired.
yield from hass.async_add_job(auth_api.check_token, self.cloud)
if self.cloud.subscription_expired:
hass.components.persistent_notification.async_create(
MESSAGE_EXPIRATION, 'Subscription expired',
'cloud_subscription_expired')
self.state = STATE_DISCONNECTED
return
if self.state == STATE_CONNECTED:
raise RuntimeError('Already connected')
self.state = STATE_CONNECTING
self.close_requested = False
remove_hass_stop_listener = None
session = async_get_clientsession(self.cloud.hass)
client = None
disconnect_warn = None
@asyncio.coroutine
def _handle_hass_stop(event):
@ -54,8 +73,6 @@ class CloudIoT:
remove_hass_stop_listener = None
yield from self.disconnect()
client = None
disconnect_warn = None
try:
yield from hass.async_add_job(auth_api.check_token, self.cloud)
@ -70,13 +87,14 @@ class CloudIoT:
EVENT_HOMEASSISTANT_STOP, _handle_hass_stop)
_LOGGER.info('Connected')
self.state = STATE_CONNECTED
while not client.closed:
msg = yield from client.receive()
if msg.type in (WSMsgType.ERROR, WSMsgType.CLOSED,
WSMsgType.CLOSING):
disconnect_warn = 'Closed by server'
disconnect_warn = 'Connection cancelled.'
break
elif msg.type != WSMsgType.TEXT:
@ -144,20 +162,33 @@ class CloudIoT:
self.client = None
yield from client.close()
if not self.close_requested:
if self.close_requested:
self.state = STATE_DISCONNECTED
else:
self.state = STATE_CONNECTING
self.tries += 1
try:
# Sleep 0, 5, 10, 15 … up to 30 seconds between retries
yield from asyncio.sleep(
min(30, (self.tries - 1) * 5), loop=hass.loop)
self.retry_task = hass.async_add_job(asyncio.sleep(
min(30, (self.tries - 1) * 5), loop=hass.loop))
yield from self.retry_task
self.retry_task = None
hass.async_add_job(self.connect())
except asyncio.CancelledError:
# Happens if disconnect called
pass
@asyncio.coroutine
def disconnect(self):
"""Disconnect the client."""
self.close_requested = True
if self.client is not None:
yield from self.client.close()
elif self.retry_task is not None:
self.retry_task.cancel()
@asyncio.coroutine

View file

@ -69,7 +69,6 @@ def test_login(mock_cognito):
auth_api.login(cloud, 'user', 'pass')
assert len(mock_cognito.authenticate.mock_calls) == 1
assert cloud.email == 'user'
assert cloud.id_token == 'test_id_token'
assert cloud.access_token == 'test_access_token'
assert cloud.refresh_token == 'test_refresh_token'

View file

@ -3,9 +3,10 @@ import asyncio
from unittest.mock import patch, MagicMock
import pytest
from jose import jwt
from homeassistant.bootstrap import async_setup_component
from homeassistant.components.cloud import DOMAIN, auth_api
from homeassistant.components.cloud import DOMAIN, auth_api, iot
from tests.common import mock_coro
@ -23,7 +24,8 @@ def cloud_client(hass, test_client):
'relayer': 'relayer',
}
}))
return hass.loop.run_until_complete(test_client(hass.http.app))
with patch('homeassistant.components.cloud.Cloud.write_user_info'):
yield hass.loop.run_until_complete(test_client(hass.http.app))
@pytest.fixture
@ -43,21 +45,35 @@ def test_account_view_no_account(cloud_client):
@asyncio.coroutine
def test_account_view(hass, cloud_client):
"""Test fetching account if no account available."""
hass.data[DOMAIN].email = 'hello@home-assistant.io'
hass.data[DOMAIN].id_token = jwt.encode({
'email': 'hello@home-assistant.io',
'custom:sub-exp': '2018-01-03'
}, 'test')
hass.data[DOMAIN].iot.state = iot.STATE_CONNECTED
req = yield from cloud_client.get('/api/cloud/account')
assert req.status == 200
result = yield from req.json()
assert result == {'email': 'hello@home-assistant.io'}
assert result == {
'email': 'hello@home-assistant.io',
'sub_exp': '2018-01-03',
'cloud': iot.STATE_CONNECTED,
}
@asyncio.coroutine
def test_login_view(hass, cloud_client):
def test_login_view(hass, cloud_client, mock_cognito):
"""Test logging in."""
hass.data[DOMAIN].email = 'hello@home-assistant.io'
mock_cognito.id_token = jwt.encode({
'email': 'hello@home-assistant.io',
'custom:sub-exp': '2018-01-03'
}, 'test')
mock_cognito.access_token = 'access_token'
mock_cognito.refresh_token = 'refresh_token'
with patch('homeassistant.components.cloud.iot.CloudIoT.connect'), \
patch('homeassistant.components.cloud.'
'auth_api.login') as mock_login:
with patch('homeassistant.components.cloud.iot.CloudIoT.'
'connect') as mock_connect, \
patch('homeassistant.components.cloud.auth_api._authenticate',
return_value=mock_cognito) as mock_auth:
req = yield from cloud_client.post('/api/cloud/login', json={
'email': 'my_username',
'password': 'my_password'
@ -65,9 +81,13 @@ def test_login_view(hass, cloud_client):
assert req.status == 200
result = yield from req.json()
assert result == {'email': 'hello@home-assistant.io'}
assert len(mock_login.mock_calls) == 1
cloud, result_user, result_pass = mock_login.mock_calls[0][1]
assert result['email'] == 'hello@home-assistant.io'
assert result['sub_exp'] == '2018-01-03'
assert len(mock_connect.mock_calls) == 1
assert len(mock_auth.mock_calls) == 1
cloud, result_user, result_pass = mock_auth.mock_calls[0][1]
assert result_user == 'my_username'
assert result_pass == 'my_password'

View file

@ -3,9 +3,11 @@ import asyncio
import json
from unittest.mock import patch, MagicMock, mock_open
from jose import jwt
import pytest
from homeassistant.components import cloud
from homeassistant.util.dt import utcnow
from tests.common import mock_coro
@ -72,7 +74,6 @@ def test_initialize_loads_info(mock_os, hass):
"""Test initialize will load info from config file."""
mock_os.path.isfile.return_value = True
mopen = mock_open(read_data=json.dumps({
'email': 'test-email',
'id_token': 'test-id-token',
'access_token': 'test-access-token',
'refresh_token': 'test-refresh-token',
@ -85,7 +86,6 @@ def test_initialize_loads_info(mock_os, hass):
with patch('homeassistant.components.cloud.open', mopen, create=True):
yield from cl.initialize()
assert cl.email == 'test-email'
assert cl.id_token == 'test-id-token'
assert cl.access_token == 'test-access-token'
assert cl.refresh_token == 'test-refresh-token'
@ -102,7 +102,6 @@ def test_logout_clears_info(mock_os, hass):
yield from cl.logout()
assert len(cl.iot.disconnect.mock_calls) == 1
assert cl.email is None
assert cl.id_token is None
assert cl.access_token is None
assert cl.refresh_token is None
@ -115,7 +114,6 @@ def test_write_user_info():
mopen = mock_open()
cl = cloud.Cloud(MagicMock(), cloud.MODE_DEV)
cl.email = 'test-email'
cl.id_token = 'test-id-token'
cl.access_token = 'test-access-token'
cl.refresh_token = 'test-refresh-token'
@ -129,7 +127,41 @@ def test_write_user_info():
data = json.loads(handle.write.mock_calls[0][1][0])
assert data == {
'access_token': 'test-access-token',
'email': 'test-email',
'id_token': 'test-id-token',
'refresh_token': 'test-refresh-token',
}
@asyncio.coroutine
def test_subscription_not_expired_without_sub_in_claim():
"""Test that we do not enforce subscriptions yet."""
cl = cloud.Cloud(None, cloud.MODE_DEV)
cl.id_token = jwt.encode({}, 'test')
assert not cl.subscription_expired
@asyncio.coroutine
def test_subscription_expired():
"""Test subscription being expired."""
cl = cloud.Cloud(None, cloud.MODE_DEV)
cl.id_token = jwt.encode({
'custom:sub-exp': '2017-11-13'
}, 'test')
with patch('homeassistant.util.dt.utcnow',
return_value=utcnow().replace(year=2018)):
assert cl.subscription_expired
@asyncio.coroutine
def test_subscription_not_expired():
"""Test subscription not being expired."""
cl = cloud.Cloud(None, cloud.MODE_DEV)
cl.id_token = jwt.encode({
'custom:sub-exp': '2017-11-13'
}, 'test')
with patch('homeassistant.util.dt.utcnow',
return_value=utcnow().replace(year=2017, month=11, day=9)):
assert not cl.subscription_expired

View file

@ -30,11 +30,16 @@ def mock_handle_message():
yield mock
@pytest.fixture
def mock_cloud():
"""Mock cloud class."""
return MagicMock(subscription_expired=False)
@asyncio.coroutine
def test_cloud_calling_handler(mock_client, mock_handle_message):
def test_cloud_calling_handler(mock_client, mock_handle_message, mock_cloud):
"""Test we call handle message with correct info."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.return_value = mock_coro(MagicMock(
type=WSMsgType.text,
json=MagicMock(return_value={
@ -53,8 +58,8 @@ def test_cloud_calling_handler(mock_client, mock_handle_message):
p_hass, p_cloud, handler_name, payload = \
mock_handle_message.mock_calls[0][1]
assert p_hass is cloud.hass
assert p_cloud is cloud
assert p_hass is mock_cloud.hass
assert p_cloud is mock_cloud
assert handler_name == 'test-handler'
assert payload == 'test-payload'
@ -67,10 +72,9 @@ def test_cloud_calling_handler(mock_client, mock_handle_message):
@asyncio.coroutine
def test_connection_msg_for_unknown_handler(mock_client):
def test_connection_msg_for_unknown_handler(mock_client, mock_cloud):
"""Test a msg for an unknown handler."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.return_value = mock_coro(MagicMock(
type=WSMsgType.text,
json=MagicMock(return_value={
@ -92,10 +96,10 @@ def test_connection_msg_for_unknown_handler(mock_client):
@asyncio.coroutine
def test_connection_msg_for_handler_raising(mock_client, mock_handle_message):
def test_connection_msg_for_handler_raising(mock_client, mock_handle_message,
mock_cloud):
"""Test we sent error when handler raises exception."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.return_value = mock_coro(MagicMock(
type=WSMsgType.text,
json=MagicMock(return_value={
@ -136,37 +140,34 @@ def test_handler_forwarding():
@asyncio.coroutine
def test_handling_core_messages(hass):
def test_handling_core_messages(hass, mock_cloud):
"""Test handling core messages."""
cloud = MagicMock()
cloud.logout.return_value = mock_coro()
yield from iot.async_handle_cloud(hass, cloud, {
mock_cloud.logout.return_value = mock_coro()
yield from iot.async_handle_cloud(hass, mock_cloud, {
'action': 'logout',
'reason': 'Logged in at two places.'
})
assert len(cloud.logout.mock_calls) == 1
assert len(mock_cloud.logout.mock_calls) == 1
@asyncio.coroutine
def test_cloud_getting_disconnected_by_server(mock_client, caplog):
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
"""Test server disconnecting instance."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.return_value = mock_coro(MagicMock(
type=WSMsgType.CLOSING,
))
yield from conn.connect()
assert 'Connection closed: Closed by server' in caplog.text
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
assert 'Connection closed: Connection cancelled.' in caplog.text
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
@asyncio.coroutine
def test_cloud_receiving_bytes(mock_client, caplog):
def test_cloud_receiving_bytes(mock_client, caplog, mock_cloud):
"""Test server disconnecting instance."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.return_value = mock_coro(MagicMock(
type=WSMsgType.BINARY,
))
@ -174,14 +175,13 @@ def test_cloud_receiving_bytes(mock_client, caplog):
yield from conn.connect()
assert 'Connection closed: Received non-Text message' in caplog.text
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
@asyncio.coroutine
def test_cloud_sending_invalid_json(mock_client, caplog):
def test_cloud_sending_invalid_json(mock_client, caplog, mock_cloud):
"""Test cloud sending invalid JSON."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.return_value = mock_coro(MagicMock(
type=WSMsgType.TEXT,
json=MagicMock(side_effect=ValueError)
@ -190,27 +190,25 @@ def test_cloud_sending_invalid_json(mock_client, caplog):
yield from conn.connect()
assert 'Connection closed: Received invalid JSON.' in caplog.text
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
@asyncio.coroutine
def test_cloud_check_token_raising(mock_client, caplog):
def test_cloud_check_token_raising(mock_client, caplog, mock_cloud):
"""Test cloud sending invalid JSON."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.side_effect = auth_api.CloudError
yield from conn.connect()
assert 'Unable to connect: Unable to refresh token.' in caplog.text
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
@asyncio.coroutine
def test_cloud_connect_invalid_auth(mock_client, caplog):
def test_cloud_connect_invalid_auth(mock_client, caplog, mock_cloud):
"""Test invalid auth detected by server."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.side_effect = \
client_exceptions.WSServerHandshakeError(None, None, code=401)
@ -220,10 +218,9 @@ def test_cloud_connect_invalid_auth(mock_client, caplog):
@asyncio.coroutine
def test_cloud_unable_to_connect(mock_client, caplog):
def test_cloud_unable_to_connect(mock_client, caplog, mock_cloud):
"""Test unable to connect error."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.side_effect = client_exceptions.ClientError(None, None)
yield from conn.connect()
@ -232,12 +229,28 @@ def test_cloud_unable_to_connect(mock_client, caplog):
@asyncio.coroutine
def test_cloud_random_exception(mock_client, caplog):
def test_cloud_random_exception(mock_client, caplog, mock_cloud):
"""Test random exception."""
cloud = MagicMock()
conn = iot.CloudIoT(cloud)
conn = iot.CloudIoT(mock_cloud)
mock_client.receive.side_effect = Exception
yield from conn.connect()
assert 'Unexpected error' in caplog.text
@asyncio.coroutine
def test_refresh_token_before_expiration_fails(hass, mock_cloud):
"""Test that we don't connect if token is expired."""
mock_cloud.subscription_expired = True
mock_cloud.hass = hass
conn = iot.CloudIoT(mock_cloud)
with patch('homeassistant.components.cloud.auth_api.check_token',
return_value=mock_coro()) as mock_check_token, \
patch.object(hass.components.persistent_notification,
'async_create') as mock_create:
yield from conn.connect()
assert len(mock_check_token.mock_calls) == 1
assert len(mock_create.mock_calls) == 1