From 70fe463ef0b696f3b4f14086e2462a0ee13b3a6e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 13 Jul 2018 15:31:20 +0200 Subject: [PATCH] User management (#15420) * User management * Lint * Fix dict * Reuse data instance * OrderedDict all the way --- homeassistant/auth/__init__.py | 45 +++- homeassistant/auth/auth_store.py | 29 ++- homeassistant/auth/providers/__init__.py | 8 - homeassistant/auth/providers/homeassistant.py | 67 ++++- homeassistant/components/auth/__init__.py | 2 +- homeassistant/components/camera/__init__.py | 4 +- homeassistant/components/config/__init__.py | 4 + homeassistant/components/config/auth.py | 113 +++++++++ .../config/auth_provider_homeassistant.py | 120 +++++++++ homeassistant/components/http/auth.py | 5 + homeassistant/components/websocket_api.py | 21 +- homeassistant/scripts/auth.py | 53 +++- tests/auth/providers/test_homeassistant.py | 31 ++- tests/auth/test_init.py | 2 +- tests/common.py | 36 ++- tests/components/config/test_auth.py | 211 ++++++++++++++++ .../test_auth_provider_homeassistant.py | 229 ++++++++++++++++++ tests/components/conftest.py | 17 +- tests/components/http/test_auth.py | 41 ++-- tests/conftest.py | 2 +- tests/scripts/test_auth.py | 58 +++-- 21 files changed, 982 insertions(+), 116 deletions(-) create mode 100644 homeassistant/components/config/auth.py create mode 100644 homeassistant/components/config/auth_provider_homeassistant.py create mode 100644 tests/components/config/test_auth.py create mode 100644 tests/components/config/test_auth_provider_homeassistant.py diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index c5db65586b1..fb35bd05c33 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -51,7 +51,7 @@ class AuthManager: self.login_flow = data_entry_flow.FlowManager( hass, self._async_create_login_flow, self._async_finish_login_flow) - self._access_tokens = {} + self._access_tokens = OrderedDict() @property def active(self): @@ -71,9 +71,13 @@ class AuthManager: return False @property - def async_auth_providers(self): + def auth_providers(self): """Return a list of available auth providers.""" - return self._providers.values() + return list(self._providers.values()) + + async def async_get_users(self): + """Retrieve all users.""" + return await self._store.async_get_users() async def async_get_user(self, user_id): """Retrieve a user.""" @@ -87,6 +91,13 @@ class AuthManager: is_active=True, ) + async def async_create_user(self, name): + """Create a user.""" + return await self._store.async_create_user( + name=name, + is_active=True, + ) + async def async_get_or_create_user(self, credentials): """Get or create a user.""" if not credentials.is_new: @@ -98,6 +109,10 @@ class AuthManager: raise ValueError('Unable to find the user.') auth_provider = self._async_get_auth_provider(credentials) + + if auth_provider is None: + raise RuntimeError('Credential with unknown provider encountered') + info = await auth_provider.async_user_meta_for_credentials( credentials) @@ -122,8 +137,26 @@ class AuthManager: async def async_remove_user(self, user): """Remove a user.""" + tasks = [ + self.async_remove_credentials(credentials) + for credentials in user.credentials + ] + + if tasks: + await asyncio.wait(tasks) + await self._store.async_remove_user(user) + async def async_remove_credentials(self, credentials): + """Remove credentials.""" + provider = self._async_get_auth_provider(credentials) + + if (provider is not None and + hasattr(provider, 'async_will_remove_credentials')): + await provider.async_will_remove_credentials(credentials) + + await self._store.async_remove_credentials(credentials) + async def async_create_refresh_token(self, user, client_id=None): """Create a new refresh token for a user.""" if not user.is_active: @@ -168,10 +201,6 @@ class AuthManager: """Create a login flow.""" auth_provider = self._providers[handler] - if not auth_provider.initialized: - auth_provider.initialized = True - await auth_provider.async_initialize() - return await auth_provider.async_credential_flow() async def _async_finish_login_flow(self, result): @@ -188,4 +217,4 @@ class AuthManager: """Helper to get auth provider from a set of credentials.""" auth_provider_key = (credentials.auth_provider_type, credentials.auth_provider_id) - return self._providers[auth_provider_key] + return self._providers.get(auth_provider_key) diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 691e561f22f..ebd61140ac1 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -1,4 +1,5 @@ """Storage for auth models.""" +from collections import OrderedDict from datetime import timedelta from homeassistant.util import dt as dt_util @@ -80,6 +81,22 @@ class AuthStore: self._users.pop(user.id) await self.async_save() + async def async_remove_credentials(self, credentials): + """Remove credentials.""" + for user in self._users.values(): + found = None + + for index, cred in enumerate(user.credentials): + if cred is credentials: + found = index + break + + if found is not None: + user.credentials.pop(found) + break + + await self.async_save() + async def async_create_refresh_token(self, user, client_id=None): """Create a new token for a user.""" refresh_token = models.RefreshToken(user=user, client_id=client_id) @@ -108,14 +125,14 @@ class AuthStore: if self._users is not None: return + users = OrderedDict() + if data is None: - self._users = {} + self._users = users return - users = { - user_dict['id']: models.User(**user_dict) - for user_dict in data['users'] - } + for user_dict in data['users']: + users[user_dict['id']] = models.User(**user_dict) for cred_dict in data['credentials']: users[cred_dict['user_id']].credentials.append(models.Credentials( @@ -126,7 +143,7 @@ class AuthStore: data=cred_dict['data'], )) - refresh_tokens = {} + refresh_tokens = OrderedDict() for rt_dict in data['refresh_tokens']: token = models.RefreshToken( diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index d6630383ff2..3769248fc05 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -77,8 +77,6 @@ class AuthProvider: DEFAULT_TITLE = 'Unnamed auth provider' - initialized = False - def __init__(self, hass, store, config): """Initialize an auth provider.""" self.hass = hass @@ -125,12 +123,6 @@ class AuthProvider: # Implement by extending class - async def async_initialize(self): - """Initialize the auth provider. - - Optional. - """ - async def async_credential_flow(self): """Return the data flow for logging in with auth provider.""" raise NotImplementedError diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index fa6878da065..17a56bc5f42 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -7,6 +7,8 @@ import hmac import voluptuous as vol from homeassistant import data_entry_flow +from homeassistant.const import CONF_ID +from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError from homeassistant.auth.util import generate_secret @@ -16,8 +18,17 @@ from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS STORAGE_VERSION = 1 STORAGE_KEY = 'auth_provider.homeassistant' -CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ -}, extra=vol.PREVENT_EXTRA) + +def _disallow_id(conf): + """Disallow ID in config.""" + if CONF_ID in conf: + raise vol.Invalid( + 'ID is not allowed for the homeassistant auth provider.') + + return conf + + +CONFIG_SCHEMA = vol.All(AUTH_PROVIDER_SCHEMA, _disallow_id) class InvalidAuth(HomeAssistantError): @@ -88,8 +99,8 @@ class Data: hashed = base64.b64encode(hashed).decode() return hashed - def add_user(self, username, password): - """Add a user.""" + def add_auth(self, username, password): + """Add a new authenticated user/pass.""" if any(user['username'] == username for user in self.users): raise InvalidUser @@ -98,8 +109,22 @@ class Data: 'password': self.hash_password(password, True), }) + @callback + def async_remove_auth(self, username): + """Remove authentication.""" + index = None + for i, user in enumerate(self.users): + if user['username'] == username: + index = i + break + + if index is None: + raise InvalidUser + + self.users.pop(index) + def change_password(self, username, new_password): - """Update the password of a user. + """Update the password. Raises InvalidUser if user cannot be found. """ @@ -121,16 +146,24 @@ class HassAuthProvider(AuthProvider): DEFAULT_TITLE = 'Home Assistant Local' + data = None + + async def async_initialize(self): + """Initialize the auth provider.""" + self.data = Data(self.hass) + await self.data.async_load() + async def async_credential_flow(self): """Return a flow to login.""" return LoginFlow(self) async def async_validate_login(self, username, password): """Helper to validate a username and password.""" - data = Data(self.hass) - await data.async_load() + if self.data is None: + await self.async_initialize() + await self.hass.async_add_executor_job( - data.validate_login, username, password) + self.data.validate_login, username, password) async def async_get_or_create_credentials(self, flow_result): """Get credentials based on the flow result.""" @@ -145,6 +178,24 @@ class HassAuthProvider(AuthProvider): 'username': username }) + async def async_user_meta_for_credentials(self, credentials): + """Get extra info for this credential.""" + return { + 'name': credentials.data['username'] + } + + async def async_will_remove_credentials(self, credentials): + """When credentials get removed, also remove the auth.""" + if self.data is None: + await self.async_initialize() + + try: + self.data.async_remove_auth(credentials.data['username']) + await self.data.async_save() + except InvalidUser: + # Can happen if somehow we didn't clean up a credential + pass + class LoginFlow(data_entry_flow.FlowHandler): """Handler for the login flow.""" diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 3e236876d6a..1ead4cacdf0 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -152,7 +152,7 @@ class AuthProvidersView(HomeAssistantView): 'name': provider.name, 'id': provider.id, 'type': provider.type, - } for provider in request.app['hass'].auth.async_auth_providers]) + } for provider in request.app['hass'].auth.auth_providers]) class LoginFlowIndexView(FlowManagerIndexView): diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index 14550dab899..22354b51956 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -66,8 +66,8 @@ CAMERA_SERVICE_SNAPSHOT = CAMERA_SERVICE_SCHEMA.extend({ WS_TYPE_CAMERA_THUMBNAIL = 'camera_thumbnail' SCHEMA_WS_CAMERA_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ - 'type': WS_TYPE_CAMERA_THUMBNAIL, - 'entity_id': cv.entity_id + vol.Required('type'): WS_TYPE_CAMERA_THUMBNAIL, + vol.Required('entity_id'): cv.entity_id }) diff --git a/homeassistant/components/config/__init__.py b/homeassistant/components/config/__init__.py index b907d4b4217..581d8fc3f7b 100644 --- a/homeassistant/components/config/__init__.py +++ b/homeassistant/components/config/__init__.py @@ -49,6 +49,10 @@ async def async_setup(hass, config): tasks = [setup_panel(panel_name) for panel_name in SECTIONS] + if hass.auth.active: + tasks.append(setup_panel('auth')) + tasks.append(setup_panel('auth_provider_homeassistant')) + for panel_name in ON_DEMAND: if panel_name in hass.config.components: tasks.append(setup_panel(panel_name)) diff --git a/homeassistant/components/config/auth.py b/homeassistant/components/config/auth.py new file mode 100644 index 00000000000..6f00b03dedb --- /dev/null +++ b/homeassistant/components/config/auth.py @@ -0,0 +1,113 @@ +"""Offer API to configure Home Assistant auth.""" +import voluptuous as vol + +from homeassistant.core import callback +from homeassistant.components import websocket_api + + +WS_TYPE_LIST = 'config/auth/list' +SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): WS_TYPE_LIST, +}) + +WS_TYPE_DELETE = 'config/auth/delete' +SCHEMA_WS_DELETE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): WS_TYPE_DELETE, + vol.Required('user_id'): str, +}) + +WS_TYPE_CREATE = 'config/auth/create' +SCHEMA_WS_CREATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): WS_TYPE_CREATE, + vol.Required('name'): str, +}) + + +async def async_setup(hass): + """Enable the Home Assistant views.""" + hass.components.websocket_api.async_register_command( + WS_TYPE_LIST, websocket_list, + SCHEMA_WS_LIST + ) + hass.components.websocket_api.async_register_command( + WS_TYPE_DELETE, websocket_delete, + SCHEMA_WS_DELETE + ) + hass.components.websocket_api.async_register_command( + WS_TYPE_CREATE, websocket_create, + SCHEMA_WS_CREATE + ) + return True + + +@callback +@websocket_api.require_owner +def websocket_list(hass, connection, msg): + """Return a list of users.""" + async def send_users(): + """Send users.""" + result = [_user_info(u) for u in await hass.auth.async_get_users()] + + connection.send_message_outside( + websocket_api.result_message(msg['id'], result)) + + hass.async_add_job(send_users()) + + +@callback +@websocket_api.require_owner +def websocket_delete(hass, connection, msg): + """Delete a user.""" + async def delete_user(): + """Delete user.""" + if msg['user_id'] == connection.request.get('hass_user').id: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'no_delete_self', + 'Unable to delete your own account')) + return + + user = await hass.auth.async_get_user(msg['user_id']) + + if not user: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'not_found', 'User not found')) + return + + await hass.auth.async_remove_user(user) + + connection.send_message_outside( + websocket_api.result_message(msg['id'])) + + hass.async_add_job(delete_user()) + + +@callback +@websocket_api.require_owner +def websocket_create(hass, connection, msg): + """Create a user.""" + async def create_user(): + """Create a user.""" + user = await hass.auth.async_create_user(msg['name']) + + connection.send_message_outside( + websocket_api.result_message(msg['id'], { + 'user': _user_info(user) + })) + + hass.async_add_job(create_user()) + + +def _user_info(user): + """Format a user.""" + return { + 'id': user.id, + 'name': user.name, + 'is_owner': user.is_owner, + 'is_active': user.is_active, + 'system_generated': user.system_generated, + 'credentials': [ + { + 'type': c.auth_provider_type, + } for c in user.credentials + ] + } diff --git a/homeassistant/components/config/auth_provider_homeassistant.py b/homeassistant/components/config/auth_provider_homeassistant.py new file mode 100644 index 00000000000..fca03ad8fa9 --- /dev/null +++ b/homeassistant/components/config/auth_provider_homeassistant.py @@ -0,0 +1,120 @@ +"""Offer API to configure the Home Assistant auth provider.""" +import voluptuous as vol + +from homeassistant.auth.providers import homeassistant as auth_ha +from homeassistant.core import callback +from homeassistant.components import websocket_api + + +WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create' +SCHEMA_WS_CREATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): WS_TYPE_CREATE, + vol.Required('user_id'): str, + vol.Required('username'): str, + vol.Required('password'): str, +}) + +WS_TYPE_DELETE = 'config/auth_provider/homeassistant/delete' +SCHEMA_WS_DELETE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): WS_TYPE_DELETE, + vol.Required('username'): str, +}) + + +async def async_setup(hass): + """Enable the Home Assistant views.""" + hass.components.websocket_api.async_register_command( + WS_TYPE_CREATE, websocket_create, + SCHEMA_WS_CREATE + ) + hass.components.websocket_api.async_register_command( + WS_TYPE_DELETE, websocket_delete, + SCHEMA_WS_DELETE + ) + return True + + +def _get_provider(hass): + """Get homeassistant auth provider.""" + for prv in hass.auth.auth_providers: + if prv.type == 'homeassistant': + return prv + + raise RuntimeError('Provider not found') + + +@callback +@websocket_api.require_owner +def websocket_create(hass, connection, msg): + """Create credentials and attach to a user.""" + async def create_creds(): + """Create credentials.""" + provider = _get_provider(hass) + await provider.async_initialize() + + user = await hass.auth.async_get_user(msg['user_id']) + + if user is None: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'not_found', 'User not found')) + return + + if user.system_generated: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'system_generated', + 'Cannot add credentials to a system generated user.')) + return + + try: + await hass.async_add_executor_job( + provider.data.add_auth, msg['username'], msg['password']) + except auth_ha.InvalidUser: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'username_exists', 'Username already exists')) + return + + credentials = await provider.async_get_or_create_credentials({ + 'username': msg['username'] + }) + await hass.auth.async_link_user(user, credentials) + + await provider.data.async_save() + connection.to_write.put_nowait(websocket_api.result_message(msg['id'])) + + hass.async_add_job(create_creds()) + + +@callback +@websocket_api.require_owner +def websocket_delete(hass, connection, msg): + """Delete username and related credential.""" + async def delete_creds(): + """Delete user credentials.""" + provider = _get_provider(hass) + await provider.async_initialize() + + credentials = await provider.async_get_or_create_credentials({ + 'username': msg['username'] + }) + + # if not new, an existing credential exists. + # Removing the credential will also remove the auth. + if not credentials.is_new: + await hass.auth.async_remove_credentials(credentials) + + connection.to_write.put_nowait( + websocket_api.result_message(msg['id'])) + return + + try: + provider.data.async_remove_auth(msg['username']) + await provider.data.async_save() + except auth_ha.InvalidUser: + connection.to_write.put_nowait(websocket_api.error_message( + msg['id'], 'auth_not_found', 'Given username was not found.')) + return + + connection.to_write.put_nowait( + websocket_api.result_message(msg['id'])) + + hass.async_add_job(delete_creds()) diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 2cc62dce38e..46d77214160 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -106,6 +106,11 @@ async def async_validate_auth_header(request, api_password=None): if access_token is None: return False + user = access_token.refresh_token.user + + if not user.is_active: + return False + request['hass_user'] = access_token.refresh_token.user return True diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index c26f68a2c29..6cd16909041 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -7,7 +7,7 @@ https://home-assistant.io/developers/websocket_api/ import asyncio from concurrent import futures from contextlib import suppress -from functools import partial +from functools import partial, wraps import json import logging @@ -196,6 +196,23 @@ def async_register_command(hass, command, handler, schema): handlers[command] = (handler, schema) +def require_owner(func): + """Websocket decorator to require user to be an owner.""" + @wraps(func) + def with_owner(hass, connection, msg): + """Check owner and call function.""" + user = connection.request.get('hass_user') + + if user is None or not user.is_owner: + connection.to_write.put_nowait(error_message( + msg['id'], 'unauthorized', 'This command is for owners only.')) + return + + func(hass, connection, msg) + + return with_owner + + async def async_setup(hass, config): """Initialize the websocket API.""" hass.http.register_view(WebsocketAPIView) @@ -325,6 +342,8 @@ class ActiveConnection: token = self.hass.auth.async_get_access_token( msg['access_token']) authenticated = token is not None + if authenticated: + request['hass_user'] = token.refresh_token.user elif ((not self.hass.auth.active or self.hass.auth.support_legacy) and diff --git a/homeassistant/scripts/auth.py b/homeassistant/scripts/auth.py index aa39e9f66df..fea523c4117 100644 --- a/homeassistant/scripts/auth.py +++ b/homeassistant/scripts/auth.py @@ -1,8 +1,10 @@ """Script to manage users for the Home Assistant auth provider.""" import argparse import asyncio +import logging import os +from homeassistant.auth import auth_manager_from_config from homeassistant.core import HomeAssistant from homeassistant.config import get_default_config_dir from homeassistant.auth.providers import homeassistant as hass_auth @@ -42,16 +44,28 @@ def run(args): args = parser.parse_args(args) loop = asyncio.get_event_loop() hass = HomeAssistant(loop=loop) + loop.run_until_complete(run_command(hass, args)) + + # Triggers save on used storage helpers with delay (core auth) + logging.getLogger('homeassistant.core').setLevel(logging.WARNING) + loop.run_until_complete(hass.async_stop()) + + +async def run_command(hass, args): + """Run the command.""" hass.config.config_dir = os.path.join(os.getcwd(), args.config) - data = hass_auth.Data(hass) - loop.run_until_complete(data.async_load()) - loop.run_until_complete(args.func(data, args)) + hass.auth = await auth_manager_from_config(hass, [{ + 'type': 'homeassistant', + }]) + provider = hass.auth.auth_providers[0] + await provider.async_initialize() + await args.func(hass, provider, args) -async def list_users(data, args): +async def list_users(hass, provider, args): """List the users.""" count = 0 - for user in data.users: + for user in provider.data.users: count += 1 print(user['username']) @@ -59,27 +73,40 @@ async def list_users(data, args): print("Total users:", count) -async def add_user(data, args): +async def add_user(hass, provider, args): """Create a user.""" - data.add_user(args.username, args.password) - await data.async_save() + try: + provider.data.add_auth(args.username, args.password) + except hass_auth.InvalidUser: + print("Username already exists!") + return + + credentials = await provider.async_get_or_create_credentials({ + 'username': args.username + }) + + user = await hass.auth.async_create_user(args.username) + await hass.auth.async_link_user(user, credentials) + + # Save username/password + await provider.data.async_save() print("User created") -async def validate_login(data, args): +async def validate_login(hass, provider, args): """Validate a login.""" try: - data.validate_login(args.username, args.password) + provider.data.validate_login(args.username, args.password) print("Auth valid") except hass_auth.InvalidAuth: print("Auth invalid") -async def change_password(data, args): +async def change_password(hass, provider, args): """Change password.""" try: - data.change_password(args.username, args.new_password) - await data.async_save() + provider.data.change_password(args.username, args.new_password) + await provider.data.async_save() print("Password changed") except hass_auth.InvalidUser: print("User not found") diff --git a/tests/auth/providers/test_homeassistant.py b/tests/auth/providers/test_homeassistant.py index 98701ba2857..08fb63a3c72 100644 --- a/tests/auth/providers/test_homeassistant.py +++ b/tests/auth/providers/test_homeassistant.py @@ -1,8 +1,11 @@ """Test the Home Assistant local auth provider.""" +from unittest.mock import Mock + import pytest from homeassistant import data_entry_flow -from homeassistant.auth.providers import homeassistant as hass_auth +from homeassistant.auth.providers import ( + auth_provider_from_config, homeassistant as hass_auth) @pytest.fixture @@ -15,15 +18,15 @@ def data(hass): async def test_adding_user(data, hass): """Test adding a user.""" - data.add_user('test-user', 'test-pass') + data.add_auth('test-user', 'test-pass') data.validate_login('test-user', 'test-pass') async def test_adding_user_duplicate_username(data, hass): """Test adding a user.""" - data.add_user('test-user', 'test-pass') + data.add_auth('test-user', 'test-pass') with pytest.raises(hass_auth.InvalidUser): - data.add_user('test-user', 'other-pass') + data.add_auth('test-user', 'other-pass') async def test_validating_password_invalid_user(data, hass): @@ -34,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass): async def test_validating_password_invalid_password(data, hass): """Test validating an invalid user.""" - data.add_user('test-user', 'test-pass') + data.add_auth('test-user', 'test-pass') with pytest.raises(hass_auth.InvalidAuth): data.validate_login('test-user', 'invalid-pass') @@ -43,7 +46,7 @@ async def test_validating_password_invalid_password(data, hass): async def test_changing_password(data, hass): """Test adding a user.""" user = 'test-user' - data.add_user(user, 'test-pass') + data.add_auth(user, 'test-pass') data.change_password(user, 'new-pass') with pytest.raises(hass_auth.InvalidAuth): @@ -60,7 +63,7 @@ async def test_changing_password_raises_invalid_user(data, hass): async def test_login_flow_validates(data, hass): """Test login flow.""" - data.add_user('test-user', 'test-pass') + data.add_auth('test-user', 'test-pass') await data.async_save() provider = hass_auth.HassAuthProvider(hass, None, {}) @@ -91,11 +94,21 @@ async def test_login_flow_validates(data, hass): async def test_saving_loading(data, hass): """Test saving and loading JSON.""" - data.add_user('test-user', 'test-pass') - data.add_user('second-user', 'second-pass') + data.add_auth('test-user', 'test-pass') + data.add_auth('second-user', 'second-pass') await data.async_save() data = hass_auth.Data(hass) await data.async_load() data.validate_login('test-user', 'test-pass') data.validate_login('second-user', 'second-pass') + + +async def test_not_allow_set_id(): + """Test we are not allowed to set an ID in config.""" + hass = Mock() + provider = await auth_provider_from_config(hass, None, { + 'type': 'homeassistant', + 'id': 'invalid', + }) + assert provider is None diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py index 805369a6da8..f7187fd49fd 100644 --- a/tests/auth/test_init.py +++ b/tests/auth/test_init.py @@ -46,7 +46,7 @@ async def test_auth_manager_from_config_validates_config_and_id(mock_hass): 'name': provider.name, 'id': provider.id, 'type': provider.type, - } for provider in manager.async_auth_providers] + } for provider in manager.auth_providers] assert providers == [{ 'name': 'Test Name', 'type': 'insecure_example', diff --git a/tests/common.py b/tests/common.py index b3da5e0d098..b03d473e6f3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,5 +1,6 @@ """Test the helper method for writing tests.""" import asyncio +from collections import OrderedDict from datetime import timedelta import functools as ft import json @@ -12,7 +13,8 @@ import threading from contextlib import contextmanager from homeassistant import auth, core as ha, data_entry_flow, config_entries -from homeassistant.auth import models as auth_models, auth_store +from homeassistant.auth import ( + models as auth_models, auth_store, providers as auth_providers) from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config from homeassistant.helpers import ( @@ -312,11 +314,12 @@ def mock_registry(hass, mock_entries=None): class MockUser(auth_models.User): """Mock a user in Home Assistant.""" - def __init__(self, id='mock-id', is_owner=True, is_active=True, - name='Mock User'): + def __init__(self, id='mock-id', is_owner=False, is_active=True, + name='Mock User', system_generated=False): """Initialize mock user.""" super().__init__( - id=id, is_owner=is_owner, is_active=is_active, name=name) + id=id, is_owner=is_owner, is_active=is_active, name=name, + system_generated=system_generated) def add_to_hass(self, hass): """Test helper to add entry to hass.""" @@ -329,12 +332,27 @@ class MockUser(auth_models.User): return self +async def register_auth_provider(hass, config): + """Helper to register an auth provider.""" + provider = await auth_providers.auth_provider_from_config( + hass, hass.auth._store, config) + assert provider is not None, 'Invalid config specified' + key = (provider.type, provider.id) + providers = hass.auth._providers + + if key in providers: + raise ValueError('Provider already registered') + + providers[key] = provider + return provider + + @ha.callback def ensure_auth_manager_loaded(auth_mgr): """Ensure an auth manager is considered loaded.""" store = auth_mgr._store if store._users is None: - store._users = {} + store._users = OrderedDict() class MockModule(object): @@ -731,7 +749,13 @@ def mock_storage(data=None): if store.key not in data: return None - store._data = data.get(store.key) + mock_data = data.get(store.key) + + if 'data' not in mock_data or 'version' not in mock_data: + _LOGGER.error('Mock data needs "version" and "data"') + raise ValueError('Mock data needs "version" and "data"') + + store._data = mock_data # Route through original load so that we trigger migration loaded = await orig_load(store) diff --git a/tests/components/config/test_auth.py b/tests/components/config/test_auth.py new file mode 100644 index 00000000000..fe8f351955f --- /dev/null +++ b/tests/components/config/test_auth.py @@ -0,0 +1,211 @@ +"""Test config entries API.""" +from unittest.mock import PropertyMock, patch + +import pytest + +from homeassistant.auth import models as auth_models +from homeassistant.components.config import auth as auth_config + +from tests.common import MockUser, CLIENT_ID + + +@pytest.fixture(autouse=True) +def auth_active(hass): + """Mock that auth is active.""" + with patch('homeassistant.auth.AuthManager.active', + PropertyMock(return_value=True)): + yield + + +@pytest.fixture(autouse=True) +def setup_config(hass, aiohttp_client): + """Fixture that sets up the auth provider homeassistant module.""" + hass.loop.run_until_complete(auth_config.async_setup(hass)) + + +async def test_list_requires_owner(hass, hass_ws_client, hass_access_token): + """Test get users requires auth.""" + client = await hass_ws_client(hass, hass_access_token) + + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_LIST, + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'unauthorized' + + +async def test_list(hass, hass_ws_client): + """Test get users.""" + owner = MockUser( + id='abc', + name='Test Owner', + is_owner=True, + ).add_to_hass(hass) + + owner.credentials.append(auth_models.Credentials( + auth_provider_type='homeassistant', + auth_provider_id=None, + data={}, + )) + + system = MockUser( + id='efg', + name='Test Hass.io', + system_generated=True + ).add_to_hass(hass) + + inactive = MockUser( + id='hij', + name='Inactive User', + is_active=False, + ).add_to_hass(hass) + + refresh_token = await hass.auth.async_create_refresh_token( + owner, CLIENT_ID) + access_token = hass.auth.async_create_access_token(refresh_token) + + client = await hass_ws_client(hass, access_token) + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_LIST, + }) + + result = await client.receive_json() + assert result['success'], result + data = result['result'] + assert len(data) == 3 + assert data[0] == { + 'id': owner.id, + 'name': 'Test Owner', + 'is_owner': True, + 'is_active': True, + 'system_generated': False, + 'credentials': [{'type': 'homeassistant'}] + } + assert data[1] == { + 'id': system.id, + 'name': 'Test Hass.io', + 'is_owner': False, + 'is_active': True, + 'system_generated': True, + 'credentials': [], + } + assert data[2] == { + 'id': inactive.id, + 'name': 'Inactive User', + 'is_owner': False, + 'is_active': False, + 'system_generated': False, + 'credentials': [], + } + + +async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token): + """Test delete command requires an owner.""" + client = await hass_ws_client(hass, hass_access_token) + + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_DELETE, + 'user_id': 'abcd', + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'unauthorized' + + +async def test_delete_unable_self_account(hass, hass_ws_client, + hass_access_token): + """Test we cannot delete our own account.""" + client = await hass_ws_client(hass, hass_access_token) + + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_DELETE, + 'user_id': hass_access_token.refresh_token.user.id, + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'unauthorized' + + +async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token): + """Test we cannot delete an unknown user.""" + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_DELETE, + 'user_id': 'abcd', + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'not_found' + + +async def test_delete(hass, hass_ws_client, hass_access_token): + """Test delete command works.""" + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + test_user = MockUser( + id='efg', + ).add_to_hass(hass) + + assert len(await hass.auth.async_get_users()) == 2 + + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_DELETE, + 'user_id': test_user.id, + }) + + result = await client.receive_json() + assert result['success'], result + assert len(await hass.auth.async_get_users()) == 1 + + +async def test_create(hass, hass_ws_client, hass_access_token): + """Test create command works.""" + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + + assert len(await hass.auth.async_get_users()) == 1 + + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_CREATE, + 'name': 'Paulus', + }) + + result = await client.receive_json() + assert result['success'], result + assert len(await hass.auth.async_get_users()) == 2 + data_user = result['result']['user'] + user = await hass.auth.async_get_user(data_user['id']) + assert user is not None + assert user.name == data_user['name'] + assert user.is_active + assert not user.is_owner + assert not user.system_generated + + +async def test_create_requires_owner(hass, hass_ws_client, hass_access_token): + """Test create command requires an owner.""" + client = await hass_ws_client(hass, hass_access_token) + + await client.send_json({ + 'id': 5, + 'type': auth_config.WS_TYPE_CREATE, + 'name': 'YO', + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'unauthorized' diff --git a/tests/components/config/test_auth_provider_homeassistant.py b/tests/components/config/test_auth_provider_homeassistant.py new file mode 100644 index 00000000000..fa4ab612bb1 --- /dev/null +++ b/tests/components/config/test_auth_provider_homeassistant.py @@ -0,0 +1,229 @@ +"""Test config entries API.""" +import pytest + +from homeassistant.auth.providers import homeassistant as prov_ha +from homeassistant.components.config import ( + auth_provider_homeassistant as auth_ha) + +from tests.common import MockUser, register_auth_provider + + +@pytest.fixture(autouse=True) +def setup_config(hass, aiohttp_client): + """Fixture that sets up the auth provider homeassistant module.""" + hass.loop.run_until_complete(register_auth_provider(hass, { + 'type': 'homeassistant' + })) + hass.loop.run_until_complete(auth_ha.async_setup(hass)) + + +async def test_create_auth_system_generated_user(hass, hass_access_token, + hass_ws_client): + """Test we can't add auth to system generated users.""" + system_user = MockUser(system_generated=True).add_to_hass(hass) + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_CREATE, + 'user_id': system_user.id, + 'username': 'test-user', + 'password': 'test-pass', + }) + + result = await client.receive_json() + + assert not result['success'], result + assert result['error']['code'] == 'system_generated' + + +async def test_create_auth_user_already_credentials(): + """Test we can't create auth for user with pre-existing credentials.""" + # assert False + + +async def test_create_auth_unknown_user(hass_ws_client, hass, + hass_access_token): + """Test create pointing at unknown user.""" + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_CREATE, + 'user_id': 'test-id', + 'username': 'test-user', + 'password': 'test-pass', + }) + + result = await client.receive_json() + + assert not result['success'], result + assert result['error']['code'] == 'not_found' + + +async def test_create_auth_requires_owner(hass, hass_ws_client, + hass_access_token): + """Test create requires owner to call API.""" + client = await hass_ws_client(hass, hass_access_token) + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_CREATE, + 'user_id': 'test-id', + 'username': 'test-user', + 'password': 'test-pass', + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'unauthorized' + + +async def test_create_auth(hass, hass_ws_client, hass_access_token, + hass_storage): + """Test create auth command works.""" + client = await hass_ws_client(hass, hass_access_token) + user = MockUser().add_to_hass(hass) + hass_access_token.refresh_token.user.is_owner = True + + assert len(user.credentials) == 0 + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_CREATE, + 'user_id': user.id, + 'username': 'test-user', + 'password': 'test-pass', + }) + + result = await client.receive_json() + assert result['success'], result + assert len(user.credentials) == 1 + creds = user.credentials[0] + assert creds.auth_provider_type == 'homeassistant' + assert creds.auth_provider_id is None + assert creds.data == { + 'username': 'test-user' + } + assert prov_ha.STORAGE_KEY in hass_storage + entry = hass_storage[prov_ha.STORAGE_KEY]['data']['users'][0] + assert entry['username'] == 'test-user' + + +async def test_create_auth_duplicate_username(hass, hass_ws_client, + hass_access_token, hass_storage): + """Test we can't create auth with a duplicate username.""" + client = await hass_ws_client(hass, hass_access_token) + user = MockUser().add_to_hass(hass) + hass_access_token.refresh_token.user.is_owner = True + + hass_storage[prov_ha.STORAGE_KEY] = { + 'version': 1, + 'data': { + 'users': [{ + 'username': 'test-user' + }] + } + } + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_CREATE, + 'user_id': user.id, + 'username': 'test-user', + 'password': 'test-pass', + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'username_exists' + + +async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage, + hass_access_token): + """Test deleting an auth without being connected to a user.""" + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + + hass_storage[prov_ha.STORAGE_KEY] = { + 'version': 1, + 'data': { + 'users': [{ + 'username': 'test-user' + }] + } + } + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_DELETE, + 'username': 'test-user', + }) + + result = await client.receive_json() + assert result['success'], result + assert len(hass_storage[prov_ha.STORAGE_KEY]['data']['users']) == 0 + + +async def test_delete_removes_credential(hass, hass_ws_client, + hass_access_token, hass_storage): + """Test deleting auth that is connected to a user.""" + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + + user = MockUser().add_to_hass(hass) + user.credentials.append( + await hass.auth.auth_providers[0].async_get_or_create_credentials({ + 'username': 'test-user'})) + + hass_storage[prov_ha.STORAGE_KEY] = { + 'version': 1, + 'data': { + 'users': [{ + 'username': 'test-user' + }] + } + } + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_DELETE, + 'username': 'test-user', + }) + + result = await client.receive_json() + assert result['success'], result + assert len(hass_storage[prov_ha.STORAGE_KEY]['data']['users']) == 0 + + +async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token): + """Test delete requires owner.""" + client = await hass_ws_client(hass, hass_access_token) + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_DELETE, + 'username': 'test-user', + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'unauthorized' + + +async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token): + """Test trying to delete an unknown auth username.""" + client = await hass_ws_client(hass, hass_access_token) + hass_access_token.refresh_token.user.is_owner = True + + await client.send_json({ + 'id': 5, + 'type': auth_ha.WS_TYPE_DELETE, + 'username': 'test-user', + }) + + result = await client.receive_json() + assert not result['success'], result + assert result['error']['code'] == 'auth_not_found' diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 843866cbfbd..5f6a17a4101 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -2,6 +2,7 @@ import pytest from homeassistant.setup import async_setup_component +from homeassistant.components import websocket_api from tests.common import MockUser, CLIENT_ID @@ -9,13 +10,27 @@ from tests.common import MockUser, CLIENT_ID @pytest.fixture def hass_ws_client(aiohttp_client): """Websocket client fixture connected to websocket server.""" - async def create_client(hass): + async def create_client(hass, access_token=None): """Create a websocket client.""" wapi = hass.components.websocket_api assert await async_setup_component(hass, 'websocket_api') client = await aiohttp_client(hass.http.app) websocket = await client.ws_connect(wapi.URL) + auth_resp = await websocket.receive_json() + + if auth_resp['type'] == wapi.TYPE_AUTH_OK: + assert access_token is None, \ + 'Access token given but no auth required' + return websocket + + assert access_token is not None, 'Access token required for fixture' + + await websocket.send_json({ + 'type': websocket_api.TYPE_AUTH, + 'access_token': access_token.token + }) + auth_ok = await websocket.receive_json() assert auth_ok['type'] == wapi.TYPE_AUTH_OK diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py index 19785958422..31cba79a6c8 100644 --- a/tests/components/http/test_auth.py +++ b/tests/components/http/test_auth.py @@ -1,13 +1,12 @@ """The tests for the Home Assistant HTTP component.""" # pylint: disable=protected-access from ipaddress import ip_network -from unittest.mock import patch, Mock +from unittest.mock import patch import pytest from aiohttp import BasicAuth, web from aiohttp.web_exceptions import HTTPUnauthorized -from homeassistant.auth.models import AccessToken, RefreshToken from homeassistant.components.http.auth import setup_auth from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.components.http.real_ip import setup_real_ip @@ -16,8 +15,6 @@ from homeassistant.setup import async_setup_component from . import mock_real_ip -ACCESS_TOKEN = 'tk.1234' - API_PASSWORD = 'test1234' # Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases @@ -39,33 +36,21 @@ async def mock_handler(request): return web.Response(status=200) -def mock_async_get_access_token(token): - """Return if token is valid.""" - if token == ACCESS_TOKEN: - return Mock(spec=AccessToken, - token=ACCESS_TOKEN, - refresh_token=Mock(spec=RefreshToken)) - else: - return None - - @pytest.fixture -def app(): +def app(hass): """Fixture to setup a web.Application.""" app = web.Application() - mock_auth = Mock(async_get_access_token=mock_async_get_access_token) - app['hass'] = Mock(auth=mock_auth) + app['hass'] = hass app.router.add_get('/', mock_handler) setup_real_ip(app, False, []) return app @pytest.fixture -def app2(): +def app2(hass): """Fixture to setup a web.Application without real_ip middleware.""" app = web.Application() - mock_auth = Mock(async_get_access_token=mock_async_get_access_token) - app['hass'] = Mock(auth=mock_auth) + app['hass'] = hass app.router.add_get('/', mock_handler) return app @@ -171,33 +156,35 @@ async def test_access_with_trusted_ip(app2, aiohttp_client): async def test_auth_active_access_with_access_token_in_header( - app, aiohttp_client): + app, aiohttp_client, hass_access_token): """Test access with access token in header.""" + token = hass_access_token.token setup_auth(app, [], True, api_password=None) client = await aiohttp_client(app) req = await client.get( - '/', headers={'Authorization': 'Bearer {}'.format(ACCESS_TOKEN)}) + '/', headers={'Authorization': 'Bearer {}'.format(token)}) assert req.status == 200 req = await client.get( - '/', headers={'AUTHORIZATION': 'Bearer {}'.format(ACCESS_TOKEN)}) + '/', headers={'AUTHORIZATION': 'Bearer {}'.format(token)}) assert req.status == 200 req = await client.get( - '/', headers={'authorization': 'Bearer {}'.format(ACCESS_TOKEN)}) + '/', headers={'authorization': 'Bearer {}'.format(token)}) assert req.status == 200 req = await client.get( - '/', headers={'Authorization': ACCESS_TOKEN}) + '/', headers={'Authorization': token}) assert req.status == 401 req = await client.get( - '/', headers={'Authorization': 'BEARER {}'.format(ACCESS_TOKEN)}) + '/', headers={'Authorization': 'BEARER {}'.format(token)}) assert req.status == 401 + hass_access_token.refresh_token.user.is_active = False req = await client.get( - '/', headers={'Authorization': 'Bearer wrong-pass'}) + '/', headers={'Authorization': 'Bearer {}'.format(token)}) assert req.status == 401 diff --git a/tests/conftest.py b/tests/conftest.py index 0a350b62fc1..28c47948666 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ if os.environ.get('UVLOOP') == '1': import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) diff --git a/tests/scripts/test_auth.py b/tests/scripts/test_auth.py index cd0524eb032..1320be299b8 100644 --- a/tests/scripts/test_auth.py +++ b/tests/scripts/test_auth.py @@ -6,21 +6,26 @@ import pytest from homeassistant.scripts import auth as script_auth from homeassistant.auth.providers import homeassistant as hass_auth +from tests.common import register_auth_provider + @pytest.fixture -def data(hass): - """Create a loaded data class.""" - data = hass_auth.Data(hass) - hass.loop.run_until_complete(data.async_load()) - return data +def provider(hass): + """Home Assistant auth provider.""" + provider = hass.loop.run_until_complete(register_auth_provider(hass, { + 'type': 'homeassistant', + })) + hass.loop.run_until_complete(provider.async_initialize()) + return provider -async def test_list_user(data, capsys): +async def test_list_user(hass, provider, capsys): """Test we can list users.""" - data.add_user('test-user', 'test-pass') - data.add_user('second-user', 'second-pass') + data = provider.data + data.add_auth('test-user', 'test-pass') + data.add_auth('second-user', 'second-pass') - await script_auth.list_users(data, None) + await script_auth.list_users(hass, provider, None) captured = capsys.readouterr() @@ -33,10 +38,11 @@ async def test_list_user(data, capsys): ]) -async def test_add_user(data, capsys, hass_storage): +async def test_add_user(hass, provider, capsys, hass_storage): """Test we can add a user.""" + data = provider.data await script_auth.add_user( - data, Mock(username='paulus', password='test-pass')) + hass, provider, Mock(username='paulus', password='test-pass')) assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1 @@ -47,32 +53,34 @@ async def test_add_user(data, capsys, hass_storage): data.validate_login('paulus', 'test-pass') -async def test_validate_login(data, capsys): +async def test_validate_login(hass, provider, capsys): """Test we can validate a user login.""" - data.add_user('test-user', 'test-pass') + data = provider.data + data.add_auth('test-user', 'test-pass') await script_auth.validate_login( - data, Mock(username='test-user', password='test-pass')) + hass, provider, Mock(username='test-user', password='test-pass')) captured = capsys.readouterr() assert captured.out == 'Auth valid\n' await script_auth.validate_login( - data, Mock(username='test-user', password='invalid-pass')) + hass, provider, Mock(username='test-user', password='invalid-pass')) captured = capsys.readouterr() assert captured.out == 'Auth invalid\n' await script_auth.validate_login( - data, Mock(username='invalid-user', password='test-pass')) + hass, provider, Mock(username='invalid-user', password='test-pass')) captured = capsys.readouterr() assert captured.out == 'Auth invalid\n' -async def test_change_password(data, capsys, hass_storage): +async def test_change_password(hass, provider, capsys, hass_storage): """Test we can change a password.""" - data.add_user('test-user', 'test-pass') + data = provider.data + data.add_auth('test-user', 'test-pass') await script_auth.change_password( - data, Mock(username='test-user', new_password='new-pass')) + hass, provider, Mock(username='test-user', new_password='new-pass')) assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1 captured = capsys.readouterr() @@ -82,12 +90,14 @@ async def test_change_password(data, capsys, hass_storage): data.validate_login('test-user', 'test-pass') -async def test_change_password_invalid_user(data, capsys, hass_storage): +async def test_change_password_invalid_user(hass, provider, capsys, + hass_storage): """Test changing password of non-existing user.""" - data.add_user('test-user', 'test-pass') + data = provider.data + data.add_auth('test-user', 'test-pass') await script_auth.change_password( - data, Mock(username='invalid-user', new_password='new-pass')) + hass, provider, Mock(username='invalid-user', new_password='new-pass')) assert hass_auth.STORAGE_KEY not in hass_storage captured = capsys.readouterr() @@ -101,11 +111,11 @@ def test_parsing_args(loop): """Test we parse args correctly.""" called = False - async def mock_func(data, args2): + async def mock_func(hass, provider, args2): """Mock function to be called.""" nonlocal called called = True - assert data.hass.config.config_dir == '/somewhere/config' + assert provider.hass.config.config_dir == '/somewhere/config' assert args2 is args args = Mock(config='/somewhere/config', func=mock_func)