Allow auth providers to influence is_active (#15557)

* Allow auth providers to influence is_active

* Fix auth script test
This commit is contained in:
Paulus Schoutsen 2018-07-19 22:10:36 +02:00 committed by GitHub
parent a42288d056
commit 2fcacbff23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 82 additions and 23 deletions

View file

@ -124,6 +124,7 @@ class AuthManager:
return await self._store.async_create_user( return await self._store.async_create_user(
credentials=credentials, credentials=credentials,
name=info.get('name'), name=info.get('name'),
is_active=info.get('is_active', False)
) )
async def async_link_user(self, user, credentials): async def async_link_user(self, user, credentials):

View file

@ -135,5 +135,9 @@ class AuthProvider:
"""Return extra user metadata for credentials. """Return extra user metadata for credentials.
Will be used to populate info when creating a new user. Will be used to populate info when creating a new user.
Values to populate:
- name: string
- is_active: boolean
""" """
return {} return {}

View file

@ -184,7 +184,8 @@ class HassAuthProvider(AuthProvider):
async def async_user_meta_for_credentials(self, credentials): async def async_user_meta_for_credentials(self, credentials):
"""Get extra info for this credential.""" """Get extra info for this credential."""
return { return {
'name': credentials.data['username'] 'name': credentials.data['username'],
'is_active': True,
} }
async def async_will_remove_credentials(self, credentials): async def async_will_remove_credentials(self, credentials):

View file

@ -75,14 +75,16 @@ class ExampleAuthProvider(AuthProvider):
Will be used to populate info when creating a new user. Will be used to populate info when creating a new user.
""" """
username = credentials.data['username'] username = credentials.data['username']
info = {
'is_active': True,
}
for user in self.config['users']: for user in self.config['users']:
if user['username'] == username: if user['username'] == username:
return { info['name'] = user.get('name')
'name': user.get('name') break
}
return {} return info
class LoginFlow(data_entry_flow.FlowHandler): class LoginFlow(data_entry_flow.FlowHandler):

View file

@ -70,7 +70,10 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
Will be used to populate info when creating a new user. Will be used to populate info when creating a new user.
""" """
return {'name': LEGACY_USER} return {
'name': LEGACY_USER,
'is_active': True,
}
class LoginFlow(data_entry_flow.FlowHandler): class LoginFlow(data_entry_flow.FlowHandler):

View file

@ -81,16 +81,9 @@ async def add_user(hass, provider, args):
print("Username already exists!") print("Username already exists!")
return return
credentials = await provider.async_get_or_create_credentials({
'username': args.username
})
user = await hass.auth.async_create_user(args.username)
await hass.auth.async_link_user(user, credentials)
# Save username/password # Save username/password
await provider.data.async_save() await provider.data.async_save()
print("User created") print("Auth created")
async def validate_login(hass, provider, args): async def validate_login(hass, provider, args):

View file

@ -4,6 +4,7 @@ from unittest.mock import Mock
import pytest import pytest
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth import auth_manager_from_config
from homeassistant.auth.providers import ( from homeassistant.auth.providers import (
auth_provider_from_config, homeassistant as hass_auth) auth_provider_from_config, homeassistant as hass_auth)
@ -112,3 +113,20 @@ async def test_not_allow_set_id():
'id': 'invalid', 'id': 'invalid',
}) })
assert provider is None assert provider is None
async def test_new_users_populate_values(hass, data):
"""Test that we populate data for new users."""
data.add_auth('hello', 'test-pass')
await data.async_save()
manager = await auth_manager_from_config(hass, [{
'type': 'homeassistant'
}])
provider = manager.auth_providers[0]
credentials = await provider.async_get_or_create_credentials({
'username': 'hello'
})
user = await manager.async_get_or_create_user(credentials)
assert user.name == 'hello'
assert user.is_active

View file

@ -4,7 +4,7 @@ import uuid
import pytest import pytest
from homeassistant.auth import auth_store, models as auth_models from homeassistant.auth import auth_store, models as auth_models, AuthManager
from homeassistant.auth.providers import insecure_example from homeassistant.auth.providers import insecure_example
from tests.common import mock_coro from tests.common import mock_coro
@ -23,6 +23,7 @@ def provider(hass, store):
'type': 'insecure_example', 'type': 'insecure_example',
'users': [ 'users': [
{ {
'name': 'Test Name',
'username': 'user-test', 'username': 'user-test',
'password': 'password-test', 'password': 'password-test',
}, },
@ -34,7 +35,15 @@ def provider(hass, store):
}) })
async def test_create_new_credential(provider): @pytest.fixture
def manager(hass, store, provider):
"""Mock manager."""
return AuthManager(hass, store, {
(provider.type, provider.id): provider
})
async def test_create_new_credential(manager, provider):
"""Test that we create a new credential.""" """Test that we create a new credential."""
credentials = await provider.async_get_or_create_credentials({ credentials = await provider.async_get_or_create_credentials({
'username': 'user-test', 'username': 'user-test',
@ -42,6 +51,10 @@ async def test_create_new_credential(provider):
}) })
assert credentials.is_new is True assert credentials.is_new is True
user = await manager.async_get_or_create_user(credentials)
assert user.name == 'Test Name'
assert user.is_active
async def test_match_existing_credentials(store, provider): async def test_match_existing_credentials(store, provider):
"""See if we match existing users.""" """See if we match existing users."""

View file

@ -30,12 +30,16 @@ def manager(hass, store, provider):
}) })
async def test_create_new_credential(provider): async def test_create_new_credential(manager, provider):
"""Test that we create a new credential.""" """Test that we create a new credential."""
credentials = await provider.async_get_or_create_credentials({}) credentials = await provider.async_get_or_create_credentials({})
assert credentials.data["username"] is legacy_api_password.LEGACY_USER assert credentials.data["username"] is legacy_api_password.LEGACY_USER
assert credentials.is_new is True assert credentials.is_new is True
user = await manager.async_get_or_create_user(credentials)
assert user.name == legacy_api_password.LEGACY_USER
assert user.is_active
async def test_only_one_credentials(manager, provider): async def test_only_one_credentials(manager, provider):
"""Call create twice will return same credential.""" """Call create twice will return same credential."""

View file

@ -40,11 +40,31 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
'code': code 'code': code
}) })
# User is not active assert resp.status == 200
assert resp.status == 403 tokens = await resp.json()
data = await resp.json()
assert data['error'] == 'access_denied' assert hass.auth.async_get_access_token(tokens['access_token']) is not None
assert data['error_description'] == 'User is not active'
# Use refresh token to get more tokens.
resp = await client.post('/auth/token', data={
'client_id': CLIENT_ID,
'grant_type': 'refresh_token',
'refresh_token': tokens['refresh_token']
})
assert resp.status == 200
tokens = await resp.json()
assert 'refresh_token' not in tokens
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
# Test using access token to hit API.
resp = await client.get('/api/')
assert resp.status == 401
resp = await client.get('/api/', headers={
'authorization': 'Bearer {}'.format(tokens['access_token'])
})
assert resp.status == 200
def test_credential_store_expiration(): def test_credential_store_expiration():

View file

@ -47,7 +47,7 @@ async def test_add_user(hass, provider, capsys, hass_storage):
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1 assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'User created\n' assert captured.out == 'Auth created\n'
assert len(data.users) == 1 assert len(data.users) == 1
data.validate_login('paulus', 'test-pass') data.validate_login('paulus', 'test-pass')