diff --git a/homeassistant/auth.py b/homeassistant/auth.py index a4e8ee05943..e6760cd9096 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -79,7 +79,14 @@ class AuthProvider: async def async_credentials(self): """Return all credentials of this provider.""" - return await self.store.credentials_for_provider(self.type, self.id) + users = await self.store.async_get_users() + return [ + credentials + for user in users + for credentials in user.credentials + if (credentials.auth_provider_type == self.type and + credentials.auth_provider_id == self.id) + ] @callback def async_create_credentials(self, data): @@ -118,10 +125,11 @@ class AuthProvider: class User: """A user.""" + name = attr.ib(type=str) id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) is_owner = attr.ib(type=bool, default=False) is_active = attr.ib(type=bool, default=False) - name = attr.ib(type=str, default=None) + system_generated = attr.ib(type=bool, default=False) # List of credentials of a user. credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False) @@ -300,10 +308,45 @@ class AuthManager: """Retrieve a user.""" return await self._store.async_get_user(user_id) + async def async_create_system_user(self, name): + """Create a system user.""" + return await self._store.async_create_user( + name=name, + system_generated=True, + is_active=True, + ) + async def async_get_or_create_user(self, credentials): """Get or create a user.""" - return await self._store.async_get_or_create_user( - credentials, self._async_get_auth_provider(credentials)) + if not credentials.is_new: + for user in await self._store.async_get_users(): + for creds in user.credentials: + if (creds.auth_provider_type == + credentials.auth_provider_type + and creds.auth_provider_id == + credentials.auth_provider_id): + return user + + raise ValueError('Unable to find the user.') + + auth_provider = self._async_get_auth_provider(credentials) + info = await auth_provider.async_user_meta_for_credentials( + credentials) + + kwargs = { + 'credentials': credentials, + 'name': info.get('name') + } + + # Make owner and activate user if it's the first user. + if await self._store.async_get_users(): + kwargs['is_owner'] = False + kwargs['is_active'] = False + else: + kwargs['is_owner'] = True + kwargs['is_active'] = True + + return await self._store.async_create_user(**kwargs) async def async_link_user(self, user, credentials): """Link credentials to an existing user.""" @@ -313,9 +356,20 @@ class AuthManager: """Remove a user.""" await self._store.async_remove_user(user) - async def async_create_refresh_token(self, user, client_id): + async def async_create_refresh_token(self, user, client=None): """Create a new refresh token for a user.""" - return await self._store.async_create_refresh_token(user, client_id) + if not user.is_active: + raise ValueError('User is not active') + + if user.system_generated and client is not None: + raise ValueError( + 'System generated users cannot have refresh tokens connected ' + 'to a client.') + + if not user.system_generated and client is None: + raise ValueError('Client is required to generate a refresh token.') + + return await self._store.async_create_refresh_token(user, client) async def async_get_refresh_token(self, token): """Get refresh token by token.""" @@ -324,7 +378,7 @@ class AuthManager: @callback def async_create_access_token(self, refresh_token): """Create a new access token.""" - access_token = AccessToken(refresh_token) + access_token = AccessToken(refresh_token=refresh_token) self._access_tokens[access_token.token] = access_token return access_token @@ -405,19 +459,6 @@ class AuthStore: self._clients = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) - async def credentials_for_provider(self, provider_type, provider_id): - """Return credentials for specific auth provider type and id.""" - if self._users is None: - await self.async_load() - - return [ - credentials - for user in self._users.values() - for credentials in user.credentials - if (credentials.auth_provider_type == provider_type and - credentials.auth_provider_id == provider_id) - ] - async def async_get_users(self): """Retrieve all users.""" if self._users is None: @@ -426,50 +467,42 @@ class AuthStore: return list(self._users.values()) async def async_get_user(self, user_id): - """Retrieve a user.""" + """Retrieve a user by id.""" if self._users is None: await self.async_load() return self._users.get(user_id) - async def async_get_or_create_user(self, credentials, auth_provider): - """Get or create a new user for given credentials. - - If link_user is passed in, the credentials will be linked to the passed - in user if the credentials are new. - """ + async def async_create_user(self, name, is_owner=None, is_active=None, + system_generated=None, credentials=None): + """Create a new user.""" if self._users is None: await self.async_load() - # New credentials, store in user - if credentials.is_new: - info = await auth_provider.async_user_meta_for_credentials( - credentials) - # Make owner and activate user if it's the first user. - if self._users: - is_owner = False - is_active = False - else: - is_owner = True - is_active = True + kwargs = { + 'name': name + } - new_user = User( - is_owner=is_owner, - is_active=is_active, - name=info.get('name'), - ) - self._users[new_user.id] = new_user - await self.async_link_user(new_user, credentials) + if is_owner is not None: + kwargs['is_owner'] = is_owner + + if is_active is not None: + kwargs['is_active'] = is_active + + if system_generated is not None: + kwargs['system_generated'] = system_generated + + new_user = User(**kwargs) + + self._users[new_user.id] = new_user + + if credentials is None: + await self.async_save() return new_user - for user in self._users.values(): - for creds in user.credentials: - if (creds.auth_provider_type == credentials.auth_provider_type - and creds.auth_provider_id == - credentials.auth_provider_id): - return user - - raise ValueError('We got credentials with ID but found no user') + # Saving is done inside the link. + await self.async_link_user(new_user, credentials) + return new_user async def async_link_user(self, user, credentials): """Add credentials to an existing user.""" @@ -482,17 +515,10 @@ class AuthStore: self._users.pop(user.id) await self.async_save() - async def async_create_refresh_token(self, user, client_id): + async def async_create_refresh_token(self, user, client=None): """Create a new token for a user.""" - local_user = await self.async_get_user(user.id) - if local_user is None: - raise ValueError('Invalid user') - - local_client = await self.async_get_client(client_id) - if local_client is None: - raise ValueError('Invalid client_id') - - refresh_token = RefreshToken(user, client_id) + client_id = client.id if client is not None else None + refresh_token = RefreshToken(user=user, client_id=client_id) user.refresh_tokens[refresh_token.token] = refresh_token await self.async_save() return refresh_token @@ -607,6 +633,7 @@ class AuthStore: 'is_owner': user.is_owner, 'is_active': user.is_active, 'name': user.name, + 'system_generated': user.system_generated, } for user in self._users.values() ] diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 0f7295a41e0..511999c52ab 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -236,18 +236,16 @@ class GrantTokenView(HomeAssistantView): grant_type = data.get('grant_type') if grant_type == 'authorization_code': - return await self._async_handle_auth_code( - hass, client.id, data) + return await self._async_handle_auth_code(hass, client, data) elif grant_type == 'refresh_token': - return await self._async_handle_refresh_token( - hass, client.id, data) + return await self._async_handle_refresh_token(hass, client, data) return self.json({ 'error': 'unsupported_grant_type', }, status_code=400) - async def _async_handle_auth_code(self, hass, client_id, data): + async def _async_handle_auth_code(self, hass, client, data): """Handle authorization code request.""" code = data.get('code') @@ -256,7 +254,7 @@ class GrantTokenView(HomeAssistantView): 'error': 'invalid_request', }, status_code=400) - credentials = self._retrieve_credentials(client_id, code) + credentials = self._retrieve_credentials(client.id, code) if credentials is None: return self.json({ @@ -265,7 +263,7 @@ class GrantTokenView(HomeAssistantView): user = await hass.auth.async_get_or_create_user(credentials) refresh_token = await hass.auth.async_create_refresh_token(user, - client_id) + client) access_token = hass.auth.async_create_access_token(refresh_token) return self.json({ @@ -276,7 +274,7 @@ class GrantTokenView(HomeAssistantView): int(refresh_token.access_token_expiration.total_seconds()), }) - async def _async_handle_refresh_token(self, hass, client_id, data): + async def _async_handle_refresh_token(self, hass, client, data): """Handle authorization code request.""" token = data.get('refresh_token') @@ -287,7 +285,7 @@ class GrantTokenView(HomeAssistantView): refresh_token = await hass.auth.async_get_refresh_token(token) - if refresh_token is None or refresh_token.client_id != client_id: + if refresh_token is None or refresh_token.client_id != client.id: return self.json({ 'error': 'invalid_grant', }, status_code=400) diff --git a/tests/auth_providers/test_insecure_example.py b/tests/auth_providers/test_insecure_example.py index 3377a60c45b..cb0bab4afed 100644 --- a/tests/auth_providers/test_insecure_example.py +++ b/tests/auth_providers/test_insecure_example.py @@ -54,7 +54,7 @@ async def test_match_existing_credentials(store, provider): }, is_new=False, ) - store.credentials_for_provider = Mock(return_value=mock_coro([existing])) + provider.async_credentials = Mock(return_value=mock_coro([existing])) credentials = await provider.async_get_or_create_credentials({ 'username': 'user-test', 'password': 'password-test', diff --git a/tests/auth_providers/test_legacy_api_password.py b/tests/auth_providers/test_legacy_api_password.py index 7a8f17894aa..3a186a0454c 100644 --- a/tests/auth_providers/test_legacy_api_password.py +++ b/tests/auth_providers/test_legacy_api_password.py @@ -21,6 +21,14 @@ def provider(hass, store): }) +@pytest.fixture +def manager(hass, store, provider): + """Mock manager.""" + return auth.AuthManager(hass, store, { + (provider.type, provider.id): provider + }) + + async def test_create_new_credential(provider): """Test that we create a new credential.""" credentials = await provider.async_get_or_create_credentials({}) @@ -28,13 +36,13 @@ async def test_create_new_credential(provider): assert credentials.is_new is True -async def test_only_one_credentials(store, provider): +async def test_only_one_credentials(manager, provider): """Call create twice will return same credential.""" credentials = await provider.async_get_or_create_credentials({}) - await store.async_get_or_create_user(credentials, provider) + await manager.async_get_or_create_user(credentials) credentials2 = await provider.async_get_or_create_credentials({}) - assert credentials2.data["username"] is legacy_api_password.LEGACY_USER - assert credentials2.id is credentials.id + assert credentials2.data["username"] == legacy_api_password.LEGACY_USER + assert credentials2.id == credentials.id assert credentials2.is_new is False diff --git a/tests/common.py b/tests/common.py index 3a51cd3e059..ccb8f49ea97 100644 --- a/tests/common.py +++ b/tests/common.py @@ -312,7 +312,8 @@ class MockUser(auth.User): def __init__(self, id='mock-id', is_owner=True, is_active=True, name='Mock User'): """Initialize mock user.""" - super().__init__(id, is_owner, is_active, name) + super().__init__( + id=id, is_owner=is_owner, is_active=is_active, name=name) def add_to_hass(self, hass): """Test helper to add entry to hass.""" diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 8a1b934ab76..00e3ee88d16 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -34,5 +34,5 @@ def hass_access_token(hass): no_secret=True, )) refresh_token = hass.loop.run_until_complete( - hass.auth.async_create_refresh_token(user, client.id)) + hass.auth.async_create_refresh_token(user, client)) yield hass.auth.async_create_access_token(refresh_token) diff --git a/tests/test_auth.py b/tests/test_auth.py index 5b545223c15..8096a081679 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -184,7 +184,7 @@ async def test_saving_loading(hass, hass_storage): client = await manager.async_create_client( 'test', redirect_uris=['https://example.com']) - refresh_token = await manager.async_create_refresh_token(user, client.id) + refresh_token = await manager.async_create_refresh_token(user, client) manager.async_create_access_token(refresh_token) @@ -226,13 +226,8 @@ async def test_cannot_retrieve_expired_access_token(hass): """Test that we cannot retrieve expired access tokens.""" manager = await auth.auth_manager_from_config(hass, []) client = await manager.async_create_client('test') - user = MockUser( - id='mock-user', - is_owner=False, - is_active=False, - name='Paulus', - ).add_to_auth_manager(manager) - refresh_token = await manager.async_create_refresh_token(user, client.id) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token(user, client) assert refresh_token.user.id is user.id assert refresh_token.client_id is client.id @@ -260,23 +255,41 @@ async def test_get_or_create_client(hass): assert client2.id is client1.id -async def test_cannot_create_refresh_token_with_invalide_client_id(hass): - """Test that we cannot create refresh token with invalid client id.""" +async def test_generating_system_user(hass): + """Test that we can add a system user.""" manager = await auth.auth_manager_from_config(hass, []) - user = MockUser( - id='mock-user', - is_owner=False, - is_active=False, - name='Paulus', - ).add_to_auth_manager(manager) - with pytest.raises(ValueError): - await manager.async_create_refresh_token(user, 'bla') + user = await manager.async_create_system_user('Hass.io') + token = await manager.async_create_refresh_token(user) + assert user.system_generated + assert token is not None + assert token.client_id is None -async def test_cannot_create_refresh_token_with_invalide_user(hass): - """Test that we cannot create refresh token with invalid client id.""" +async def test_refresh_token_requires_client_for_user(hass): + """Test that we can add a system user.""" manager = await auth.auth_manager_from_config(hass, []) - client = await manager.async_create_client('test') - user = MockUser(id='invalid-user') + user = MockUser().add_to_auth_manager(manager) + assert user.system_generated is False + with pytest.raises(ValueError): - await manager.async_create_refresh_token(user, client.id) + await manager.async_create_refresh_token(user) + + client = await manager.async_get_or_create_client('Test client') + token = await manager.async_create_refresh_token(user, client) + assert token is not None + assert token.client_id == client.id + + +async def test_refresh_token_not_requires_client_for_system_user(hass): + """Test that we can add a system user.""" + manager = await auth.auth_manager_from_config(hass, []) + user = await manager.async_create_system_user('Hass.io') + assert user.system_generated is True + client = await manager.async_get_or_create_client('Test client') + + with pytest.raises(ValueError): + await manager.async_create_refresh_token(user, client) + + token = await manager.async_create_refresh_token(user) + assert token is not None + assert token.client_id is None