From 26590e244ced1b67792bac8fdd5fc81ac446dba1 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 29 Jun 2018 00:02:45 -0400 Subject: [PATCH] Migrate home assistant auth provider to use storage helper (#15200) --- homeassistant/auth_providers/homeassistant.py | 39 ++++---- homeassistant/scripts/auth.py | 25 +++-- tests/auth_providers/test_homeassistant.py | 93 +++++++------------ tests/scripts/test_auth.py | 72 ++++++++------ 4 files changed, 113 insertions(+), 116 deletions(-) diff --git a/homeassistant/auth_providers/homeassistant.py b/homeassistant/auth_providers/homeassistant.py index c2db193ce1a..c4d2021f6ce 100644 --- a/homeassistant/auth_providers/homeassistant.py +++ b/homeassistant/auth_providers/homeassistant.py @@ -8,10 +8,10 @@ import voluptuous as vol from homeassistant import auth, data_entry_flow from homeassistant.exceptions import HomeAssistantError -from homeassistant.util import json -PATH_DATA = '.users.json' +STORAGE_VERSION = 1 +STORAGE_KEY = 'auth_provider.homeassistant' CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({ }, extra=vol.PREVENT_EXTRA) @@ -31,14 +31,22 @@ class InvalidUser(HomeAssistantError): class Data: """Hold the user data.""" - def __init__(self, path, data): + def __init__(self, hass): """Initialize the user data store.""" - self.path = path + self.hass = hass + self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + self._data = None + + async def async_load(self): + """Load stored data.""" + data = await self._store.async_load() + if data is None: data = { 'salt': auth.generate_secret(), 'users': [] } + self._data = data @property @@ -99,14 +107,9 @@ class Data: else: raise InvalidUser - def save(self): + async def async_save(self): """Save data.""" - json.save_json(self.path, self._data) - - -def load_data(path): - """Load auth data.""" - return Data(path, json.load_json(path, None)) + await self._store.async_save(self._data) @auth.AUTH_PROVIDERS.register('homeassistant') @@ -121,12 +124,10 @@ class HassAuthProvider(auth.AuthProvider): async def async_validate_login(self, username, password): """Helper to validate a username and password.""" - def validate(): - """Validate creds.""" - data = self._auth_data() - data.validate_login(username, password) - - await self.hass.async_add_job(validate) + data = Data(self.hass) + await data.async_load() + await self.hass.async_add_executor_job( + data.validate_login, username, password) async def async_get_or_create_credentials(self, flow_result): """Get credentials based on the flow result.""" @@ -141,10 +142,6 @@ class HassAuthProvider(auth.AuthProvider): 'username': username }) - def _auth_data(self): - """Return the auth provider data.""" - return load_data(self.hass.config.path(PATH_DATA)) - class LoginFlow(data_entry_flow.FlowHandler): """Handler for the login flow.""" diff --git a/homeassistant/scripts/auth.py b/homeassistant/scripts/auth.py index b4f1ddd2f11..dacdc7b18e2 100644 --- a/homeassistant/scripts/auth.py +++ b/homeassistant/scripts/auth.py @@ -1,7 +1,9 @@ """Script to manage users for the Home Assistant auth provider.""" import argparse +import asyncio import os +from homeassistant.core import HomeAssistant from homeassistant.config import get_default_config_dir from homeassistant.auth_providers import homeassistant as hass_auth @@ -17,7 +19,8 @@ def run(args): default=get_default_config_dir(), help="Directory that contains the Home Assistant configuration") - subparsers = parser.add_subparsers() + subparsers = parser.add_subparsers(dest='func') + subparsers.required = True parser_list = subparsers.add_parser('list') parser_list.set_defaults(func=list_users) @@ -37,11 +40,15 @@ def run(args): parser_change_pw.set_defaults(func=change_password) args = parser.parse_args(args) - path = os.path.join(os.getcwd(), args.config, hass_auth.PATH_DATA) - args.func(hass_auth.load_data(path), args) + loop = asyncio.get_event_loop() + hass = HomeAssistant(loop=loop) + hass.config.config_dir = os.path.join(os.getcwd(), args.config) + data = hass_auth.Data(hass) + loop.run_until_complete(data.async_load()) + loop.run_until_complete(args.func(data, args)) -def list_users(data, args): +async def list_users(data, args): """List the users.""" count = 0 for user in data.users: @@ -52,14 +59,14 @@ def list_users(data, args): print("Total users:", count) -def add_user(data, args): +async def add_user(data, args): """Create a user.""" data.add_user(args.username, args.password) - data.save() + await data.async_save() print("User created") -def validate_login(data, args): +async def validate_login(data, args): """Validate a login.""" try: data.validate_login(args.username, args.password) @@ -68,11 +75,11 @@ def validate_login(data, args): print("Auth invalid") -def change_password(data, args): +async def change_password(data, args): """Change password.""" try: data.change_password(args.username, args.new_password) - data.save() + await data.async_save() print("Password changed") except hass_auth.InvalidUser: print("User not found") diff --git a/tests/auth_providers/test_homeassistant.py b/tests/auth_providers/test_homeassistant.py index 8b12e682865..1d9a29bf48b 100644 --- a/tests/auth_providers/test_homeassistant.py +++ b/tests/auth_providers/test_homeassistant.py @@ -1,60 +1,48 @@ """Test the Home Assistant local auth provider.""" -from unittest.mock import patch, mock_open - import pytest from homeassistant import data_entry_flow from homeassistant.auth_providers import homeassistant as hass_auth -MOCK_PATH = '/bla/users.json' -JSON__OPEN_PATH = 'homeassistant.util.json.open' +@pytest.fixture +def data(hass): + """Create a loaded data class.""" + data = hass_auth.Data(hass) + hass.loop.run_until_complete(data.async_load()) + return data -def test_initialize_empty_config_file_not_found(): - """Test that we initialize an empty config.""" - with patch('homeassistant.util.json.open', side_effect=FileNotFoundError): - data = hass_auth.load_data(MOCK_PATH) - - assert data is not None - - -def test_adding_user(): +async def test_adding_user(data, hass): """Test adding a user.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') data.validate_login('test-user', 'test-pass') -def test_adding_user_duplicate_username(): +async def test_adding_user_duplicate_username(data, hass): """Test adding a user.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') with pytest.raises(hass_auth.InvalidUser): data.add_user('test-user', 'other-pass') -def test_validating_password_invalid_user(): +async def test_validating_password_invalid_user(data, hass): """Test validating an invalid user.""" - data = hass_auth.Data(MOCK_PATH, None) - with pytest.raises(hass_auth.InvalidAuth): data.validate_login('non-existing', 'pw') -def test_validating_password_invalid_password(): +async def test_validating_password_invalid_password(data, hass): """Test validating an invalid user.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') with pytest.raises(hass_auth.InvalidAuth): data.validate_login('test-user', 'invalid-pass') -def test_changing_password(): +async def test_changing_password(data, hass): """Test adding a user.""" user = 'test-user' - data = hass_auth.Data(MOCK_PATH, None) data.add_user(user, 'test-pass') data.change_password(user, 'new-pass') @@ -64,61 +52,50 @@ def test_changing_password(): data.validate_login(user, 'new-pass') -def test_changing_password_raises_invalid_user(): +async def test_changing_password_raises_invalid_user(data, hass): """Test that we initialize an empty config.""" - data = hass_auth.Data(MOCK_PATH, None) - with pytest.raises(hass_auth.InvalidUser): data.change_password('non-existing', 'pw') -async def test_login_flow_validates(hass): +async def test_login_flow_validates(data, hass): """Test login flow.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') + await data.async_save() provider = hass_auth.HassAuthProvider(hass, None, {}) flow = hass_auth.LoginFlow(provider) result = await flow.async_step_init() assert result['type'] == data_entry_flow.RESULT_TYPE_FORM - with patch.object(provider, '_auth_data', return_value=data): - result = await flow.async_step_init({ - 'username': 'incorrect-user', - 'password': 'test-pass', - }) - assert result['type'] == data_entry_flow.RESULT_TYPE_FORM - assert result['errors']['base'] == 'invalid_auth' + result = await flow.async_step_init({ + 'username': 'incorrect-user', + 'password': 'test-pass', + }) + assert result['type'] == data_entry_flow.RESULT_TYPE_FORM + assert result['errors']['base'] == 'invalid_auth' - result = await flow.async_step_init({ - 'username': 'test-user', - 'password': 'incorrect-pass', - }) - assert result['type'] == data_entry_flow.RESULT_TYPE_FORM - assert result['errors']['base'] == 'invalid_auth' + result = await flow.async_step_init({ + 'username': 'test-user', + 'password': 'incorrect-pass', + }) + assert result['type'] == data_entry_flow.RESULT_TYPE_FORM + assert result['errors']['base'] == 'invalid_auth' - result = await flow.async_step_init({ - 'username': 'test-user', - 'password': 'test-pass', - }) - assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + result = await flow.async_step_init({ + 'username': 'test-user', + 'password': 'test-pass', + }) + assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY -async def test_saving_loading(hass): +async def test_saving_loading(data, hass): """Test saving and loading JSON.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') data.add_user('second-user', 'second-pass') + await data.async_save() - with patch(JSON__OPEN_PATH, mock_open(), create=True) as mock_write: - await hass.async_add_job(data.save) - - # Mock open calls are: open file, context enter, write, context leave - written = mock_write.mock_calls[2][1][0] - - with patch('os.path.isfile', return_value=True), \ - patch(JSON__OPEN_PATH, mock_open(read_data=written), create=True): - await hass.async_add_job(hass_auth.load_data, MOCK_PATH) - + data = hass_auth.Data(hass) + await data.async_load() data.validate_login('test-user', 'test-pass') data.validate_login('second-user', 'second-pass') diff --git a/tests/scripts/test_auth.py b/tests/scripts/test_auth.py index 2e837b06b58..e6aa7893f33 100644 --- a/tests/scripts/test_auth.py +++ b/tests/scripts/test_auth.py @@ -6,16 +6,21 @@ import pytest from homeassistant.scripts import auth as script_auth from homeassistant.auth_providers import homeassistant as hass_auth -MOCK_PATH = '/bla/users.json' + +@pytest.fixture +def data(hass): + """Create a loaded data class.""" + data = hass_auth.Data(hass) + hass.loop.run_until_complete(data.async_load()) + return data -def test_list_user(capsys): +async def test_list_user(data, capsys): """Test we can list users.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') data.add_user('second-user', 'second-pass') - script_auth.list_users(data, None) + await script_auth.list_users(data, None) captured = capsys.readouterr() @@ -28,15 +33,12 @@ def test_list_user(capsys): ]) -def test_add_user(capsys): +async def test_add_user(data, capsys, hass_storage): """Test we can add a user.""" - data = hass_auth.Data(MOCK_PATH, None) + await script_auth.add_user( + data, Mock(username='paulus', password='test-pass')) - with patch.object(data, 'save') as mock_save: - script_auth.add_user( - data, Mock(username='paulus', password='test-pass')) - - assert len(mock_save.mock_calls) == 1 + assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1 captured = capsys.readouterr() assert captured.out == 'User created\n' @@ -45,37 +47,34 @@ def test_add_user(capsys): data.validate_login('paulus', 'test-pass') -def test_validate_login(capsys): +async def test_validate_login(data, capsys): """Test we can validate a user login.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') - script_auth.validate_login( + await script_auth.validate_login( data, Mock(username='test-user', password='test-pass')) captured = capsys.readouterr() assert captured.out == 'Auth valid\n' - script_auth.validate_login( + await script_auth.validate_login( data, Mock(username='test-user', password='invalid-pass')) captured = capsys.readouterr() assert captured.out == 'Auth invalid\n' - script_auth.validate_login( + await script_auth.validate_login( data, Mock(username='invalid-user', password='test-pass')) captured = capsys.readouterr() assert captured.out == 'Auth invalid\n' -def test_change_password(capsys): +async def test_change_password(data, capsys, hass_storage): """Test we can change a password.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') - with patch.object(data, 'save') as mock_save: - script_auth.change_password( - data, Mock(username='test-user', new_password='new-pass')) + await script_auth.change_password( + data, Mock(username='test-user', new_password='new-pass')) - assert len(mock_save.mock_calls) == 1 + assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1 captured = capsys.readouterr() assert captured.out == 'Password changed\n' data.validate_login('test-user', 'new-pass') @@ -83,18 +82,35 @@ def test_change_password(capsys): data.validate_login('test-user', 'test-pass') -def test_change_password_invalid_user(capsys): +async def test_change_password_invalid_user(data, capsys, hass_storage): """Test changing password of non-existing user.""" - data = hass_auth.Data(MOCK_PATH, None) data.add_user('test-user', 'test-pass') - with patch.object(data, 'save') as mock_save: - script_auth.change_password( - data, Mock(username='invalid-user', new_password='new-pass')) + await script_auth.change_password( + data, Mock(username='invalid-user', new_password='new-pass')) - assert len(mock_save.mock_calls) == 0 + assert hass_auth.STORAGE_KEY not in hass_storage captured = capsys.readouterr() assert captured.out == 'User not found\n' data.validate_login('test-user', 'test-pass') with pytest.raises(hass_auth.InvalidAuth): data.validate_login('invalid-user', 'new-pass') + + +def test_parsing_args(loop): + """Test we parse args correctly.""" + called = False + + async def mock_func(data, args2): + """Mock function to be called.""" + nonlocal called + called = True + assert data.hass.config.config_dir == '/somewhere/config' + assert args2 is args + + args = Mock(config='/somewhere/config', func=mock_func) + + with patch('argparse.ArgumentParser.parse_args', return_value=args): + script_auth.run(None) + + assert called, 'Mock function did not get called'