Get user after login flow finished (#16047)

* Get user after login flow finished

* Add optional parameter 'type' to /auth/login_flow

* Update __init__.py
This commit is contained in:
Jason Hu 2018-08-21 01:18:04 -07:00 committed by Paulus Schoutsen
parent b1ba11510b
commit f84a31871e
6 changed files with 83 additions and 52 deletions

View file

@ -2,7 +2,7 @@
import asyncio
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Tuple, cast, Union
import jwt
@ -257,15 +257,20 @@ class AuthManager:
async def _async_finish_login_flow(
self, context: Optional[Dict], result: Dict[str, Any]) \
-> Optional[models.Credentials]:
"""Result of a credential login flow."""
-> Optional[Union[models.User, models.Credentials]]:
"""Return a user as result of login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
auth_provider = self._providers[result['handler']]
return await auth_provider.async_get_or_create_credentials(
cred = await auth_provider.async_get_or_create_credentials(
result['data'])
if context is not None and context.get('credential_only'):
return cred
return await self.async_get_or_create_user(cred)
@callback
def _async_get_auth_provider(
self, credentials: models.Credentials) -> Optional[AuthProvider]:

View file

@ -51,6 +51,7 @@ from datetime import timedelta
import voluptuous as vol
from homeassistant.auth.models import User, Credentials
from homeassistant.components import websocket_api
from homeassistant.components.http.ban import log_invalid_auth
from homeassistant.components.http.data_validator import RequestDataValidator
@ -68,22 +69,25 @@ SCHEMA_WS_CURRENT_USER = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): WS_TYPE_CURRENT_USER,
})
RESULT_TYPE_CREDENTIALS = 'credentials'
RESULT_TYPE_USER = 'user'
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass, config):
"""Component to allow users to login."""
store_credentials, retrieve_credentials = _create_cred_store()
store_result, retrieve_result = _create_auth_code_store()
hass.http.register_view(GrantTokenView(retrieve_credentials))
hass.http.register_view(LinkUserView(retrieve_credentials))
hass.http.register_view(GrantTokenView(retrieve_result))
hass.http.register_view(LinkUserView(retrieve_result))
hass.components.websocket_api.async_register_command(
WS_TYPE_CURRENT_USER, websocket_current_user,
SCHEMA_WS_CURRENT_USER
)
await login_flow.async_setup(hass, store_credentials)
await login_flow.async_setup(hass, store_result)
return True
@ -96,9 +100,9 @@ class GrantTokenView(HomeAssistantView):
requires_auth = False
cors_allowed = True
def __init__(self, retrieve_credentials):
def __init__(self, retrieve_user):
"""Initialize the grant token view."""
self._retrieve_credentials = retrieve_credentials
self._retrieve_user = retrieve_user
@log_invalid_auth
async def post(self, request):
@ -134,15 +138,16 @@ class GrantTokenView(HomeAssistantView):
'error': 'invalid_request',
}, status_code=400)
credentials = self._retrieve_credentials(client_id, code)
user = self._retrieve_user(client_id, RESULT_TYPE_USER, code)
if credentials is None:
if user is None or not isinstance(user, User):
return self.json({
'error': 'invalid_request',
'error_description': 'Invalid code',
}, status_code=400)
user = await hass.auth.async_get_or_create_user(credentials)
# refresh user
user = await hass.auth.async_get_user(user.id)
if not user.is_active:
return self.json({
@ -220,7 +225,7 @@ class LinkUserView(HomeAssistantView):
user = request['hass_user']
credentials = self._retrieve_credentials(
data['client_id'], data['code'])
data['client_id'], RESULT_TYPE_CREDENTIALS, data['code'])
if credentials is None:
return self.json_message('Invalid code', status_code=400)
@ -230,37 +235,45 @@ class LinkUserView(HomeAssistantView):
@callback
def _create_cred_store():
"""Create a credential store."""
temp_credentials = {}
def _create_auth_code_store():
"""Create an in memory store."""
temp_results = {}
@callback
def store_credentials(client_id, credentials):
"""Store credentials and return a code to retrieve it."""
def store_result(client_id, result):
"""Store flow result and return a code to retrieve it."""
if isinstance(result, User):
result_type = RESULT_TYPE_USER
elif isinstance(result, Credentials):
result_type = RESULT_TYPE_CREDENTIALS
else:
raise ValueError('result has to be either User or Credentials')
code = uuid.uuid4().hex
temp_credentials[(client_id, code)] = (dt_util.utcnow(), credentials)
temp_results[(client_id, result_type, code)] = \
(dt_util.utcnow(), result_type, result)
return code
@callback
def retrieve_credentials(client_id, code):
"""Retrieve credentials."""
key = (client_id, code)
def retrieve_result(client_id, result_type, code):
"""Retrieve flow result."""
key = (client_id, result_type, code)
if key not in temp_credentials:
if key not in temp_results:
return None
created, credentials = temp_credentials.pop(key)
created, _, result = temp_results.pop(key)
# OAuth 4.2.1
# The authorization code MUST expire shortly after it is issued to
# mitigate the risk of leaks. A maximum authorization code lifetime of
# 10 minutes is RECOMMENDED.
if dt_util.utcnow() - created < timedelta(minutes=10):
return credentials
return result
return None
return store_credentials, retrieve_credentials
return store_result, retrieve_result
@callback

View file

@ -22,10 +22,14 @@ Pass in parameter 'client_id' and 'redirect_url' validate by indieauth.
Pass in parameter 'handler' to specify the auth provider to use. Auth providers
are identified by type and id.
And optional parameter 'type' has to set as 'link_user' if login flow used for
link credential to exist user. Default 'type' is 'authorize'.
{
"client_id": "https://hassbian.local:8123/",
"handler": ["local_provider", null],
"redirect_url": "https://hassbian.local:8123/"
"redirect_url": "https://hassbian.local:8123/",
"type': "authorize"
}
Return value will be a step in a data entry flow. See the docs for data entry
@ -49,6 +53,9 @@ flow for details.
Progress the flow. Most flows will be 1 page, but could optionally add extra
login challenges, like TFA. Once the flow has finished, the returned step will
have type "create_entry" and "result" key will contain an authorization code.
The authorization code associated with an authorized user by default, it will
associate with an credential if "type" set to "link_user" in
"/auth/login_flow"
{
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
@ -71,12 +78,12 @@ from homeassistant.components.http.view import HomeAssistantView
from . import indieauth
async def async_setup(hass, store_credentials):
async def async_setup(hass, store_result):
"""Component to allow users to login."""
hass.http.register_view(AuthProvidersView)
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow))
hass.http.register_view(
LoginFlowResourceView(hass.auth.login_flow, store_credentials))
LoginFlowResourceView(hass.auth.login_flow, store_result))
class AuthProvidersView(HomeAssistantView):
@ -138,6 +145,7 @@ class LoginFlowIndexView(HomeAssistantView):
vol.Required('client_id'): str,
vol.Required('handler'): vol.Any(str, list),
vol.Required('redirect_uri'): str,
vol.Optional('type', default='authorize'): str,
}))
@log_invalid_auth
async def post(self, request, data):
@ -153,7 +161,10 @@ class LoginFlowIndexView(HomeAssistantView):
try:
result = await self._flow_mgr.async_init(
handler, context={'ip_address': request[KEY_REAL_IP]})
handler, context={
'ip_address': request[KEY_REAL_IP],
'credential_only': data.get('type') == 'link_user',
})
except data_entry_flow.UnknownHandler:
return self.json_message('Invalid handler specified', 404)
except data_entry_flow.UnknownStep:
@ -169,10 +180,10 @@ class LoginFlowResourceView(HomeAssistantView):
name = 'api:auth:login_flow:resource'
requires_auth = False
def __init__(self, flow_mgr, store_credentials):
def __init__(self, flow_mgr, store_result):
"""Initialize the login flow resource view."""
self._flow_mgr = flow_mgr
self._store_credentials = store_credentials
self._store_result = store_result
async def get(self, request):
"""Do not allow getting status of a flow in progress."""
@ -212,7 +223,7 @@ class LoginFlowResourceView(HomeAssistantView):
return self.json(_prepare_result_json(result))
result.pop('data')
result['result'] = self._store_credentials(client_id, result['result'])
result['result'] = self._store_result(client_id, result['result'])
return self.json(result)

View file

@ -77,8 +77,7 @@ async def test_create_new_user(hass, hass_storage):
'password': 'test-pass',
})
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
credentials = step['result']
user = await manager.async_get_or_create_user(credentials)
user = step['result']
assert user is not None
assert user.is_owner is False
assert user.name == 'Test Name'
@ -134,9 +133,8 @@ async def test_login_as_existing_user(mock_hass):
'password': 'test-pass',
})
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
credentials = step['result']
user = await manager.async_get_or_create_user(credentials)
user = step['result']
assert user is not None
assert user.id == 'mock-user'
assert user.is_owner is False
@ -166,16 +164,18 @@ async def test_linking_user_to_two_auth_providers(hass, hass_storage):
'username': 'test-user',
'password': 'test-pass',
})
user = await manager.async_get_or_create_user(step['result'])
user = step['result']
assert user is not None
step = await manager.login_flow.async_init(('insecure_example',
'another-provider'))
step = await manager.login_flow.async_init(
('insecure_example', 'another-provider'),
context={'credential_only': True})
step = await manager.login_flow.async_configure(step['flow_id'], {
'username': 'another-user',
'password': 'another-password',
})
await manager.async_link_user(user, step['result'])
new_credential = step['result']
await manager.async_link_user(user, new_credential)
assert len(user.credentials) == 2
@ -197,7 +197,7 @@ async def test_saving_loading(hass, hass_storage):
'username': 'test-user',
'password': 'test-pass',
})
user = await manager.async_get_or_create_user(step['result'])
user = step['result']
await manager.async_activate_user(user)
await manager.async_create_refresh_token(user, CLIENT_ID)

View file

@ -3,13 +3,14 @@ from datetime import timedelta
from unittest.mock import patch
from homeassistant.auth.models import Credentials
from homeassistant.components.auth import RESULT_TYPE_USER
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
from homeassistant.components import auth
from . import async_setup_auth
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
@ -74,26 +75,26 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
assert resp.status == 200
def test_credential_store_expiration():
"""Test that the credential store will not return expired tokens."""
store, retrieve = auth._create_cred_store()
def test_auth_code_store_expiration():
"""Test that the auth code store will not return expired tokens."""
store, retrieve = auth._create_auth_code_store()
client_id = 'bla'
credentials = 'creds'
user = MockUser(id='mock_user')
now = utcnow()
with patch('homeassistant.util.dt.utcnow', return_value=now):
code = store(client_id, credentials)
code = store(client_id, user)
with patch('homeassistant.util.dt.utcnow',
return_value=now + timedelta(minutes=10)):
assert retrieve(client_id, code) is None
assert retrieve(client_id, RESULT_TYPE_USER, code) is None
with patch('homeassistant.util.dt.utcnow', return_value=now):
code = store(client_id, credentials)
code = store(client_id, user)
with patch('homeassistant.util.dt.utcnow',
return_value=now + timedelta(minutes=9, seconds=59)):
assert retrieve(client_id, code) == credentials
assert retrieve(client_id, RESULT_TYPE_USER, code) == user
async def test_ws_current_user(hass, hass_ws_client, hass_access_token):

View file

@ -34,6 +34,7 @@ async def async_get_code(hass, aiohttp_client):
'client_id': CLIENT_ID,
'handler': ['insecure_example', '2nd auth'],
'redirect_uri': CLIENT_REDIRECT_URI,
'type': 'link_user',
})
assert resp.status == 200
step = await resp.json()