Allow system users to refresh tokens (#15574)
This commit is contained in:
parent
ea2ff6aae3
commit
4e7dbf9ce5
2 changed files with 93 additions and 17 deletions
|
@ -252,6 +252,20 @@ class GrantTokenView(HomeAssistantView):
|
|||
hass = request.app['hass']
|
||||
data = await request.post()
|
||||
|
||||
grant_type = data.get('grant_type')
|
||||
|
||||
if grant_type == 'authorization_code':
|
||||
return await self._async_handle_auth_code(hass, data)
|
||||
|
||||
if grant_type == 'refresh_token':
|
||||
return await self._async_handle_refresh_token(hass, data)
|
||||
|
||||
return self.json({
|
||||
'error': 'unsupported_grant_type',
|
||||
}, status_code=400)
|
||||
|
||||
async def _async_handle_auth_code(self, hass, data):
|
||||
"""Handle authorization code request."""
|
||||
client_id = data.get('client_id')
|
||||
if client_id is None or not indieauth.verify_client_id(client_id):
|
||||
return self.json({
|
||||
|
@ -259,21 +273,6 @@ class GrantTokenView(HomeAssistantView):
|
|||
'error_description': 'Invalid client id',
|
||||
}, status_code=400)
|
||||
|
||||
grant_type = data.get('grant_type')
|
||||
|
||||
if grant_type == 'authorization_code':
|
||||
return await self._async_handle_auth_code(hass, client_id, data)
|
||||
|
||||
if grant_type == 'refresh_token':
|
||||
return await self._async_handle_refresh_token(
|
||||
hass, client_id, data)
|
||||
|
||||
return self.json({
|
||||
'error': 'unsupported_grant_type',
|
||||
}, status_code=400)
|
||||
|
||||
async def _async_handle_auth_code(self, hass, client_id, data):
|
||||
"""Handle authorization code request."""
|
||||
code = data.get('code')
|
||||
|
||||
if code is None:
|
||||
|
@ -309,8 +308,15 @@ class GrantTokenView(HomeAssistantView):
|
|||
int(refresh_token.access_token_expiration.total_seconds()),
|
||||
})
|
||||
|
||||
async def _async_handle_refresh_token(self, hass, client_id, data):
|
||||
async def _async_handle_refresh_token(self, hass, data):
|
||||
"""Handle authorization code request."""
|
||||
client_id = data.get('client_id')
|
||||
if client_id is not None and not indieauth.verify_client_id(client_id):
|
||||
return self.json({
|
||||
'error': 'invalid_request',
|
||||
'error_description': 'Invalid client id',
|
||||
}, status_code=400)
|
||||
|
||||
token = data.get('refresh_token')
|
||||
|
||||
if token is None:
|
||||
|
@ -320,11 +326,16 @@ class GrantTokenView(HomeAssistantView):
|
|||
|
||||
refresh_token = await hass.auth.async_get_refresh_token(token)
|
||||
|
||||
if refresh_token is None or refresh_token.client_id != client_id:
|
||||
if refresh_token is None:
|
||||
return self.json({
|
||||
'error': 'invalid_grant',
|
||||
}, status_code=400)
|
||||
|
||||
if refresh_token.client_id != client_id:
|
||||
return self.json({
|
||||
'error': 'invalid_request',
|
||||
}, status_code=400)
|
||||
|
||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
return self.json({
|
||||
|
|
|
@ -130,3 +130,68 @@ async def test_cors_on_token(hass, aiohttp_client):
|
|||
'origin': 'http://example.com'
|
||||
})
|
||||
assert resp.headers['Access-Control-Allow-Origin'] == 'http://example.com'
|
||||
|
||||
|
||||
async def test_refresh_token_system_generated(hass, aiohttp_client):
|
||||
"""Test that we can get access tokens for system generated user."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
user = await hass.auth.async_create_system_user('Test System')
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user, None)
|
||||
|
||||
resp = await client.post('/auth/token', data={
|
||||
'client_id': 'https://this-is-not-allowed-for-system-users.com/',
|
||||
'grant_type': 'refresh_token',
|
||||
'refresh_token': refresh_token.token,
|
||||
})
|
||||
|
||||
assert resp.status == 400
|
||||
result = await resp.json()
|
||||
assert result['error'] == 'invalid_request'
|
||||
|
||||
resp = await client.post('/auth/token', data={
|
||||
'grant_type': 'refresh_token',
|
||||
'refresh_token': refresh_token.token,
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||
|
||||
|
||||
async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
||||
"""Test that we verify client ID."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
user = await hass.auth.async_create_user('Test User')
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)
|
||||
|
||||
# No client ID
|
||||
resp = await client.post('/auth/token', data={
|
||||
'grant_type': 'refresh_token',
|
||||
'refresh_token': refresh_token.token,
|
||||
})
|
||||
|
||||
assert resp.status == 400
|
||||
result = await resp.json()
|
||||
assert result['error'] == 'invalid_request'
|
||||
|
||||
# Different client ID
|
||||
resp = await client.post('/auth/token', data={
|
||||
'client_id': 'http://example-different.com',
|
||||
'grant_type': 'refresh_token',
|
||||
'refresh_token': refresh_token.token,
|
||||
})
|
||||
|
||||
assert resp.status == 400
|
||||
result = await resp.json()
|
||||
assert result['error'] == 'invalid_request'
|
||||
|
||||
# Correct
|
||||
resp = await client.post('/auth/token', data={
|
||||
'client_id': CLIENT_ID,
|
||||
'grant_type': 'refresh_token',
|
||||
'refresh_token': refresh_token.token,
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||
|
|
Loading…
Add table
Reference in a new issue