Only create front-end client_id once (#15214)
* Only create frontend client_id once * Check user and client_id before create refresh token * Lint * Follow code review comment * Minor clenaup * Update doc string
This commit is contained in:
parent
dffe36761d
commit
a64a66dd62
5 changed files with 121 additions and 51 deletions
|
@ -1,23 +1,22 @@
|
|||
"""Provide an authentication layer for Home Assistant."""
|
||||
import asyncio
|
||||
import binascii
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant import data_entry_flow, requirements
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||
from homeassistant.util.decorator import Registry
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -349,6 +348,16 @@ class AuthManager:
|
|||
return await self._store.async_create_client(
|
||||
name, redirect_uris, no_secret)
|
||||
|
||||
async def async_get_or_create_client(self, name, *, redirect_uris=None,
|
||||
no_secret=False):
|
||||
"""Find a client, if not exists, create a new one."""
|
||||
for client in await self._store.async_get_clients():
|
||||
if client.name == name:
|
||||
return client
|
||||
|
||||
return await self._store.async_create_client(
|
||||
name, redirect_uris, no_secret)
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
return await self._store.async_get_client(client_id)
|
||||
|
@ -392,29 +401,36 @@ class AuthStore:
|
|||
def __init__(self, hass):
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self.users = None
|
||||
self.clients = None
|
||||
self._users = None
|
||||
self._clients = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
async def credentials_for_provider(self, provider_type, provider_id):
|
||||
"""Return credentials for specific auth provider type and id."""
|
||||
if self.users is None:
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return [
|
||||
credentials
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for credentials in user.credentials
|
||||
if (credentials.auth_provider_type == provider_type and
|
||||
credentials.auth_provider_id == provider_id)
|
||||
]
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user."""
|
||||
if self.users is None:
|
||||
async def async_get_users(self):
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return self.users.get(user_id)
|
||||
return list(self._users.values())
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def async_get_or_create_user(self, credentials, auth_provider):
|
||||
"""Get or create a new user for given credentials.
|
||||
|
@ -422,7 +438,7 @@ class AuthStore:
|
|||
If link_user is passed in, the credentials will be linked to the passed
|
||||
in user if the credentials are new.
|
||||
"""
|
||||
if self.users is None:
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
# New credentials, store in user
|
||||
|
@ -430,7 +446,7 @@ class AuthStore:
|
|||
info = await auth_provider.async_user_meta_for_credentials(
|
||||
credentials)
|
||||
# Make owner and activate user if it's the first user.
|
||||
if self.users:
|
||||
if self._users:
|
||||
is_owner = False
|
||||
is_active = False
|
||||
else:
|
||||
|
@ -442,11 +458,11 @@ class AuthStore:
|
|||
is_active=is_active,
|
||||
name=info.get('name'),
|
||||
)
|
||||
self.users[new_user.id] = new_user
|
||||
self._users[new_user.id] = new_user
|
||||
await self.async_link_user(new_user, credentials)
|
||||
return new_user
|
||||
|
||||
for user in self.users.values():
|
||||
for user in self._users.values():
|
||||
for creds in user.credentials:
|
||||
if (creds.auth_provider_type == credentials.auth_provider_type
|
||||
and creds.auth_provider_id ==
|
||||
|
@ -463,11 +479,19 @@ class AuthStore:
|
|||
|
||||
async def async_remove_user(self, user):
|
||||
"""Remove a user."""
|
||||
self.users.pop(user.id)
|
||||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id):
|
||||
"""Create a new token for a user."""
|
||||
local_user = await self.async_get_user(user.id)
|
||||
if local_user is None:
|
||||
raise ValueError('Invalid user')
|
||||
|
||||
local_client = await self.async_get_client(client_id)
|
||||
if local_client is None:
|
||||
raise ValueError('Invalid client_id')
|
||||
|
||||
refresh_token = RefreshToken(user, client_id)
|
||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||
await self.async_save()
|
||||
|
@ -475,10 +499,10 @@ class AuthStore:
|
|||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
if self.users is None:
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
for user in self.users.values():
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token)
|
||||
if refresh_token is not None:
|
||||
return refresh_token
|
||||
|
@ -487,7 +511,7 @@ class AuthStore:
|
|||
|
||||
async def async_create_client(self, name, redirect_uris, no_secret):
|
||||
"""Create a new client."""
|
||||
if self.clients is None:
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
kwargs = {
|
||||
|
@ -499,16 +523,23 @@ class AuthStore:
|
|||
kwargs['secret'] = None
|
||||
|
||||
client = Client(**kwargs)
|
||||
self.clients[client.id] = client
|
||||
self._clients[client.id] = client
|
||||
await self.async_save()
|
||||
return client
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
if self.clients is None:
|
||||
async def async_get_clients(self):
|
||||
"""Return all clients."""
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
return self.clients.get(client_id)
|
||||
return list(self._clients.values())
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
return self._clients.get(client_id)
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the users."""
|
||||
|
@ -516,12 +547,12 @@ class AuthStore:
|
|||
|
||||
# Make sure that we're not overriding data if 2 loads happened at the
|
||||
# same time
|
||||
if self.users is not None:
|
||||
if self._users is not None:
|
||||
return
|
||||
|
||||
if data is None:
|
||||
self.users = {}
|
||||
self.clients = {}
|
||||
self._users = {}
|
||||
self._clients = {}
|
||||
return
|
||||
|
||||
users = {
|
||||
|
@ -565,8 +596,8 @@ class AuthStore:
|
|||
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
|
||||
}
|
||||
|
||||
self.users = users
|
||||
self.clients = clients
|
||||
self._users = users
|
||||
self._clients = clients
|
||||
|
||||
async def async_save(self):
|
||||
"""Save users."""
|
||||
|
@ -577,7 +608,7 @@ class AuthStore:
|
|||
'is_active': user.is_active,
|
||||
'name': user.name,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
]
|
||||
|
||||
credentials = [
|
||||
|
@ -588,7 +619,7 @@ class AuthStore:
|
|||
'auth_provider_id': credential.auth_provider_id,
|
||||
'data': credential.data,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for credential in user.credentials
|
||||
]
|
||||
|
||||
|
@ -602,7 +633,7 @@ class AuthStore:
|
|||
refresh_token.access_token_expiration.total_seconds(),
|
||||
'token': refresh_token.token,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
]
|
||||
|
||||
|
@ -613,7 +644,7 @@ class AuthStore:
|
|||
'created_at': access_token.created_at.isoformat(),
|
||||
'token': access_token.token,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
for access_token in refresh_token.access_tokens
|
||||
]
|
||||
|
@ -625,7 +656,7 @@ class AuthStore:
|
|||
'secret': client.secret,
|
||||
'redirect_uris': client.redirect_uris,
|
||||
}
|
||||
for client in self.clients.values()
|
||||
for client in self._clients.values()
|
||||
]
|
||||
|
||||
data = {
|
||||
|
|
|
@ -201,7 +201,7 @@ def add_manifest_json_key(key, val):
|
|||
async def async_setup(hass, config):
|
||||
"""Set up the serving of the frontend."""
|
||||
if hass.auth.active:
|
||||
client = await hass.auth.async_create_client(
|
||||
client = await hass.auth.async_get_or_create_client(
|
||||
'Home Assistant Frontend',
|
||||
redirect_uris=['/'],
|
||||
no_secret=True,
|
||||
|
|
|
@ -321,7 +321,7 @@ class MockUser(auth.User):
|
|||
def add_to_auth_manager(self, auth_mgr):
|
||||
"""Test helper to add entry to hass."""
|
||||
ensure_auth_manager_loaded(auth_mgr)
|
||||
auth_mgr._store.users[self.id] = self
|
||||
auth_mgr._store._users[self.id] = self
|
||||
return self
|
||||
|
||||
|
||||
|
@ -329,10 +329,10 @@ class MockUser(auth.User):
|
|||
def ensure_auth_manager_loaded(auth_mgr):
|
||||
"""Ensure an auth manager is considered loaded."""
|
||||
store = auth_mgr._store
|
||||
if store.clients is None:
|
||||
store.clients = {}
|
||||
if store.users is None:
|
||||
store.users = {}
|
||||
if store._clients is None:
|
||||
store._clients = {}
|
||||
if store._users is None:
|
||||
store._users = {}
|
||||
|
||||
|
||||
class MockModule(object):
|
||||
|
|
|
@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
|
|||
})
|
||||
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
|
||||
redirect_uris=[CLIENT_REDIRECT_URI])
|
||||
hass.auth._store.clients[client.id] = client
|
||||
hass.auth._store._clients[client.id] = client
|
||||
if setup_api:
|
||||
await async_setup_component(hass, 'api', {})
|
||||
return await aiohttp_client(hass.http.app)
|
||||
|
|
|
@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage):
|
|||
await flush_store(manager._store._store)
|
||||
|
||||
store2 = auth.AuthStore(hass)
|
||||
await store2.async_load()
|
||||
assert len(store2.users) == 1
|
||||
assert store2.users[user.id] == user
|
||||
users = await store2.async_get_users()
|
||||
assert len(users) == 1
|
||||
assert users[0] == user
|
||||
|
||||
assert len(store2.clients) == 1
|
||||
assert store2.clients[client.id] == client
|
||||
clients = await store2.async_get_clients()
|
||||
assert len(clients) == 1
|
||||
assert clients[0] == client
|
||||
|
||||
|
||||
def test_access_token_expired():
|
||||
|
@ -224,15 +225,18 @@ def test_access_token_expired():
|
|||
async def test_cannot_retrieve_expired_access_token(hass):
|
||||
"""Test that we cannot retrieve expired access tokens."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
client = await manager.async_create_client('test')
|
||||
user = MockUser(
|
||||
id='mock-user',
|
||||
is_owner=False,
|
||||
is_active=False,
|
||||
name='Paulus',
|
||||
).add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(user, 'bla')
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
refresh_token = await manager.async_create_refresh_token(user, client.id)
|
||||
assert refresh_token.user.id is user.id
|
||||
assert refresh_token.client_id is client.id
|
||||
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
assert manager.async_get_access_token(access_token.token) is access_token
|
||||
|
||||
with patch('homeassistant.auth.dt_util.utcnow',
|
||||
|
@ -241,3 +245,38 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
|||
|
||||
# Even with unpatched time, it should have been removed from manager
|
||||
assert manager.async_get_access_token(access_token.token) is None
|
||||
|
||||
|
||||
async def test_get_or_create_client(hass):
|
||||
"""Test that get_or_create_client works."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
|
||||
client1 = await manager.async_get_or_create_client(
|
||||
'Test Client', redirect_uris=['https://test.com/1'])
|
||||
assert client1.name is 'Test Client'
|
||||
|
||||
client2 = await manager.async_get_or_create_client(
|
||||
'Test Client', redirect_uris=['https://test.com/1'])
|
||||
assert client2.id is client1.id
|
||||
|
||||
|
||||
async def test_cannot_create_refresh_token_with_invalide_client_id(hass):
|
||||
"""Test that we cannot create refresh token with invalid client id."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
user = MockUser(
|
||||
id='mock-user',
|
||||
is_owner=False,
|
||||
is_active=False,
|
||||
name='Paulus',
|
||||
).add_to_auth_manager(manager)
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_create_refresh_token(user, 'bla')
|
||||
|
||||
|
||||
async def test_cannot_create_refresh_token_with_invalide_user(hass):
|
||||
"""Test that we cannot create refresh token with invalid client id."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
client = await manager.async_create_client('test')
|
||||
user = MockUser(id='invalid-user')
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_create_refresh_token(user, client.id)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue