hass-core/homeassistant/components/cloud/cloud_api.py
Paulus Schoutsen 0b58d5405e Add cloud auth support ()
* Add initial cloud auth

* Move hass.data to a dict

* Move mode into helper

* Fix bugs afte refactor

* Add tests

* Clean up scripts file after test config

* Lint

* Update __init__.py
2017-08-29 13:40:08 -07:00

297 lines
7.5 KiB
Python

"""Package to offer tools to communicate with the cloud."""
import asyncio
from datetime import timedelta
import json
import logging
import os
from urllib.parse import urljoin
import aiohttp
import async_timeout
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util.dt import utcnow
from .const import AUTH_FILE, REQUEST_TIMEOUT, SERVERS
from .util import get_mode
_LOGGER = logging.getLogger(__name__)
URL_CREATE_TOKEN = 'o/token/'
URL_REVOKE_TOKEN = 'o/revoke_token/'
URL_ACCOUNT = 'account.json'
class CloudError(Exception):
"""Base class for cloud related errors."""
def __init__(self, reason=None, status=None):
"""Initialize a cloud error."""
super().__init__(reason)
self.status = status
class Unauthenticated(CloudError):
"""Raised when authentication failed."""
class UnknownError(CloudError):
"""Raised when an unknown error occurred."""
@asyncio.coroutine
def async_load_auth(hass):
"""Load authentication from disk and verify it."""
auth = yield from hass.async_add_job(_read_auth, hass)
if not auth:
return None
cloud = Cloud(hass, auth)
try:
with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop):
auth_check = yield from cloud.async_refresh_account_info()
if not auth_check:
_LOGGER.error('Unable to validate credentials.')
return None
return cloud
except asyncio.TimeoutError:
_LOGGER.error('Unable to reach server to validate credentials.')
return None
@asyncio.coroutine
def async_login(hass, username, password, scope=None):
"""Get a token using a username and password.
Returns a coroutine.
"""
data = {
'grant_type': 'password',
'username': username,
'password': password
}
if scope is not None:
data['scope'] = scope
auth = yield from _async_get_token(hass, data)
yield from hass.async_add_job(_write_auth, hass, auth)
return Cloud(hass, auth)
@asyncio.coroutine
def _async_get_token(hass, data):
"""Get a new token and return it as a dictionary.
Raises exceptions when errors occur:
- Unauthenticated
- UnknownError
"""
session = async_get_clientsession(hass)
auth = aiohttp.BasicAuth(*_client_credentials(hass))
try:
req = yield from session.post(
_url(hass, URL_CREATE_TOKEN),
data=data,
auth=auth
)
if req.status == 401:
_LOGGER.error('Cloud login failed: %d', req.status)
raise Unauthenticated(status=req.status)
elif req.status != 200:
_LOGGER.error('Cloud login failed: %d', req.status)
raise UnknownError(status=req.status)
response = yield from req.json()
response['expires_at'] = \
(utcnow() + timedelta(seconds=response['expires_in'])).isoformat()
return response
except aiohttp.ClientError:
raise UnknownError()
class Cloud:
"""Store Hass Cloud info."""
def __init__(self, hass, auth):
"""Initialize Hass cloud info object."""
self.hass = hass
self.auth = auth
self.account = None
@property
def access_token(self):
"""Return access token."""
return self.auth['access_token']
@property
def refresh_token(self):
"""Get refresh token."""
return self.auth['refresh_token']
@asyncio.coroutine
def async_refresh_account_info(self):
"""Refresh the account info."""
req = yield from self.async_request('get', URL_ACCOUNT)
if req.status != 200:
return False
self.account = yield from req.json()
return True
@asyncio.coroutine
def async_refresh_access_token(self):
"""Get a token using a refresh token."""
try:
self.auth = yield from _async_get_token(self.hass, {
'grant_type': 'refresh_token',
'refresh_token': self.refresh_token,
})
yield from self.hass.async_add_job(
_write_auth, self.hass, self.auth)
return True
except CloudError:
return False
@asyncio.coroutine
def async_revoke_access_token(self):
"""Revoke active access token."""
session = async_get_clientsession(self.hass)
client_id, client_secret = _client_credentials(self.hass)
data = {
'token': self.access_token,
'client_id': client_id,
'client_secret': client_secret
}
try:
req = yield from session.post(
_url(self.hass, URL_REVOKE_TOKEN),
data=data,
)
if req.status != 200:
_LOGGER.error('Cloud logout failed: %d', req.status)
raise UnknownError(status=req.status)
self.auth = None
yield from self.hass.async_add_job(
_write_auth, self.hass, None)
except aiohttp.ClientError:
raise UnknownError()
@asyncio.coroutine
def async_request(self, method, path, **kwargs):
"""Make a request to Home Assistant cloud.
Will refresh the token if necessary.
"""
session = async_get_clientsession(self.hass)
url = _url(self.hass, path)
if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['authorization'] = \
'Bearer {}'.format(self.access_token)
request = yield from session.request(method, url, **kwargs)
if request.status != 403:
return request
# Maybe token expired. Try refreshing it.
reauth = yield from self.async_refresh_access_token()
if not reauth:
return request
# Release old connection back to the pool.
yield from request.release()
kwargs['headers']['authorization'] = \
'Bearer {}'.format(self.access_token)
# If we are not already fetching the account info,
# refresh the account info.
if path != URL_ACCOUNT:
yield from self.async_refresh_account_info()
request = yield from session.request(method, url, **kwargs)
return request
def _read_auth(hass):
"""Read auth file."""
path = hass.config.path(AUTH_FILE)
if not os.path.isfile(path):
return None
with open(path) as file:
return json.load(file).get(get_mode(hass))
def _write_auth(hass, data):
"""Write auth info for specified mode.
Pass in None for data to remove authentication for that mode.
"""
path = hass.config.path(AUTH_FILE)
mode = get_mode(hass)
if os.path.isfile(path):
with open(path) as file:
content = json.load(file)
else:
content = {}
if data is None:
content.pop(mode, None)
else:
content[mode] = data
with open(path, 'wt') as file:
file.write(json.dumps(content, indent=4, sort_keys=True))
def _client_credentials(hass):
"""Get the client credentials.
Async friendly.
"""
mode = get_mode(hass)
if mode not in SERVERS:
raise ValueError('Mode {} is not supported.'.format(mode))
return SERVERS[mode]['client_id'], SERVERS[mode]['client_secret']
def _url(hass, path):
"""Generate a url for the cloud.
Async friendly.
"""
mode = get_mode(hass)
if mode not in SERVERS:
raise ValueError('Mode {} is not supported.'.format(mode))
return urljoin(SERVERS[mode]['host'], path)