Only create front-end client_id once (#15214)
* Only create frontend client_id once * Check user and client_id before create refresh token * Lint * Follow code review comment * Minor clenaup * Update doc string
This commit is contained in:
parent
dffe36761d
commit
a64a66dd62
5 changed files with 121 additions and 51 deletions
|
@ -1,23 +1,22 @@
|
||||||
"""Provide an authentication layer for Home Assistant."""
|
"""Provide an authentication layer for Home Assistant."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import binascii
|
import binascii
|
||||||
from collections import OrderedDict
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import os
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections import OrderedDict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
from homeassistant import data_entry_flow, requirements
|
from homeassistant import data_entry_flow, requirements
|
||||||
from homeassistant.core import callback
|
|
||||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.core import callback
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
from homeassistant.util.decorator import Registry
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -349,6 +348,16 @@ class AuthManager:
|
||||||
return await self._store.async_create_client(
|
return await self._store.async_create_client(
|
||||||
name, redirect_uris, no_secret)
|
name, redirect_uris, no_secret)
|
||||||
|
|
||||||
|
async def async_get_or_create_client(self, name, *, redirect_uris=None,
|
||||||
|
no_secret=False):
|
||||||
|
"""Find a client, if not exists, create a new one."""
|
||||||
|
for client in await self._store.async_get_clients():
|
||||||
|
if client.name == name:
|
||||||
|
return client
|
||||||
|
|
||||||
|
return await self._store.async_create_client(
|
||||||
|
name, redirect_uris, no_secret)
|
||||||
|
|
||||||
async def async_get_client(self, client_id):
|
async def async_get_client(self, client_id):
|
||||||
"""Get a client."""
|
"""Get a client."""
|
||||||
return await self._store.async_get_client(client_id)
|
return await self._store.async_get_client(client_id)
|
||||||
|
@ -392,29 +401,36 @@ class AuthStore:
|
||||||
def __init__(self, hass):
|
def __init__(self, hass):
|
||||||
"""Initialize the auth store."""
|
"""Initialize the auth store."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.users = None
|
self._users = None
|
||||||
self.clients = None
|
self._clients = None
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
|
||||||
async def credentials_for_provider(self, provider_type, provider_id):
|
async def credentials_for_provider(self, provider_type, provider_id):
|
||||||
"""Return credentials for specific auth provider type and id."""
|
"""Return credentials for specific auth provider type and id."""
|
||||||
if self.users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
credentials
|
credentials
|
||||||
for user in self.users.values()
|
for user in self._users.values()
|
||||||
for credentials in user.credentials
|
for credentials in user.credentials
|
||||||
if (credentials.auth_provider_type == provider_type and
|
if (credentials.auth_provider_type == provider_type and
|
||||||
credentials.auth_provider_id == provider_id)
|
credentials.auth_provider_id == provider_id)
|
||||||
]
|
]
|
||||||
|
|
||||||
async def async_get_user(self, user_id):
|
async def async_get_users(self):
|
||||||
"""Retrieve a user."""
|
"""Retrieve all users."""
|
||||||
if self.users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
return self.users.get(user_id)
|
return list(self._users.values())
|
||||||
|
|
||||||
|
async def async_get_user(self, user_id):
|
||||||
|
"""Retrieve a user."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
return self._users.get(user_id)
|
||||||
|
|
||||||
async def async_get_or_create_user(self, credentials, auth_provider):
|
async def async_get_or_create_user(self, credentials, auth_provider):
|
||||||
"""Get or create a new user for given credentials.
|
"""Get or create a new user for given credentials.
|
||||||
|
@ -422,7 +438,7 @@ class AuthStore:
|
||||||
If link_user is passed in, the credentials will be linked to the passed
|
If link_user is passed in, the credentials will be linked to the passed
|
||||||
in user if the credentials are new.
|
in user if the credentials are new.
|
||||||
"""
|
"""
|
||||||
if self.users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
# New credentials, store in user
|
# New credentials, store in user
|
||||||
|
@ -430,7 +446,7 @@ class AuthStore:
|
||||||
info = await auth_provider.async_user_meta_for_credentials(
|
info = await auth_provider.async_user_meta_for_credentials(
|
||||||
credentials)
|
credentials)
|
||||||
# Make owner and activate user if it's the first user.
|
# Make owner and activate user if it's the first user.
|
||||||
if self.users:
|
if self._users:
|
||||||
is_owner = False
|
is_owner = False
|
||||||
is_active = False
|
is_active = False
|
||||||
else:
|
else:
|
||||||
|
@ -442,11 +458,11 @@ class AuthStore:
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
name=info.get('name'),
|
name=info.get('name'),
|
||||||
)
|
)
|
||||||
self.users[new_user.id] = new_user
|
self._users[new_user.id] = new_user
|
||||||
await self.async_link_user(new_user, credentials)
|
await self.async_link_user(new_user, credentials)
|
||||||
return new_user
|
return new_user
|
||||||
|
|
||||||
for user in self.users.values():
|
for user in self._users.values():
|
||||||
for creds in user.credentials:
|
for creds in user.credentials:
|
||||||
if (creds.auth_provider_type == credentials.auth_provider_type
|
if (creds.auth_provider_type == credentials.auth_provider_type
|
||||||
and creds.auth_provider_id ==
|
and creds.auth_provider_id ==
|
||||||
|
@ -463,11 +479,19 @@ class AuthStore:
|
||||||
|
|
||||||
async def async_remove_user(self, user):
|
async def async_remove_user(self, user):
|
||||||
"""Remove a user."""
|
"""Remove a user."""
|
||||||
self.users.pop(user.id)
|
self._users.pop(user.id)
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
|
|
||||||
async def async_create_refresh_token(self, user, client_id):
|
async def async_create_refresh_token(self, user, client_id):
|
||||||
"""Create a new token for a user."""
|
"""Create a new token for a user."""
|
||||||
|
local_user = await self.async_get_user(user.id)
|
||||||
|
if local_user is None:
|
||||||
|
raise ValueError('Invalid user')
|
||||||
|
|
||||||
|
local_client = await self.async_get_client(client_id)
|
||||||
|
if local_client is None:
|
||||||
|
raise ValueError('Invalid client_id')
|
||||||
|
|
||||||
refresh_token = RefreshToken(user, client_id)
|
refresh_token = RefreshToken(user, client_id)
|
||||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
|
@ -475,10 +499,10 @@ class AuthStore:
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token):
|
async def async_get_refresh_token(self, token):
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
if self.users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
for user in self.users.values():
|
for user in self._users.values():
|
||||||
refresh_token = user.refresh_tokens.get(token)
|
refresh_token = user.refresh_tokens.get(token)
|
||||||
if refresh_token is not None:
|
if refresh_token is not None:
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
@ -487,7 +511,7 @@ class AuthStore:
|
||||||
|
|
||||||
async def async_create_client(self, name, redirect_uris, no_secret):
|
async def async_create_client(self, name, redirect_uris, no_secret):
|
||||||
"""Create a new client."""
|
"""Create a new client."""
|
||||||
if self.clients is None:
|
if self._clients is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -499,16 +523,23 @@ class AuthStore:
|
||||||
kwargs['secret'] = None
|
kwargs['secret'] = None
|
||||||
|
|
||||||
client = Client(**kwargs)
|
client = Client(**kwargs)
|
||||||
self.clients[client.id] = client
|
self._clients[client.id] = client
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
return client
|
return client
|
||||||
|
|
||||||
async def async_get_client(self, client_id):
|
async def async_get_clients(self):
|
||||||
"""Get a client."""
|
"""Return all clients."""
|
||||||
if self.clients is None:
|
if self._clients is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
return self.clients.get(client_id)
|
return list(self._clients.values())
|
||||||
|
|
||||||
|
async def async_get_client(self, client_id):
|
||||||
|
"""Get a client."""
|
||||||
|
if self._clients is None:
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
return self._clients.get(client_id)
|
||||||
|
|
||||||
async def async_load(self):
|
async def async_load(self):
|
||||||
"""Load the users."""
|
"""Load the users."""
|
||||||
|
@ -516,12 +547,12 @@ class AuthStore:
|
||||||
|
|
||||||
# Make sure that we're not overriding data if 2 loads happened at the
|
# Make sure that we're not overriding data if 2 loads happened at the
|
||||||
# same time
|
# same time
|
||||||
if self.users is not None:
|
if self._users is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
self.users = {}
|
self._users = {}
|
||||||
self.clients = {}
|
self._clients = {}
|
||||||
return
|
return
|
||||||
|
|
||||||
users = {
|
users = {
|
||||||
|
@ -565,8 +596,8 @@ class AuthStore:
|
||||||
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
|
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
|
||||||
}
|
}
|
||||||
|
|
||||||
self.users = users
|
self._users = users
|
||||||
self.clients = clients
|
self._clients = clients
|
||||||
|
|
||||||
async def async_save(self):
|
async def async_save(self):
|
||||||
"""Save users."""
|
"""Save users."""
|
||||||
|
@ -577,7 +608,7 @@ class AuthStore:
|
||||||
'is_active': user.is_active,
|
'is_active': user.is_active,
|
||||||
'name': user.name,
|
'name': user.name,
|
||||||
}
|
}
|
||||||
for user in self.users.values()
|
for user in self._users.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
credentials = [
|
credentials = [
|
||||||
|
@ -588,7 +619,7 @@ class AuthStore:
|
||||||
'auth_provider_id': credential.auth_provider_id,
|
'auth_provider_id': credential.auth_provider_id,
|
||||||
'data': credential.data,
|
'data': credential.data,
|
||||||
}
|
}
|
||||||
for user in self.users.values()
|
for user in self._users.values()
|
||||||
for credential in user.credentials
|
for credential in user.credentials
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -602,7 +633,7 @@ class AuthStore:
|
||||||
refresh_token.access_token_expiration.total_seconds(),
|
refresh_token.access_token_expiration.total_seconds(),
|
||||||
'token': refresh_token.token,
|
'token': refresh_token.token,
|
||||||
}
|
}
|
||||||
for user in self.users.values()
|
for user in self._users.values()
|
||||||
for refresh_token in user.refresh_tokens.values()
|
for refresh_token in user.refresh_tokens.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -613,7 +644,7 @@ class AuthStore:
|
||||||
'created_at': access_token.created_at.isoformat(),
|
'created_at': access_token.created_at.isoformat(),
|
||||||
'token': access_token.token,
|
'token': access_token.token,
|
||||||
}
|
}
|
||||||
for user in self.users.values()
|
for user in self._users.values()
|
||||||
for refresh_token in user.refresh_tokens.values()
|
for refresh_token in user.refresh_tokens.values()
|
||||||
for access_token in refresh_token.access_tokens
|
for access_token in refresh_token.access_tokens
|
||||||
]
|
]
|
||||||
|
@ -625,7 +656,7 @@ class AuthStore:
|
||||||
'secret': client.secret,
|
'secret': client.secret,
|
||||||
'redirect_uris': client.redirect_uris,
|
'redirect_uris': client.redirect_uris,
|
||||||
}
|
}
|
||||||
for client in self.clients.values()
|
for client in self._clients.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
|
|
@ -201,7 +201,7 @@ def add_manifest_json_key(key, val):
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Set up the serving of the frontend."""
|
"""Set up the serving of the frontend."""
|
||||||
if hass.auth.active:
|
if hass.auth.active:
|
||||||
client = await hass.auth.async_create_client(
|
client = await hass.auth.async_get_or_create_client(
|
||||||
'Home Assistant Frontend',
|
'Home Assistant Frontend',
|
||||||
redirect_uris=['/'],
|
redirect_uris=['/'],
|
||||||
no_secret=True,
|
no_secret=True,
|
||||||
|
|
|
@ -321,7 +321,7 @@ class MockUser(auth.User):
|
||||||
def add_to_auth_manager(self, auth_mgr):
|
def add_to_auth_manager(self, auth_mgr):
|
||||||
"""Test helper to add entry to hass."""
|
"""Test helper to add entry to hass."""
|
||||||
ensure_auth_manager_loaded(auth_mgr)
|
ensure_auth_manager_loaded(auth_mgr)
|
||||||
auth_mgr._store.users[self.id] = self
|
auth_mgr._store._users[self.id] = self
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@ -329,10 +329,10 @@ class MockUser(auth.User):
|
||||||
def ensure_auth_manager_loaded(auth_mgr):
|
def ensure_auth_manager_loaded(auth_mgr):
|
||||||
"""Ensure an auth manager is considered loaded."""
|
"""Ensure an auth manager is considered loaded."""
|
||||||
store = auth_mgr._store
|
store = auth_mgr._store
|
||||||
if store.clients is None:
|
if store._clients is None:
|
||||||
store.clients = {}
|
store._clients = {}
|
||||||
if store.users is None:
|
if store._users is None:
|
||||||
store.users = {}
|
store._users = {}
|
||||||
|
|
||||||
|
|
||||||
class MockModule(object):
|
class MockModule(object):
|
||||||
|
|
|
@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
|
||||||
})
|
})
|
||||||
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
|
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
|
||||||
redirect_uris=[CLIENT_REDIRECT_URI])
|
redirect_uris=[CLIENT_REDIRECT_URI])
|
||||||
hass.auth._store.clients[client.id] = client
|
hass.auth._store._clients[client.id] = client
|
||||||
if setup_api:
|
if setup_api:
|
||||||
await async_setup_component(hass, 'api', {})
|
await async_setup_component(hass, 'api', {})
|
||||||
return await aiohttp_client(hass.http.app)
|
return await aiohttp_client(hass.http.app)
|
||||||
|
|
|
@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage):
|
||||||
await flush_store(manager._store._store)
|
await flush_store(manager._store._store)
|
||||||
|
|
||||||
store2 = auth.AuthStore(hass)
|
store2 = auth.AuthStore(hass)
|
||||||
await store2.async_load()
|
users = await store2.async_get_users()
|
||||||
assert len(store2.users) == 1
|
assert len(users) == 1
|
||||||
assert store2.users[user.id] == user
|
assert users[0] == user
|
||||||
|
|
||||||
assert len(store2.clients) == 1
|
clients = await store2.async_get_clients()
|
||||||
assert store2.clients[client.id] == client
|
assert len(clients) == 1
|
||||||
|
assert clients[0] == client
|
||||||
|
|
||||||
|
|
||||||
def test_access_token_expired():
|
def test_access_token_expired():
|
||||||
|
@ -224,15 +225,18 @@ def test_access_token_expired():
|
||||||
async def test_cannot_retrieve_expired_access_token(hass):
|
async def test_cannot_retrieve_expired_access_token(hass):
|
||||||
"""Test that we cannot retrieve expired access tokens."""
|
"""Test that we cannot retrieve expired access tokens."""
|
||||||
manager = await auth.auth_manager_from_config(hass, [])
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
|
client = await manager.async_create_client('test')
|
||||||
user = MockUser(
|
user = MockUser(
|
||||||
id='mock-user',
|
id='mock-user',
|
||||||
is_owner=False,
|
is_owner=False,
|
||||||
is_active=False,
|
is_active=False,
|
||||||
name='Paulus',
|
name='Paulus',
|
||||||
).add_to_auth_manager(manager)
|
).add_to_auth_manager(manager)
|
||||||
refresh_token = await manager.async_create_refresh_token(user, 'bla')
|
refresh_token = await manager.async_create_refresh_token(user, client.id)
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
assert refresh_token.user.id is user.id
|
||||||
|
assert refresh_token.client_id is client.id
|
||||||
|
|
||||||
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
assert manager.async_get_access_token(access_token.token) is access_token
|
assert manager.async_get_access_token(access_token.token) is access_token
|
||||||
|
|
||||||
with patch('homeassistant.auth.dt_util.utcnow',
|
with patch('homeassistant.auth.dt_util.utcnow',
|
||||||
|
@ -241,3 +245,38 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
||||||
|
|
||||||
# Even with unpatched time, it should have been removed from manager
|
# Even with unpatched time, it should have been removed from manager
|
||||||
assert manager.async_get_access_token(access_token.token) is None
|
assert manager.async_get_access_token(access_token.token) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_or_create_client(hass):
|
||||||
|
"""Test that get_or_create_client works."""
|
||||||
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
|
|
||||||
|
client1 = await manager.async_get_or_create_client(
|
||||||
|
'Test Client', redirect_uris=['https://test.com/1'])
|
||||||
|
assert client1.name is 'Test Client'
|
||||||
|
|
||||||
|
client2 = await manager.async_get_or_create_client(
|
||||||
|
'Test Client', redirect_uris=['https://test.com/1'])
|
||||||
|
assert client2.id is client1.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cannot_create_refresh_token_with_invalide_client_id(hass):
|
||||||
|
"""Test that we cannot create refresh token with invalid client id."""
|
||||||
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
|
user = MockUser(
|
||||||
|
id='mock-user',
|
||||||
|
is_owner=False,
|
||||||
|
is_active=False,
|
||||||
|
name='Paulus',
|
||||||
|
).add_to_auth_manager(manager)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await manager.async_create_refresh_token(user, 'bla')
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cannot_create_refresh_token_with_invalide_user(hass):
|
||||||
|
"""Test that we cannot create refresh token with invalid client id."""
|
||||||
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
|
client = await manager.async_create_client('test')
|
||||||
|
user = MockUser(id='invalid-user')
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await manager.async_create_refresh_token(user, client.id)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue