Get user after login flow finished (#16047)
* Get user after login flow finished * Add optional parameter 'type' to /auth/login_flow * Update __init__.py
This commit is contained in:
parent
b1ba11510b
commit
f84a31871e
6 changed files with 83 additions and 52 deletions
|
@ -2,7 +2,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast, Union
|
||||
|
||||
import jwt
|
||||
|
||||
|
@ -257,15 +257,20 @@ class AuthManager:
|
|||
|
||||
async def _async_finish_login_flow(
|
||||
self, context: Optional[Dict], result: Dict[str, Any]) \
|
||||
-> Optional[models.Credentials]:
|
||||
"""Result of a credential login flow."""
|
||||
-> Optional[Union[models.User, models.Credentials]]:
|
||||
"""Return a user as result of login flow."""
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
|
||||
auth_provider = self._providers[result['handler']]
|
||||
return await auth_provider.async_get_or_create_credentials(
|
||||
cred = await auth_provider.async_get_or_create_credentials(
|
||||
result['data'])
|
||||
|
||||
if context is not None and context.get('credential_only'):
|
||||
return cred
|
||||
|
||||
return await self.async_get_or_create_user(cred)
|
||||
|
||||
@callback
|
||||
def _async_get_auth_provider(
|
||||
self, credentials: models.Credentials) -> Optional[AuthProvider]:
|
||||
|
|
|
@ -51,6 +51,7 @@ from datetime import timedelta
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.models import User, Credentials
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http.ban import log_invalid_auth
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
|
@ -68,22 +69,25 @@ SCHEMA_WS_CURRENT_USER = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
|||
vol.Required('type'): WS_TYPE_CURRENT_USER,
|
||||
})
|
||||
|
||||
RESULT_TYPE_CREDENTIALS = 'credentials'
|
||||
RESULT_TYPE_USER = 'user'
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Component to allow users to login."""
|
||||
store_credentials, retrieve_credentials = _create_cred_store()
|
||||
store_result, retrieve_result = _create_auth_code_store()
|
||||
|
||||
hass.http.register_view(GrantTokenView(retrieve_credentials))
|
||||
hass.http.register_view(LinkUserView(retrieve_credentials))
|
||||
hass.http.register_view(GrantTokenView(retrieve_result))
|
||||
hass.http.register_view(LinkUserView(retrieve_result))
|
||||
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_CURRENT_USER, websocket_current_user,
|
||||
SCHEMA_WS_CURRENT_USER
|
||||
)
|
||||
|
||||
await login_flow.async_setup(hass, store_credentials)
|
||||
await login_flow.async_setup(hass, store_result)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -96,9 +100,9 @@ class GrantTokenView(HomeAssistantView):
|
|||
requires_auth = False
|
||||
cors_allowed = True
|
||||
|
||||
def __init__(self, retrieve_credentials):
|
||||
def __init__(self, retrieve_user):
|
||||
"""Initialize the grant token view."""
|
||||
self._retrieve_credentials = retrieve_credentials
|
||||
self._retrieve_user = retrieve_user
|
||||
|
||||
@log_invalid_auth
|
||||
async def post(self, request):
|
||||
|
@ -134,15 +138,16 @@ class GrantTokenView(HomeAssistantView):
|
|||
'error': 'invalid_request',
|
||||
}, status_code=400)
|
||||
|
||||
credentials = self._retrieve_credentials(client_id, code)
|
||||
user = self._retrieve_user(client_id, RESULT_TYPE_USER, code)
|
||||
|
||||
if credentials is None:
|
||||
if user is None or not isinstance(user, User):
|
||||
return self.json({
|
||||
'error': 'invalid_request',
|
||||
'error_description': 'Invalid code',
|
||||
}, status_code=400)
|
||||
|
||||
user = await hass.auth.async_get_or_create_user(credentials)
|
||||
# refresh user
|
||||
user = await hass.auth.async_get_user(user.id)
|
||||
|
||||
if not user.is_active:
|
||||
return self.json({
|
||||
|
@ -220,7 +225,7 @@ class LinkUserView(HomeAssistantView):
|
|||
user = request['hass_user']
|
||||
|
||||
credentials = self._retrieve_credentials(
|
||||
data['client_id'], data['code'])
|
||||
data['client_id'], RESULT_TYPE_CREDENTIALS, data['code'])
|
||||
|
||||
if credentials is None:
|
||||
return self.json_message('Invalid code', status_code=400)
|
||||
|
@ -230,37 +235,45 @@ class LinkUserView(HomeAssistantView):
|
|||
|
||||
|
||||
@callback
|
||||
def _create_cred_store():
|
||||
"""Create a credential store."""
|
||||
temp_credentials = {}
|
||||
def _create_auth_code_store():
|
||||
"""Create an in memory store."""
|
||||
temp_results = {}
|
||||
|
||||
@callback
|
||||
def store_credentials(client_id, credentials):
|
||||
"""Store credentials and return a code to retrieve it."""
|
||||
def store_result(client_id, result):
|
||||
"""Store flow result and return a code to retrieve it."""
|
||||
if isinstance(result, User):
|
||||
result_type = RESULT_TYPE_USER
|
||||
elif isinstance(result, Credentials):
|
||||
result_type = RESULT_TYPE_CREDENTIALS
|
||||
else:
|
||||
raise ValueError('result has to be either User or Credentials')
|
||||
|
||||
code = uuid.uuid4().hex
|
||||
temp_credentials[(client_id, code)] = (dt_util.utcnow(), credentials)
|
||||
temp_results[(client_id, result_type, code)] = \
|
||||
(dt_util.utcnow(), result_type, result)
|
||||
return code
|
||||
|
||||
@callback
|
||||
def retrieve_credentials(client_id, code):
|
||||
"""Retrieve credentials."""
|
||||
key = (client_id, code)
|
||||
def retrieve_result(client_id, result_type, code):
|
||||
"""Retrieve flow result."""
|
||||
key = (client_id, result_type, code)
|
||||
|
||||
if key not in temp_credentials:
|
||||
if key not in temp_results:
|
||||
return None
|
||||
|
||||
created, credentials = temp_credentials.pop(key)
|
||||
created, _, result = temp_results.pop(key)
|
||||
|
||||
# OAuth 4.2.1
|
||||
# The authorization code MUST expire shortly after it is issued to
|
||||
# mitigate the risk of leaks. A maximum authorization code lifetime of
|
||||
# 10 minutes is RECOMMENDED.
|
||||
if dt_util.utcnow() - created < timedelta(minutes=10):
|
||||
return credentials
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
return store_credentials, retrieve_credentials
|
||||
return store_result, retrieve_result
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -22,10 +22,14 @@ Pass in parameter 'client_id' and 'redirect_url' validate by indieauth.
|
|||
Pass in parameter 'handler' to specify the auth provider to use. Auth providers
|
||||
are identified by type and id.
|
||||
|
||||
And optional parameter 'type' has to set as 'link_user' if login flow used for
|
||||
link credential to exist user. Default 'type' is 'authorize'.
|
||||
|
||||
{
|
||||
"client_id": "https://hassbian.local:8123/",
|
||||
"handler": ["local_provider", null],
|
||||
"redirect_url": "https://hassbian.local:8123/"
|
||||
"redirect_url": "https://hassbian.local:8123/",
|
||||
"type': "authorize"
|
||||
}
|
||||
|
||||
Return value will be a step in a data entry flow. See the docs for data entry
|
||||
|
@ -49,6 +53,9 @@ flow for details.
|
|||
Progress the flow. Most flows will be 1 page, but could optionally add extra
|
||||
login challenges, like TFA. Once the flow has finished, the returned step will
|
||||
have type "create_entry" and "result" key will contain an authorization code.
|
||||
The authorization code associated with an authorized user by default, it will
|
||||
associate with an credential if "type" set to "link_user" in
|
||||
"/auth/login_flow"
|
||||
|
||||
{
|
||||
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
|
||||
|
@ -71,12 +78,12 @@ from homeassistant.components.http.view import HomeAssistantView
|
|||
from . import indieauth
|
||||
|
||||
|
||||
async def async_setup(hass, store_credentials):
|
||||
async def async_setup(hass, store_result):
|
||||
"""Component to allow users to login."""
|
||||
hass.http.register_view(AuthProvidersView)
|
||||
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow))
|
||||
hass.http.register_view(
|
||||
LoginFlowResourceView(hass.auth.login_flow, store_credentials))
|
||||
LoginFlowResourceView(hass.auth.login_flow, store_result))
|
||||
|
||||
|
||||
class AuthProvidersView(HomeAssistantView):
|
||||
|
@ -138,6 +145,7 @@ class LoginFlowIndexView(HomeAssistantView):
|
|||
vol.Required('client_id'): str,
|
||||
vol.Required('handler'): vol.Any(str, list),
|
||||
vol.Required('redirect_uri'): str,
|
||||
vol.Optional('type', default='authorize'): str,
|
||||
}))
|
||||
@log_invalid_auth
|
||||
async def post(self, request, data):
|
||||
|
@ -153,7 +161,10 @@ class LoginFlowIndexView(HomeAssistantView):
|
|||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(
|
||||
handler, context={'ip_address': request[KEY_REAL_IP]})
|
||||
handler, context={
|
||||
'ip_address': request[KEY_REAL_IP],
|
||||
'credential_only': data.get('type') == 'link_user',
|
||||
})
|
||||
except data_entry_flow.UnknownHandler:
|
||||
return self.json_message('Invalid handler specified', 404)
|
||||
except data_entry_flow.UnknownStep:
|
||||
|
@ -169,10 +180,10 @@ class LoginFlowResourceView(HomeAssistantView):
|
|||
name = 'api:auth:login_flow:resource'
|
||||
requires_auth = False
|
||||
|
||||
def __init__(self, flow_mgr, store_credentials):
|
||||
def __init__(self, flow_mgr, store_result):
|
||||
"""Initialize the login flow resource view."""
|
||||
self._flow_mgr = flow_mgr
|
||||
self._store_credentials = store_credentials
|
||||
self._store_result = store_result
|
||||
|
||||
async def get(self, request):
|
||||
"""Do not allow getting status of a flow in progress."""
|
||||
|
@ -212,7 +223,7 @@ class LoginFlowResourceView(HomeAssistantView):
|
|||
return self.json(_prepare_result_json(result))
|
||||
|
||||
result.pop('data')
|
||||
result['result'] = self._store_credentials(client_id, result['result'])
|
||||
result['result'] = self._store_result(client_id, result['result'])
|
||||
|
||||
return self.json(result)
|
||||
|
||||
|
|
|
@ -77,8 +77,7 @@ async def test_create_new_user(hass, hass_storage):
|
|||
'password': 'test-pass',
|
||||
})
|
||||
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
credentials = step['result']
|
||||
user = await manager.async_get_or_create_user(credentials)
|
||||
user = step['result']
|
||||
assert user is not None
|
||||
assert user.is_owner is False
|
||||
assert user.name == 'Test Name'
|
||||
|
@ -134,9 +133,8 @@ async def test_login_as_existing_user(mock_hass):
|
|||
'password': 'test-pass',
|
||||
})
|
||||
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
credentials = step['result']
|
||||
|
||||
user = await manager.async_get_or_create_user(credentials)
|
||||
user = step['result']
|
||||
assert user is not None
|
||||
assert user.id == 'mock-user'
|
||||
assert user.is_owner is False
|
||||
|
@ -166,16 +164,18 @@ async def test_linking_user_to_two_auth_providers(hass, hass_storage):
|
|||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
user = await manager.async_get_or_create_user(step['result'])
|
||||
user = step['result']
|
||||
assert user is not None
|
||||
|
||||
step = await manager.login_flow.async_init(('insecure_example',
|
||||
'another-provider'))
|
||||
step = await manager.login_flow.async_init(
|
||||
('insecure_example', 'another-provider'),
|
||||
context={'credential_only': True})
|
||||
step = await manager.login_flow.async_configure(step['flow_id'], {
|
||||
'username': 'another-user',
|
||||
'password': 'another-password',
|
||||
})
|
||||
await manager.async_link_user(user, step['result'])
|
||||
new_credential = step['result']
|
||||
await manager.async_link_user(user, new_credential)
|
||||
assert len(user.credentials) == 2
|
||||
|
||||
|
||||
|
@ -197,7 +197,7 @@ async def test_saving_loading(hass, hass_storage):
|
|||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
user = await manager.async_get_or_create_user(step['result'])
|
||||
user = step['result']
|
||||
await manager.async_activate_user(user)
|
||||
await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
|
||||
|
|
|
@ -3,13 +3,14 @@ from datetime import timedelta
|
|||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.auth.models import Credentials
|
||||
from homeassistant.components.auth import RESULT_TYPE_USER
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
from homeassistant.components import auth
|
||||
|
||||
from . import async_setup_auth
|
||||
|
||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
|
||||
|
||||
|
||||
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
||||
|
@ -74,26 +75,26 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
|||
assert resp.status == 200
|
||||
|
||||
|
||||
def test_credential_store_expiration():
|
||||
"""Test that the credential store will not return expired tokens."""
|
||||
store, retrieve = auth._create_cred_store()
|
||||
def test_auth_code_store_expiration():
|
||||
"""Test that the auth code store will not return expired tokens."""
|
||||
store, retrieve = auth._create_auth_code_store()
|
||||
client_id = 'bla'
|
||||
credentials = 'creds'
|
||||
user = MockUser(id='mock_user')
|
||||
now = utcnow()
|
||||
|
||||
with patch('homeassistant.util.dt.utcnow', return_value=now):
|
||||
code = store(client_id, credentials)
|
||||
code = store(client_id, user)
|
||||
|
||||
with patch('homeassistant.util.dt.utcnow',
|
||||
return_value=now + timedelta(minutes=10)):
|
||||
assert retrieve(client_id, code) is None
|
||||
assert retrieve(client_id, RESULT_TYPE_USER, code) is None
|
||||
|
||||
with patch('homeassistant.util.dt.utcnow', return_value=now):
|
||||
code = store(client_id, credentials)
|
||||
code = store(client_id, user)
|
||||
|
||||
with patch('homeassistant.util.dt.utcnow',
|
||||
return_value=now + timedelta(minutes=9, seconds=59)):
|
||||
assert retrieve(client_id, code) == credentials
|
||||
assert retrieve(client_id, RESULT_TYPE_USER, code) == user
|
||||
|
||||
|
||||
async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
||||
|
|
|
@ -34,6 +34,7 @@ async def async_get_code(hass, aiohttp_client):
|
|||
'client_id': CLIENT_ID,
|
||||
'handler': ['insecure_example', '2nd auth'],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||
'type': 'link_user',
|
||||
})
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
|
Loading…
Add table
Reference in a new issue