Allow auth providers to influence is_active (#15557)
* Allow auth providers to influence is_active * Fix auth script test
This commit is contained in:
parent
a42288d056
commit
2fcacbff23
11 changed files with 82 additions and 23 deletions
|
@ -124,6 +124,7 @@ class AuthManager:
|
||||||
return await self._store.async_create_user(
|
return await self._store.async_create_user(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
name=info.get('name'),
|
name=info.get('name'),
|
||||||
|
is_active=info.get('is_active', False)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_link_user(self, user, credentials):
|
async def async_link_user(self, user, credentials):
|
||||||
|
|
|
@ -135,5 +135,9 @@ class AuthProvider:
|
||||||
"""Return extra user metadata for credentials.
|
"""Return extra user metadata for credentials.
|
||||||
|
|
||||||
Will be used to populate info when creating a new user.
|
Will be used to populate info when creating a new user.
|
||||||
|
|
||||||
|
Values to populate:
|
||||||
|
- name: string
|
||||||
|
- is_active: boolean
|
||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -184,7 +184,8 @@ class HassAuthProvider(AuthProvider):
|
||||||
async def async_user_meta_for_credentials(self, credentials):
|
async def async_user_meta_for_credentials(self, credentials):
|
||||||
"""Get extra info for this credential."""
|
"""Get extra info for this credential."""
|
||||||
return {
|
return {
|
||||||
'name': credentials.data['username']
|
'name': credentials.data['username'],
|
||||||
|
'is_active': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def async_will_remove_credentials(self, credentials):
|
async def async_will_remove_credentials(self, credentials):
|
||||||
|
|
|
@ -75,14 +75,16 @@ class ExampleAuthProvider(AuthProvider):
|
||||||
Will be used to populate info when creating a new user.
|
Will be used to populate info when creating a new user.
|
||||||
"""
|
"""
|
||||||
username = credentials.data['username']
|
username = credentials.data['username']
|
||||||
|
info = {
|
||||||
|
'is_active': True,
|
||||||
|
}
|
||||||
|
|
||||||
for user in self.config['users']:
|
for user in self.config['users']:
|
||||||
if user['username'] == username:
|
if user['username'] == username:
|
||||||
return {
|
info['name'] = user.get('name')
|
||||||
'name': user.get('name')
|
break
|
||||||
}
|
|
||||||
|
|
||||||
return {}
|
return info
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
|
|
@ -70,7 +70,10 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||||
|
|
||||||
Will be used to populate info when creating a new user.
|
Will be used to populate info when creating a new user.
|
||||||
"""
|
"""
|
||||||
return {'name': LEGACY_USER}
|
return {
|
||||||
|
'name': LEGACY_USER,
|
||||||
|
'is_active': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
|
|
@ -81,16 +81,9 @@ async def add_user(hass, provider, args):
|
||||||
print("Username already exists!")
|
print("Username already exists!")
|
||||||
return
|
return
|
||||||
|
|
||||||
credentials = await provider.async_get_or_create_credentials({
|
|
||||||
'username': args.username
|
|
||||||
})
|
|
||||||
|
|
||||||
user = await hass.auth.async_create_user(args.username)
|
|
||||||
await hass.auth.async_link_user(user, credentials)
|
|
||||||
|
|
||||||
# Save username/password
|
# Save username/password
|
||||||
await provider.data.async_save()
|
await provider.data.async_save()
|
||||||
print("User created")
|
print("Auth created")
|
||||||
|
|
||||||
|
|
||||||
async def validate_login(hass, provider, args):
|
async def validate_login(hass, provider, args):
|
||||||
|
|
|
@ -4,6 +4,7 @@ from unittest.mock import Mock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.auth import auth_manager_from_config
|
||||||
from homeassistant.auth.providers import (
|
from homeassistant.auth.providers import (
|
||||||
auth_provider_from_config, homeassistant as hass_auth)
|
auth_provider_from_config, homeassistant as hass_auth)
|
||||||
|
|
||||||
|
@ -112,3 +113,20 @@ async def test_not_allow_set_id():
|
||||||
'id': 'invalid',
|
'id': 'invalid',
|
||||||
})
|
})
|
||||||
assert provider is None
|
assert provider is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_new_users_populate_values(hass, data):
|
||||||
|
"""Test that we populate data for new users."""
|
||||||
|
data.add_auth('hello', 'test-pass')
|
||||||
|
await data.async_save()
|
||||||
|
|
||||||
|
manager = await auth_manager_from_config(hass, [{
|
||||||
|
'type': 'homeassistant'
|
||||||
|
}])
|
||||||
|
provider = manager.auth_providers[0]
|
||||||
|
credentials = await provider.async_get_or_create_credentials({
|
||||||
|
'username': 'hello'
|
||||||
|
})
|
||||||
|
user = await manager.async_get_or_create_user(credentials)
|
||||||
|
assert user.name == 'hello'
|
||||||
|
assert user.is_active
|
||||||
|
|
|
@ -4,7 +4,7 @@ import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.auth import auth_store, models as auth_models
|
from homeassistant.auth import auth_store, models as auth_models, AuthManager
|
||||||
from homeassistant.auth.providers import insecure_example
|
from homeassistant.auth.providers import insecure_example
|
||||||
|
|
||||||
from tests.common import mock_coro
|
from tests.common import mock_coro
|
||||||
|
@ -23,6 +23,7 @@ def provider(hass, store):
|
||||||
'type': 'insecure_example',
|
'type': 'insecure_example',
|
||||||
'users': [
|
'users': [
|
||||||
{
|
{
|
||||||
|
'name': 'Test Name',
|
||||||
'username': 'user-test',
|
'username': 'user-test',
|
||||||
'password': 'password-test',
|
'password': 'password-test',
|
||||||
},
|
},
|
||||||
|
@ -34,7 +35,15 @@ def provider(hass, store):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
async def test_create_new_credential(provider):
|
@pytest.fixture
|
||||||
|
def manager(hass, store, provider):
|
||||||
|
"""Mock manager."""
|
||||||
|
return AuthManager(hass, store, {
|
||||||
|
(provider.type, provider.id): provider
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_new_credential(manager, provider):
|
||||||
"""Test that we create a new credential."""
|
"""Test that we create a new credential."""
|
||||||
credentials = await provider.async_get_or_create_credentials({
|
credentials = await provider.async_get_or_create_credentials({
|
||||||
'username': 'user-test',
|
'username': 'user-test',
|
||||||
|
@ -42,6 +51,10 @@ async def test_create_new_credential(provider):
|
||||||
})
|
})
|
||||||
assert credentials.is_new is True
|
assert credentials.is_new is True
|
||||||
|
|
||||||
|
user = await manager.async_get_or_create_user(credentials)
|
||||||
|
assert user.name == 'Test Name'
|
||||||
|
assert user.is_active
|
||||||
|
|
||||||
|
|
||||||
async def test_match_existing_credentials(store, provider):
|
async def test_match_existing_credentials(store, provider):
|
||||||
"""See if we match existing users."""
|
"""See if we match existing users."""
|
||||||
|
|
|
@ -30,12 +30,16 @@ def manager(hass, store, provider):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
async def test_create_new_credential(provider):
|
async def test_create_new_credential(manager, provider):
|
||||||
"""Test that we create a new credential."""
|
"""Test that we create a new credential."""
|
||||||
credentials = await provider.async_get_or_create_credentials({})
|
credentials = await provider.async_get_or_create_credentials({})
|
||||||
assert credentials.data["username"] is legacy_api_password.LEGACY_USER
|
assert credentials.data["username"] is legacy_api_password.LEGACY_USER
|
||||||
assert credentials.is_new is True
|
assert credentials.is_new is True
|
||||||
|
|
||||||
|
user = await manager.async_get_or_create_user(credentials)
|
||||||
|
assert user.name == legacy_api_password.LEGACY_USER
|
||||||
|
assert user.is_active
|
||||||
|
|
||||||
|
|
||||||
async def test_only_one_credentials(manager, provider):
|
async def test_only_one_credentials(manager, provider):
|
||||||
"""Call create twice will return same credential."""
|
"""Call create twice will return same credential."""
|
||||||
|
|
|
@ -40,11 +40,31 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
||||||
'code': code
|
'code': code
|
||||||
})
|
})
|
||||||
|
|
||||||
# User is not active
|
assert resp.status == 200
|
||||||
assert resp.status == 403
|
tokens = await resp.json()
|
||||||
data = await resp.json()
|
|
||||||
assert data['error'] == 'access_denied'
|
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||||
assert data['error_description'] == 'User is not active'
|
|
||||||
|
# Use refresh token to get more tokens.
|
||||||
|
resp = await client.post('/auth/token', data={
|
||||||
|
'client_id': CLIENT_ID,
|
||||||
|
'grant_type': 'refresh_token',
|
||||||
|
'refresh_token': tokens['refresh_token']
|
||||||
|
})
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
tokens = await resp.json()
|
||||||
|
assert 'refresh_token' not in tokens
|
||||||
|
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||||
|
|
||||||
|
# Test using access token to hit API.
|
||||||
|
resp = await client.get('/api/')
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
|
resp = await client.get('/api/', headers={
|
||||||
|
'authorization': 'Bearer {}'.format(tokens['access_token'])
|
||||||
|
})
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
|
||||||
def test_credential_store_expiration():
|
def test_credential_store_expiration():
|
||||||
|
|
|
@ -47,7 +47,7 @@ async def test_add_user(hass, provider, capsys, hass_storage):
|
||||||
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
||||||
|
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == 'User created\n'
|
assert captured.out == 'Auth created\n'
|
||||||
|
|
||||||
assert len(data.users) == 1
|
assert len(data.users) == 1
|
||||||
data.validate_login('paulus', 'test-pass')
|
data.validate_login('paulus', 'test-pass')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue