* Add initial cloud auth * Move hass.data to a dict * Move mode into helper * Fix bugs afte refactor * Add tests * Clean up scripts file after test config * Lint * Update __init__.py
297 lines
7.5 KiB
Python
297 lines
7.5 KiB
Python
"""Package to offer tools to communicate with the cloud."""
|
|
import asyncio
|
|
from datetime import timedelta
|
|
import json
|
|
import logging
|
|
import os
|
|
from urllib.parse import urljoin
|
|
|
|
import aiohttp
|
|
import async_timeout
|
|
|
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
|
from homeassistant.util.dt import utcnow
|
|
|
|
from .const import AUTH_FILE, REQUEST_TIMEOUT, SERVERS
|
|
from .util import get_mode
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
URL_CREATE_TOKEN = 'o/token/'
|
|
URL_REVOKE_TOKEN = 'o/revoke_token/'
|
|
URL_ACCOUNT = 'account.json'
|
|
|
|
|
|
class CloudError(Exception):
|
|
"""Base class for cloud related errors."""
|
|
|
|
def __init__(self, reason=None, status=None):
|
|
"""Initialize a cloud error."""
|
|
super().__init__(reason)
|
|
self.status = status
|
|
|
|
|
|
class Unauthenticated(CloudError):
|
|
"""Raised when authentication failed."""
|
|
|
|
|
|
class UnknownError(CloudError):
|
|
"""Raised when an unknown error occurred."""
|
|
|
|
|
|
@asyncio.coroutine
|
|
def async_load_auth(hass):
|
|
"""Load authentication from disk and verify it."""
|
|
auth = yield from hass.async_add_job(_read_auth, hass)
|
|
|
|
if not auth:
|
|
return None
|
|
|
|
cloud = Cloud(hass, auth)
|
|
|
|
try:
|
|
with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop):
|
|
auth_check = yield from cloud.async_refresh_account_info()
|
|
|
|
if not auth_check:
|
|
_LOGGER.error('Unable to validate credentials.')
|
|
return None
|
|
|
|
return cloud
|
|
|
|
except asyncio.TimeoutError:
|
|
_LOGGER.error('Unable to reach server to validate credentials.')
|
|
return None
|
|
|
|
|
|
@asyncio.coroutine
|
|
def async_login(hass, username, password, scope=None):
|
|
"""Get a token using a username and password.
|
|
|
|
Returns a coroutine.
|
|
"""
|
|
data = {
|
|
'grant_type': 'password',
|
|
'username': username,
|
|
'password': password
|
|
}
|
|
if scope is not None:
|
|
data['scope'] = scope
|
|
|
|
auth = yield from _async_get_token(hass, data)
|
|
|
|
yield from hass.async_add_job(_write_auth, hass, auth)
|
|
|
|
return Cloud(hass, auth)
|
|
|
|
|
|
@asyncio.coroutine
|
|
def _async_get_token(hass, data):
|
|
"""Get a new token and return it as a dictionary.
|
|
|
|
Raises exceptions when errors occur:
|
|
- Unauthenticated
|
|
- UnknownError
|
|
"""
|
|
session = async_get_clientsession(hass)
|
|
auth = aiohttp.BasicAuth(*_client_credentials(hass))
|
|
|
|
try:
|
|
req = yield from session.post(
|
|
_url(hass, URL_CREATE_TOKEN),
|
|
data=data,
|
|
auth=auth
|
|
)
|
|
|
|
if req.status == 401:
|
|
_LOGGER.error('Cloud login failed: %d', req.status)
|
|
raise Unauthenticated(status=req.status)
|
|
elif req.status != 200:
|
|
_LOGGER.error('Cloud login failed: %d', req.status)
|
|
raise UnknownError(status=req.status)
|
|
|
|
response = yield from req.json()
|
|
response['expires_at'] = \
|
|
(utcnow() + timedelta(seconds=response['expires_in'])).isoformat()
|
|
|
|
return response
|
|
|
|
except aiohttp.ClientError:
|
|
raise UnknownError()
|
|
|
|
|
|
class Cloud:
|
|
"""Store Hass Cloud info."""
|
|
|
|
def __init__(self, hass, auth):
|
|
"""Initialize Hass cloud info object."""
|
|
self.hass = hass
|
|
self.auth = auth
|
|
self.account = None
|
|
|
|
@property
|
|
def access_token(self):
|
|
"""Return access token."""
|
|
return self.auth['access_token']
|
|
|
|
@property
|
|
def refresh_token(self):
|
|
"""Get refresh token."""
|
|
return self.auth['refresh_token']
|
|
|
|
@asyncio.coroutine
|
|
def async_refresh_account_info(self):
|
|
"""Refresh the account info."""
|
|
req = yield from self.async_request('get', URL_ACCOUNT)
|
|
|
|
if req.status != 200:
|
|
return False
|
|
|
|
self.account = yield from req.json()
|
|
return True
|
|
|
|
@asyncio.coroutine
|
|
def async_refresh_access_token(self):
|
|
"""Get a token using a refresh token."""
|
|
try:
|
|
self.auth = yield from _async_get_token(self.hass, {
|
|
'grant_type': 'refresh_token',
|
|
'refresh_token': self.refresh_token,
|
|
})
|
|
|
|
yield from self.hass.async_add_job(
|
|
_write_auth, self.hass, self.auth)
|
|
|
|
return True
|
|
except CloudError:
|
|
return False
|
|
|
|
@asyncio.coroutine
|
|
def async_revoke_access_token(self):
|
|
"""Revoke active access token."""
|
|
session = async_get_clientsession(self.hass)
|
|
client_id, client_secret = _client_credentials(self.hass)
|
|
data = {
|
|
'token': self.access_token,
|
|
'client_id': client_id,
|
|
'client_secret': client_secret
|
|
}
|
|
try:
|
|
req = yield from session.post(
|
|
_url(self.hass, URL_REVOKE_TOKEN),
|
|
data=data,
|
|
)
|
|
|
|
if req.status != 200:
|
|
_LOGGER.error('Cloud logout failed: %d', req.status)
|
|
raise UnknownError(status=req.status)
|
|
|
|
self.auth = None
|
|
yield from self.hass.async_add_job(
|
|
_write_auth, self.hass, None)
|
|
|
|
except aiohttp.ClientError:
|
|
raise UnknownError()
|
|
|
|
@asyncio.coroutine
|
|
def async_request(self, method, path, **kwargs):
|
|
"""Make a request to Home Assistant cloud.
|
|
|
|
Will refresh the token if necessary.
|
|
"""
|
|
session = async_get_clientsession(self.hass)
|
|
url = _url(self.hass, path)
|
|
|
|
if 'headers' not in kwargs:
|
|
kwargs['headers'] = {}
|
|
|
|
kwargs['headers']['authorization'] = \
|
|
'Bearer {}'.format(self.access_token)
|
|
|
|
request = yield from session.request(method, url, **kwargs)
|
|
|
|
if request.status != 403:
|
|
return request
|
|
|
|
# Maybe token expired. Try refreshing it.
|
|
reauth = yield from self.async_refresh_access_token()
|
|
|
|
if not reauth:
|
|
return request
|
|
|
|
# Release old connection back to the pool.
|
|
yield from request.release()
|
|
|
|
kwargs['headers']['authorization'] = \
|
|
'Bearer {}'.format(self.access_token)
|
|
|
|
# If we are not already fetching the account info,
|
|
# refresh the account info.
|
|
|
|
if path != URL_ACCOUNT:
|
|
yield from self.async_refresh_account_info()
|
|
|
|
request = yield from session.request(method, url, **kwargs)
|
|
|
|
return request
|
|
|
|
|
|
def _read_auth(hass):
|
|
"""Read auth file."""
|
|
path = hass.config.path(AUTH_FILE)
|
|
|
|
if not os.path.isfile(path):
|
|
return None
|
|
|
|
with open(path) as file:
|
|
return json.load(file).get(get_mode(hass))
|
|
|
|
|
|
def _write_auth(hass, data):
|
|
"""Write auth info for specified mode.
|
|
|
|
Pass in None for data to remove authentication for that mode.
|
|
"""
|
|
path = hass.config.path(AUTH_FILE)
|
|
mode = get_mode(hass)
|
|
|
|
if os.path.isfile(path):
|
|
with open(path) as file:
|
|
content = json.load(file)
|
|
else:
|
|
content = {}
|
|
|
|
if data is None:
|
|
content.pop(mode, None)
|
|
else:
|
|
content[mode] = data
|
|
|
|
with open(path, 'wt') as file:
|
|
file.write(json.dumps(content, indent=4, sort_keys=True))
|
|
|
|
|
|
def _client_credentials(hass):
|
|
"""Get the client credentials.
|
|
|
|
Async friendly.
|
|
"""
|
|
mode = get_mode(hass)
|
|
|
|
if mode not in SERVERS:
|
|
raise ValueError('Mode {} is not supported.'.format(mode))
|
|
|
|
return SERVERS[mode]['client_id'], SERVERS[mode]['client_secret']
|
|
|
|
|
|
def _url(hass, path):
|
|
"""Generate a url for the cloud.
|
|
|
|
Async friendly.
|
|
"""
|
|
mode = get_mode(hass)
|
|
|
|
if mode not in SERVERS:
|
|
raise ValueError('Mode {} is not supported.'.format(mode))
|
|
|
|
return urljoin(SERVERS[mode]['host'], path)
|