parent
0cd3271dfa
commit
ea7ffff0ca
9 changed files with 219 additions and 91 deletions
|
@ -1,5 +1,6 @@
|
||||||
"""Component to integrate the Home Assistant cloud."""
|
"""Component to integrate the Home Assistant cloud."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -8,6 +9,7 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
EVENT_HOMEASSISTANT_START, CONF_REGION, CONF_MODE)
|
EVENT_HOMEASSISTANT_START, CONF_REGION, CONF_MODE)
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import http_api, iot
|
from . import http_api, iot
|
||||||
from .const import CONFIG_DIR, DOMAIN, SERVERS
|
from .const import CONFIG_DIR, DOMAIN, SERVERS
|
||||||
|
@ -66,7 +68,6 @@ class Cloud:
|
||||||
"""Create an instance of Cloud."""
|
"""Create an instance of Cloud."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.email = None
|
|
||||||
self.id_token = None
|
self.id_token = None
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
self.refresh_token = None
|
self.refresh_token = None
|
||||||
|
@ -89,7 +90,29 @@ class Cloud:
|
||||||
@property
|
@property
|
||||||
def is_logged_in(self):
|
def is_logged_in(self):
|
||||||
"""Get if cloud is logged in."""
|
"""Get if cloud is logged in."""
|
||||||
return self.email is not None
|
return self.id_token is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subscription_expired(self):
|
||||||
|
"""Return a boolen if the subscription has expired."""
|
||||||
|
# For now, don't enforce subscriptions to exist
|
||||||
|
if 'custom:sub-exp' not in self.claims:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return dt_util.utcnow() > self.expiration_date
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expiration_date(self):
|
||||||
|
"""Return the subscription expiration as a UTC datetime object."""
|
||||||
|
return datetime.combine(
|
||||||
|
dt_util.parse_date(self.claims['custom:sub-exp']),
|
||||||
|
datetime.min.time()).replace(tzinfo=dt_util.UTC)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def claims(self):
|
||||||
|
"""Get the claims from the id token."""
|
||||||
|
from jose import jwt
|
||||||
|
return jwt.get_unverified_claims(self.id_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_info_path(self):
|
def user_info_path(self):
|
||||||
|
@ -110,18 +133,20 @@ class Cloud:
|
||||||
if os.path.isfile(user_info):
|
if os.path.isfile(user_info):
|
||||||
with open(user_info, 'rt') as file:
|
with open(user_info, 'rt') as file:
|
||||||
info = json.loads(file.read())
|
info = json.loads(file.read())
|
||||||
self.email = info['email']
|
|
||||||
self.id_token = info['id_token']
|
self.id_token = info['id_token']
|
||||||
self.access_token = info['access_token']
|
self.access_token = info['access_token']
|
||||||
self.refresh_token = info['refresh_token']
|
self.refresh_token = info['refresh_token']
|
||||||
|
|
||||||
yield from self.hass.async_add_job(load_config)
|
yield from self.hass.async_add_job(load_config)
|
||||||
|
|
||||||
if self.email is not None:
|
if self.id_token is not None:
|
||||||
yield from self.iot.connect()
|
yield from self.iot.connect()
|
||||||
|
|
||||||
def path(self, *parts):
|
def path(self, *parts):
|
||||||
"""Get config path inside cloud dir."""
|
"""Get config path inside cloud dir.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
return self.hass.config.path(CONFIG_DIR, *parts)
|
return self.hass.config.path(CONFIG_DIR, *parts)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -129,7 +154,6 @@ class Cloud:
|
||||||
"""Close connection and remove all credentials."""
|
"""Close connection and remove all credentials."""
|
||||||
yield from self.iot.disconnect()
|
yield from self.iot.disconnect()
|
||||||
|
|
||||||
self.email = None
|
|
||||||
self.id_token = None
|
self.id_token = None
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
self.refresh_token = None
|
self.refresh_token = None
|
||||||
|
@ -141,7 +165,6 @@ class Cloud:
|
||||||
"""Write user info to a file."""
|
"""Write user info to a file."""
|
||||||
with open(self.user_info_path, 'wt') as file:
|
with open(self.user_info_path, 'wt') as file:
|
||||||
file.write(json.dumps({
|
file.write(json.dumps({
|
||||||
'email': self.email,
|
|
||||||
'id_token': self.id_token,
|
'id_token': self.id_token,
|
||||||
'access_token': self.access_token,
|
'access_token': self.access_token,
|
||||||
'refresh_token': self.refresh_token,
|
'refresh_token': self.refresh_token,
|
||||||
|
|
|
@ -113,7 +113,6 @@ def login(cloud, email, password):
|
||||||
cloud.id_token = cognito.id_token
|
cloud.id_token = cognito.id_token
|
||||||
cloud.access_token = cognito.access_token
|
cloud.access_token = cognito.access_token
|
||||||
cloud.refresh_token = cognito.refresh_token
|
cloud.refresh_token = cognito.refresh_token
|
||||||
cloud.email = email
|
|
||||||
cloud.write_user_info()
|
cloud.write_user_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,3 +12,8 @@ SERVERS = {
|
||||||
# 'relayer': ''
|
# 'relayer': ''
|
||||||
# }
|
# }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MESSAGE_EXPIRATION = """
|
||||||
|
It looks like your Home Assistant Cloud subscription has expired. Please check
|
||||||
|
your [account page](/config/cloud/account) to continue using the service.
|
||||||
|
"""
|
||||||
|
|
|
@ -79,8 +79,10 @@ class CloudLoginView(HomeAssistantView):
|
||||||
with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop):
|
with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop):
|
||||||
yield from hass.async_add_job(auth_api.login, cloud, data['email'],
|
yield from hass.async_add_job(auth_api.login, cloud, data['email'],
|
||||||
data['password'])
|
data['password'])
|
||||||
hass.async_add_job(cloud.iot.connect)
|
|
||||||
|
|
||||||
|
hass.async_add_job(cloud.iot.connect)
|
||||||
|
# Allow cloud to start connecting.
|
||||||
|
yield from asyncio.sleep(0, loop=hass.loop)
|
||||||
return self.json(_account_data(cloud))
|
return self.json(_account_data(cloud))
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,6 +224,10 @@ class CloudConfirmForgotPasswordView(HomeAssistantView):
|
||||||
|
|
||||||
def _account_data(cloud):
|
def _account_data(cloud):
|
||||||
"""Generate the auth data JSON response."""
|
"""Generate the auth data JSON response."""
|
||||||
|
claims = cloud.claims
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'email': cloud.email
|
'email': claims['email'],
|
||||||
|
'sub_exp': claims.get('custom:sub-exp'),
|
||||||
|
'cloud': cloud.iot.state,
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,11 +9,16 @@ from homeassistant.components.alexa import smart_home
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.util.decorator import Registry
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from . import auth_api
|
from . import auth_api
|
||||||
|
from .const import MESSAGE_EXPIRATION
|
||||||
|
|
||||||
|
|
||||||
HANDLERS = Registry()
|
HANDLERS = Registry()
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STATE_CONNECTING = 'connecting'
|
||||||
|
STATE_CONNECTED = 'connected'
|
||||||
|
STATE_DISCONNECTED = 'disconnected'
|
||||||
|
|
||||||
|
|
||||||
class UnknownHandler(Exception):
|
class UnknownHandler(Exception):
|
||||||
"""Exception raised when trying to handle unknown handler."""
|
"""Exception raised when trying to handle unknown handler."""
|
||||||
|
@ -25,27 +30,41 @@ class CloudIoT:
|
||||||
def __init__(self, cloud):
|
def __init__(self, cloud):
|
||||||
"""Initialize the CloudIoT class."""
|
"""Initialize the CloudIoT class."""
|
||||||
self.cloud = cloud
|
self.cloud = cloud
|
||||||
|
# The WebSocket client
|
||||||
self.client = None
|
self.client = None
|
||||||
|
# Scheduled sleep task till next connection retry
|
||||||
|
self.retry_task = None
|
||||||
|
# Boolean to indicate if we wanted the connection to close
|
||||||
self.close_requested = False
|
self.close_requested = False
|
||||||
|
# The current number of attempts to connect, impacts wait time
|
||||||
self.tries = 0
|
self.tries = 0
|
||||||
|
# Current state of the connection
|
||||||
@property
|
self.state = STATE_DISCONNECTED
|
||||||
def is_connected(self):
|
|
||||||
"""Return if connected to the cloud."""
|
|
||||||
return self.client is not None
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def connect(self):
|
def connect(self):
|
||||||
"""Connect to the IoT broker."""
|
"""Connect to the IoT broker."""
|
||||||
if self.client is not None:
|
|
||||||
raise RuntimeError('Cannot connect while already connected')
|
|
||||||
|
|
||||||
self.close_requested = False
|
|
||||||
|
|
||||||
hass = self.cloud.hass
|
hass = self.cloud.hass
|
||||||
remove_hass_stop_listener = None
|
if self.cloud.subscription_expired:
|
||||||
|
# Try refreshing the token to see if it is still expired.
|
||||||
|
yield from hass.async_add_job(auth_api.check_token, self.cloud)
|
||||||
|
|
||||||
|
if self.cloud.subscription_expired:
|
||||||
|
hass.components.persistent_notification.async_create(
|
||||||
|
MESSAGE_EXPIRATION, 'Subscription expired',
|
||||||
|
'cloud_subscription_expired')
|
||||||
|
self.state = STATE_DISCONNECTED
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.state == STATE_CONNECTED:
|
||||||
|
raise RuntimeError('Already connected')
|
||||||
|
|
||||||
|
self.state = STATE_CONNECTING
|
||||||
|
self.close_requested = False
|
||||||
|
remove_hass_stop_listener = None
|
||||||
session = async_get_clientsession(self.cloud.hass)
|
session = async_get_clientsession(self.cloud.hass)
|
||||||
|
client = None
|
||||||
|
disconnect_warn = None
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _handle_hass_stop(event):
|
def _handle_hass_stop(event):
|
||||||
|
@ -54,8 +73,6 @@ class CloudIoT:
|
||||||
remove_hass_stop_listener = None
|
remove_hass_stop_listener = None
|
||||||
yield from self.disconnect()
|
yield from self.disconnect()
|
||||||
|
|
||||||
client = None
|
|
||||||
disconnect_warn = None
|
|
||||||
try:
|
try:
|
||||||
yield from hass.async_add_job(auth_api.check_token, self.cloud)
|
yield from hass.async_add_job(auth_api.check_token, self.cloud)
|
||||||
|
|
||||||
|
@ -70,13 +87,14 @@ class CloudIoT:
|
||||||
EVENT_HOMEASSISTANT_STOP, _handle_hass_stop)
|
EVENT_HOMEASSISTANT_STOP, _handle_hass_stop)
|
||||||
|
|
||||||
_LOGGER.info('Connected')
|
_LOGGER.info('Connected')
|
||||||
|
self.state = STATE_CONNECTED
|
||||||
|
|
||||||
while not client.closed:
|
while not client.closed:
|
||||||
msg = yield from client.receive()
|
msg = yield from client.receive()
|
||||||
|
|
||||||
if msg.type in (WSMsgType.ERROR, WSMsgType.CLOSED,
|
if msg.type in (WSMsgType.ERROR, WSMsgType.CLOSED,
|
||||||
WSMsgType.CLOSING):
|
WSMsgType.CLOSING):
|
||||||
disconnect_warn = 'Closed by server'
|
disconnect_warn = 'Connection cancelled.'
|
||||||
break
|
break
|
||||||
|
|
||||||
elif msg.type != WSMsgType.TEXT:
|
elif msg.type != WSMsgType.TEXT:
|
||||||
|
@ -144,20 +162,33 @@ class CloudIoT:
|
||||||
self.client = None
|
self.client = None
|
||||||
yield from client.close()
|
yield from client.close()
|
||||||
|
|
||||||
if not self.close_requested:
|
if self.close_requested:
|
||||||
|
self.state = STATE_DISCONNECTED
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.state = STATE_CONNECTING
|
||||||
self.tries += 1
|
self.tries += 1
|
||||||
|
|
||||||
|
try:
|
||||||
# Sleep 0, 5, 10, 15 … up to 30 seconds between retries
|
# Sleep 0, 5, 10, 15 … up to 30 seconds between retries
|
||||||
yield from asyncio.sleep(
|
self.retry_task = hass.async_add_job(asyncio.sleep(
|
||||||
min(30, (self.tries - 1) * 5), loop=hass.loop)
|
min(30, (self.tries - 1) * 5), loop=hass.loop))
|
||||||
|
yield from self.retry_task
|
||||||
|
self.retry_task = None
|
||||||
hass.async_add_job(self.connect())
|
hass.async_add_job(self.connect())
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Happens if disconnect called
|
||||||
|
pass
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
"""Disconnect the client."""
|
"""Disconnect the client."""
|
||||||
self.close_requested = True
|
self.close_requested = True
|
||||||
|
|
||||||
|
if self.client is not None:
|
||||||
yield from self.client.close()
|
yield from self.client.close()
|
||||||
|
elif self.retry_task is not None:
|
||||||
|
self.retry_task.cancel()
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
|
|
@ -69,7 +69,6 @@ def test_login(mock_cognito):
|
||||||
auth_api.login(cloud, 'user', 'pass')
|
auth_api.login(cloud, 'user', 'pass')
|
||||||
|
|
||||||
assert len(mock_cognito.authenticate.mock_calls) == 1
|
assert len(mock_cognito.authenticate.mock_calls) == 1
|
||||||
assert cloud.email == 'user'
|
|
||||||
assert cloud.id_token == 'test_id_token'
|
assert cloud.id_token == 'test_id_token'
|
||||||
assert cloud.access_token == 'test_access_token'
|
assert cloud.access_token == 'test_access_token'
|
||||||
assert cloud.refresh_token == 'test_refresh_token'
|
assert cloud.refresh_token == 'test_refresh_token'
|
||||||
|
|
|
@ -3,9 +3,10 @@ import asyncio
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
from homeassistant.bootstrap import async_setup_component
|
from homeassistant.bootstrap import async_setup_component
|
||||||
from homeassistant.components.cloud import DOMAIN, auth_api
|
from homeassistant.components.cloud import DOMAIN, auth_api, iot
|
||||||
|
|
||||||
from tests.common import mock_coro
|
from tests.common import mock_coro
|
||||||
|
|
||||||
|
@ -23,7 +24,8 @@ def cloud_client(hass, test_client):
|
||||||
'relayer': 'relayer',
|
'relayer': 'relayer',
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
return hass.loop.run_until_complete(test_client(hass.http.app))
|
with patch('homeassistant.components.cloud.Cloud.write_user_info'):
|
||||||
|
yield hass.loop.run_until_complete(test_client(hass.http.app))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -43,21 +45,35 @@ def test_account_view_no_account(cloud_client):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_account_view(hass, cloud_client):
|
def test_account_view(hass, cloud_client):
|
||||||
"""Test fetching account if no account available."""
|
"""Test fetching account if no account available."""
|
||||||
hass.data[DOMAIN].email = 'hello@home-assistant.io'
|
hass.data[DOMAIN].id_token = jwt.encode({
|
||||||
|
'email': 'hello@home-assistant.io',
|
||||||
|
'custom:sub-exp': '2018-01-03'
|
||||||
|
}, 'test')
|
||||||
|
hass.data[DOMAIN].iot.state = iot.STATE_CONNECTED
|
||||||
req = yield from cloud_client.get('/api/cloud/account')
|
req = yield from cloud_client.get('/api/cloud/account')
|
||||||
assert req.status == 200
|
assert req.status == 200
|
||||||
result = yield from req.json()
|
result = yield from req.json()
|
||||||
assert result == {'email': 'hello@home-assistant.io'}
|
assert result == {
|
||||||
|
'email': 'hello@home-assistant.io',
|
||||||
|
'sub_exp': '2018-01-03',
|
||||||
|
'cloud': iot.STATE_CONNECTED,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_login_view(hass, cloud_client):
|
def test_login_view(hass, cloud_client, mock_cognito):
|
||||||
"""Test logging in."""
|
"""Test logging in."""
|
||||||
hass.data[DOMAIN].email = 'hello@home-assistant.io'
|
mock_cognito.id_token = jwt.encode({
|
||||||
|
'email': 'hello@home-assistant.io',
|
||||||
|
'custom:sub-exp': '2018-01-03'
|
||||||
|
}, 'test')
|
||||||
|
mock_cognito.access_token = 'access_token'
|
||||||
|
mock_cognito.refresh_token = 'refresh_token'
|
||||||
|
|
||||||
with patch('homeassistant.components.cloud.iot.CloudIoT.connect'), \
|
with patch('homeassistant.components.cloud.iot.CloudIoT.'
|
||||||
patch('homeassistant.components.cloud.'
|
'connect') as mock_connect, \
|
||||||
'auth_api.login') as mock_login:
|
patch('homeassistant.components.cloud.auth_api._authenticate',
|
||||||
|
return_value=mock_cognito) as mock_auth:
|
||||||
req = yield from cloud_client.post('/api/cloud/login', json={
|
req = yield from cloud_client.post('/api/cloud/login', json={
|
||||||
'email': 'my_username',
|
'email': 'my_username',
|
||||||
'password': 'my_password'
|
'password': 'my_password'
|
||||||
|
@ -65,9 +81,13 @@ def test_login_view(hass, cloud_client):
|
||||||
|
|
||||||
assert req.status == 200
|
assert req.status == 200
|
||||||
result = yield from req.json()
|
result = yield from req.json()
|
||||||
assert result == {'email': 'hello@home-assistant.io'}
|
assert result['email'] == 'hello@home-assistant.io'
|
||||||
assert len(mock_login.mock_calls) == 1
|
assert result['sub_exp'] == '2018-01-03'
|
||||||
cloud, result_user, result_pass = mock_login.mock_calls[0][1]
|
|
||||||
|
assert len(mock_connect.mock_calls) == 1
|
||||||
|
|
||||||
|
assert len(mock_auth.mock_calls) == 1
|
||||||
|
cloud, result_user, result_pass = mock_auth.mock_calls[0][1]
|
||||||
assert result_user == 'my_username'
|
assert result_user == 'my_username'
|
||||||
assert result_pass == 'my_password'
|
assert result_pass == 'my_password'
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,11 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
from unittest.mock import patch, MagicMock, mock_open
|
from unittest.mock import patch, MagicMock, mock_open
|
||||||
|
|
||||||
|
from jose import jwt
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import cloud
|
from homeassistant.components import cloud
|
||||||
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from tests.common import mock_coro
|
from tests.common import mock_coro
|
||||||
|
|
||||||
|
@ -72,7 +74,6 @@ def test_initialize_loads_info(mock_os, hass):
|
||||||
"""Test initialize will load info from config file."""
|
"""Test initialize will load info from config file."""
|
||||||
mock_os.path.isfile.return_value = True
|
mock_os.path.isfile.return_value = True
|
||||||
mopen = mock_open(read_data=json.dumps({
|
mopen = mock_open(read_data=json.dumps({
|
||||||
'email': 'test-email',
|
|
||||||
'id_token': 'test-id-token',
|
'id_token': 'test-id-token',
|
||||||
'access_token': 'test-access-token',
|
'access_token': 'test-access-token',
|
||||||
'refresh_token': 'test-refresh-token',
|
'refresh_token': 'test-refresh-token',
|
||||||
|
@ -85,7 +86,6 @@ def test_initialize_loads_info(mock_os, hass):
|
||||||
with patch('homeassistant.components.cloud.open', mopen, create=True):
|
with patch('homeassistant.components.cloud.open', mopen, create=True):
|
||||||
yield from cl.initialize()
|
yield from cl.initialize()
|
||||||
|
|
||||||
assert cl.email == 'test-email'
|
|
||||||
assert cl.id_token == 'test-id-token'
|
assert cl.id_token == 'test-id-token'
|
||||||
assert cl.access_token == 'test-access-token'
|
assert cl.access_token == 'test-access-token'
|
||||||
assert cl.refresh_token == 'test-refresh-token'
|
assert cl.refresh_token == 'test-refresh-token'
|
||||||
|
@ -102,7 +102,6 @@ def test_logout_clears_info(mock_os, hass):
|
||||||
yield from cl.logout()
|
yield from cl.logout()
|
||||||
|
|
||||||
assert len(cl.iot.disconnect.mock_calls) == 1
|
assert len(cl.iot.disconnect.mock_calls) == 1
|
||||||
assert cl.email is None
|
|
||||||
assert cl.id_token is None
|
assert cl.id_token is None
|
||||||
assert cl.access_token is None
|
assert cl.access_token is None
|
||||||
assert cl.refresh_token is None
|
assert cl.refresh_token is None
|
||||||
|
@ -115,7 +114,6 @@ def test_write_user_info():
|
||||||
mopen = mock_open()
|
mopen = mock_open()
|
||||||
|
|
||||||
cl = cloud.Cloud(MagicMock(), cloud.MODE_DEV)
|
cl = cloud.Cloud(MagicMock(), cloud.MODE_DEV)
|
||||||
cl.email = 'test-email'
|
|
||||||
cl.id_token = 'test-id-token'
|
cl.id_token = 'test-id-token'
|
||||||
cl.access_token = 'test-access-token'
|
cl.access_token = 'test-access-token'
|
||||||
cl.refresh_token = 'test-refresh-token'
|
cl.refresh_token = 'test-refresh-token'
|
||||||
|
@ -129,7 +127,41 @@ def test_write_user_info():
|
||||||
data = json.loads(handle.write.mock_calls[0][1][0])
|
data = json.loads(handle.write.mock_calls[0][1][0])
|
||||||
assert data == {
|
assert data == {
|
||||||
'access_token': 'test-access-token',
|
'access_token': 'test-access-token',
|
||||||
'email': 'test-email',
|
|
||||||
'id_token': 'test-id-token',
|
'id_token': 'test-id-token',
|
||||||
'refresh_token': 'test-refresh-token',
|
'refresh_token': 'test-refresh-token',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_subscription_not_expired_without_sub_in_claim():
|
||||||
|
"""Test that we do not enforce subscriptions yet."""
|
||||||
|
cl = cloud.Cloud(None, cloud.MODE_DEV)
|
||||||
|
cl.id_token = jwt.encode({}, 'test')
|
||||||
|
|
||||||
|
assert not cl.subscription_expired
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_subscription_expired():
|
||||||
|
"""Test subscription being expired."""
|
||||||
|
cl = cloud.Cloud(None, cloud.MODE_DEV)
|
||||||
|
cl.id_token = jwt.encode({
|
||||||
|
'custom:sub-exp': '2017-11-13'
|
||||||
|
}, 'test')
|
||||||
|
|
||||||
|
with patch('homeassistant.util.dt.utcnow',
|
||||||
|
return_value=utcnow().replace(year=2018)):
|
||||||
|
assert cl.subscription_expired
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_subscription_not_expired():
|
||||||
|
"""Test subscription not being expired."""
|
||||||
|
cl = cloud.Cloud(None, cloud.MODE_DEV)
|
||||||
|
cl.id_token = jwt.encode({
|
||||||
|
'custom:sub-exp': '2017-11-13'
|
||||||
|
}, 'test')
|
||||||
|
|
||||||
|
with patch('homeassistant.util.dt.utcnow',
|
||||||
|
return_value=utcnow().replace(year=2017, month=11, day=9)):
|
||||||
|
assert not cl.subscription_expired
|
||||||
|
|
|
@ -30,11 +30,16 @@ def mock_handle_message():
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_cloud():
|
||||||
|
"""Mock cloud class."""
|
||||||
|
return MagicMock(subscription_expired=False)
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_calling_handler(mock_client, mock_handle_message):
|
def test_cloud_calling_handler(mock_client, mock_handle_message, mock_cloud):
|
||||||
"""Test we call handle message with correct info."""
|
"""Test we call handle message with correct info."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.return_value = mock_coro(MagicMock(
|
mock_client.receive.return_value = mock_coro(MagicMock(
|
||||||
type=WSMsgType.text,
|
type=WSMsgType.text,
|
||||||
json=MagicMock(return_value={
|
json=MagicMock(return_value={
|
||||||
|
@ -53,8 +58,8 @@ def test_cloud_calling_handler(mock_client, mock_handle_message):
|
||||||
p_hass, p_cloud, handler_name, payload = \
|
p_hass, p_cloud, handler_name, payload = \
|
||||||
mock_handle_message.mock_calls[0][1]
|
mock_handle_message.mock_calls[0][1]
|
||||||
|
|
||||||
assert p_hass is cloud.hass
|
assert p_hass is mock_cloud.hass
|
||||||
assert p_cloud is cloud
|
assert p_cloud is mock_cloud
|
||||||
assert handler_name == 'test-handler'
|
assert handler_name == 'test-handler'
|
||||||
assert payload == 'test-payload'
|
assert payload == 'test-payload'
|
||||||
|
|
||||||
|
@ -67,10 +72,9 @@ def test_cloud_calling_handler(mock_client, mock_handle_message):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_connection_msg_for_unknown_handler(mock_client):
|
def test_connection_msg_for_unknown_handler(mock_client, mock_cloud):
|
||||||
"""Test a msg for an unknown handler."""
|
"""Test a msg for an unknown handler."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.return_value = mock_coro(MagicMock(
|
mock_client.receive.return_value = mock_coro(MagicMock(
|
||||||
type=WSMsgType.text,
|
type=WSMsgType.text,
|
||||||
json=MagicMock(return_value={
|
json=MagicMock(return_value={
|
||||||
|
@ -92,10 +96,10 @@ def test_connection_msg_for_unknown_handler(mock_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_connection_msg_for_handler_raising(mock_client, mock_handle_message):
|
def test_connection_msg_for_handler_raising(mock_client, mock_handle_message,
|
||||||
|
mock_cloud):
|
||||||
"""Test we sent error when handler raises exception."""
|
"""Test we sent error when handler raises exception."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.return_value = mock_coro(MagicMock(
|
mock_client.receive.return_value = mock_coro(MagicMock(
|
||||||
type=WSMsgType.text,
|
type=WSMsgType.text,
|
||||||
json=MagicMock(return_value={
|
json=MagicMock(return_value={
|
||||||
|
@ -136,37 +140,34 @@ def test_handler_forwarding():
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_handling_core_messages(hass):
|
def test_handling_core_messages(hass, mock_cloud):
|
||||||
"""Test handling core messages."""
|
"""Test handling core messages."""
|
||||||
cloud = MagicMock()
|
mock_cloud.logout.return_value = mock_coro()
|
||||||
cloud.logout.return_value = mock_coro()
|
yield from iot.async_handle_cloud(hass, mock_cloud, {
|
||||||
yield from iot.async_handle_cloud(hass, cloud, {
|
|
||||||
'action': 'logout',
|
'action': 'logout',
|
||||||
'reason': 'Logged in at two places.'
|
'reason': 'Logged in at two places.'
|
||||||
})
|
})
|
||||||
assert len(cloud.logout.mock_calls) == 1
|
assert len(mock_cloud.logout.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_getting_disconnected_by_server(mock_client, caplog):
|
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
|
||||||
"""Test server disconnecting instance."""
|
"""Test server disconnecting instance."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.return_value = mock_coro(MagicMock(
|
mock_client.receive.return_value = mock_coro(MagicMock(
|
||||||
type=WSMsgType.CLOSING,
|
type=WSMsgType.CLOSING,
|
||||||
))
|
))
|
||||||
|
|
||||||
yield from conn.connect()
|
yield from conn.connect()
|
||||||
|
|
||||||
assert 'Connection closed: Closed by server' in caplog.text
|
assert 'Connection closed: Connection cancelled.' in caplog.text
|
||||||
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
|
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_receiving_bytes(mock_client, caplog):
|
def test_cloud_receiving_bytes(mock_client, caplog, mock_cloud):
|
||||||
"""Test server disconnecting instance."""
|
"""Test server disconnecting instance."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.return_value = mock_coro(MagicMock(
|
mock_client.receive.return_value = mock_coro(MagicMock(
|
||||||
type=WSMsgType.BINARY,
|
type=WSMsgType.BINARY,
|
||||||
))
|
))
|
||||||
|
@ -174,14 +175,13 @@ def test_cloud_receiving_bytes(mock_client, caplog):
|
||||||
yield from conn.connect()
|
yield from conn.connect()
|
||||||
|
|
||||||
assert 'Connection closed: Received non-Text message' in caplog.text
|
assert 'Connection closed: Received non-Text message' in caplog.text
|
||||||
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
|
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_sending_invalid_json(mock_client, caplog):
|
def test_cloud_sending_invalid_json(mock_client, caplog, mock_cloud):
|
||||||
"""Test cloud sending invalid JSON."""
|
"""Test cloud sending invalid JSON."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.return_value = mock_coro(MagicMock(
|
mock_client.receive.return_value = mock_coro(MagicMock(
|
||||||
type=WSMsgType.TEXT,
|
type=WSMsgType.TEXT,
|
||||||
json=MagicMock(side_effect=ValueError)
|
json=MagicMock(side_effect=ValueError)
|
||||||
|
@ -190,27 +190,25 @@ def test_cloud_sending_invalid_json(mock_client, caplog):
|
||||||
yield from conn.connect()
|
yield from conn.connect()
|
||||||
|
|
||||||
assert 'Connection closed: Received invalid JSON.' in caplog.text
|
assert 'Connection closed: Received invalid JSON.' in caplog.text
|
||||||
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
|
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_check_token_raising(mock_client, caplog):
|
def test_cloud_check_token_raising(mock_client, caplog, mock_cloud):
|
||||||
"""Test cloud sending invalid JSON."""
|
"""Test cloud sending invalid JSON."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.side_effect = auth_api.CloudError
|
mock_client.receive.side_effect = auth_api.CloudError
|
||||||
|
|
||||||
yield from conn.connect()
|
yield from conn.connect()
|
||||||
|
|
||||||
assert 'Unable to connect: Unable to refresh token.' in caplog.text
|
assert 'Unable to connect: Unable to refresh token.' in caplog.text
|
||||||
assert 'connect' in str(cloud.hass.async_add_job.mock_calls[-1][1][0])
|
assert 'connect' in str(mock_cloud.hass.async_add_job.mock_calls[-1][1][0])
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_connect_invalid_auth(mock_client, caplog):
|
def test_cloud_connect_invalid_auth(mock_client, caplog, mock_cloud):
|
||||||
"""Test invalid auth detected by server."""
|
"""Test invalid auth detected by server."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.side_effect = \
|
mock_client.receive.side_effect = \
|
||||||
client_exceptions.WSServerHandshakeError(None, None, code=401)
|
client_exceptions.WSServerHandshakeError(None, None, code=401)
|
||||||
|
|
||||||
|
@ -220,10 +218,9 @@ def test_cloud_connect_invalid_auth(mock_client, caplog):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_unable_to_connect(mock_client, caplog):
|
def test_cloud_unable_to_connect(mock_client, caplog, mock_cloud):
|
||||||
"""Test unable to connect error."""
|
"""Test unable to connect error."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.side_effect = client_exceptions.ClientError(None, None)
|
mock_client.receive.side_effect = client_exceptions.ClientError(None, None)
|
||||||
|
|
||||||
yield from conn.connect()
|
yield from conn.connect()
|
||||||
|
@ -232,12 +229,28 @@ def test_cloud_unable_to_connect(mock_client, caplog):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_cloud_random_exception(mock_client, caplog):
|
def test_cloud_random_exception(mock_client, caplog, mock_cloud):
|
||||||
"""Test random exception."""
|
"""Test random exception."""
|
||||||
cloud = MagicMock()
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
conn = iot.CloudIoT(cloud)
|
|
||||||
mock_client.receive.side_effect = Exception
|
mock_client.receive.side_effect = Exception
|
||||||
|
|
||||||
yield from conn.connect()
|
yield from conn.connect()
|
||||||
|
|
||||||
assert 'Unexpected error' in caplog.text
|
assert 'Unexpected error' in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_refresh_token_before_expiration_fails(hass, mock_cloud):
|
||||||
|
"""Test that we don't connect if token is expired."""
|
||||||
|
mock_cloud.subscription_expired = True
|
||||||
|
mock_cloud.hass = hass
|
||||||
|
conn = iot.CloudIoT(mock_cloud)
|
||||||
|
|
||||||
|
with patch('homeassistant.components.cloud.auth_api.check_token',
|
||||||
|
return_value=mock_coro()) as mock_check_token, \
|
||||||
|
patch.object(hass.components.persistent_notification,
|
||||||
|
'async_create') as mock_create:
|
||||||
|
yield from conn.connect()
|
||||||
|
|
||||||
|
assert len(mock_check_token.mock_calls) == 1
|
||||||
|
assert len(mock_create.mock_calls) == 1
|
||||||
|
|
Loading…
Add table
Reference in a new issue