Update cloud auth (#9357)

* Update cloud logic

* Lint

* Update test requirements

* Address commments, fix tests

* Add credentials
This commit is contained in:
Paulus Schoutsen 2017-09-12 09:47:04 -07:00 committed by Pascal Vizeli
parent 90f9a6bc0a
commit c9fc3fae6e
12 changed files with 1047 additions and 822 deletions

View file

@ -1,297 +0,0 @@
"""Package to offer tools to communicate with the cloud."""
import asyncio
from datetime import timedelta
import json
import logging
import os
from urllib.parse import urljoin
import aiohttp
import async_timeout
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util.dt import utcnow
from .const import AUTH_FILE, REQUEST_TIMEOUT, SERVERS
from .util import get_mode
_LOGGER = logging.getLogger(__name__)
URL_CREATE_TOKEN = 'o/token/'
URL_REVOKE_TOKEN = 'o/revoke_token/'
URL_ACCOUNT = 'account.json'
class CloudError(Exception):
"""Base class for cloud related errors."""
def __init__(self, reason=None, status=None):
"""Initialize a cloud error."""
super().__init__(reason)
self.status = status
class Unauthenticated(CloudError):
"""Raised when authentication failed."""
class UnknownError(CloudError):
"""Raised when an unknown error occurred."""
@asyncio.coroutine
def async_load_auth(hass):
"""Load authentication from disk and verify it."""
auth = yield from hass.async_add_job(_read_auth, hass)
if not auth:
return None
cloud = Cloud(hass, auth)
try:
with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop):
auth_check = yield from cloud.async_refresh_account_info()
if not auth_check:
_LOGGER.error('Unable to validate credentials.')
return None
return cloud
except asyncio.TimeoutError:
_LOGGER.error('Unable to reach server to validate credentials.')
return None
@asyncio.coroutine
def async_login(hass, username, password, scope=None):
"""Get a token using a username and password.
Returns a coroutine.
"""
data = {
'grant_type': 'password',
'username': username,
'password': password
}
if scope is not None:
data['scope'] = scope
auth = yield from _async_get_token(hass, data)
yield from hass.async_add_job(_write_auth, hass, auth)
return Cloud(hass, auth)
@asyncio.coroutine
def _async_get_token(hass, data):
"""Get a new token and return it as a dictionary.
Raises exceptions when errors occur:
- Unauthenticated
- UnknownError
"""
session = async_get_clientsession(hass)
auth = aiohttp.BasicAuth(*_client_credentials(hass))
try:
req = yield from session.post(
_url(hass, URL_CREATE_TOKEN),
data=data,
auth=auth
)
if req.status == 401:
_LOGGER.error('Cloud login failed: %d', req.status)
raise Unauthenticated(status=req.status)
elif req.status != 200:
_LOGGER.error('Cloud login failed: %d', req.status)
raise UnknownError(status=req.status)
response = yield from req.json()
response['expires_at'] = \
(utcnow() + timedelta(seconds=response['expires_in'])).isoformat()
return response
except aiohttp.ClientError:
raise UnknownError()
class Cloud:
"""Store Hass Cloud info."""
def __init__(self, hass, auth):
"""Initialize Hass cloud info object."""
self.hass = hass
self.auth = auth
self.account = None
@property
def access_token(self):
"""Return access token."""
return self.auth['access_token']
@property
def refresh_token(self):
"""Get refresh token."""
return self.auth['refresh_token']
@asyncio.coroutine
def async_refresh_account_info(self):
"""Refresh the account info."""
req = yield from self.async_request('get', URL_ACCOUNT)
if req.status != 200:
return False
self.account = yield from req.json()
return True
@asyncio.coroutine
def async_refresh_access_token(self):
"""Get a token using a refresh token."""
try:
self.auth = yield from _async_get_token(self.hass, {
'grant_type': 'refresh_token',
'refresh_token': self.refresh_token,
})
yield from self.hass.async_add_job(
_write_auth, self.hass, self.auth)
return True
except CloudError:
return False
@asyncio.coroutine
def async_revoke_access_token(self):
"""Revoke active access token."""
session = async_get_clientsession(self.hass)
client_id, client_secret = _client_credentials(self.hass)
data = {
'token': self.access_token,
'client_id': client_id,
'client_secret': client_secret
}
try:
req = yield from session.post(
_url(self.hass, URL_REVOKE_TOKEN),
data=data,
)
if req.status != 200:
_LOGGER.error('Cloud logout failed: %d', req.status)
raise UnknownError(status=req.status)
self.auth = None
yield from self.hass.async_add_job(
_write_auth, self.hass, None)
except aiohttp.ClientError:
raise UnknownError()
@asyncio.coroutine
def async_request(self, method, path, **kwargs):
"""Make a request to Home Assistant cloud.
Will refresh the token if necessary.
"""
session = async_get_clientsession(self.hass)
url = _url(self.hass, path)
if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['authorization'] = \
'Bearer {}'.format(self.access_token)
request = yield from session.request(method, url, **kwargs)
if request.status != 403:
return request
# Maybe token expired. Try refreshing it.
reauth = yield from self.async_refresh_access_token()
if not reauth:
return request
# Release old connection back to the pool.
yield from request.release()
kwargs['headers']['authorization'] = \
'Bearer {}'.format(self.access_token)
# If we are not already fetching the account info,
# refresh the account info.
if path != URL_ACCOUNT:
yield from self.async_refresh_account_info()
request = yield from session.request(method, url, **kwargs)
return request
def _read_auth(hass):
"""Read auth file."""
path = hass.config.path(AUTH_FILE)
if not os.path.isfile(path):
return None
with open(path) as file:
return json.load(file).get(get_mode(hass))
def _write_auth(hass, data):
"""Write auth info for specified mode.
Pass in None for data to remove authentication for that mode.
"""
path = hass.config.path(AUTH_FILE)
mode = get_mode(hass)
if os.path.isfile(path):
with open(path) as file:
content = json.load(file)
else:
content = {}
if data is None:
content.pop(mode, None)
else:
content[mode] = data
with open(path, 'wt') as file:
file.write(json.dumps(content, indent=4, sort_keys=True))
def _client_credentials(hass):
"""Get the client credentials.
Async friendly.
"""
mode = get_mode(hass)
if mode not in SERVERS:
raise ValueError('Mode {} is not supported.'.format(mode))
return SERVERS[mode]['client_id'], SERVERS[mode]['client_secret']
def _url(hass, path):
"""Generate a url for the cloud.
Async friendly.
"""
mode = get_mode(hass)
if mode not in SERVERS:
raise ValueError('Mode {} is not supported.'.format(mode))
return urljoin(SERVERS[mode]['host'], path)