Migrate home assistant auth provider to use storage helper (#15200)

This commit is contained in:
Paulus Schoutsen 2018-06-29 00:02:45 -04:00 committed by GitHub
parent 39971ee919
commit 26590e244c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 116 deletions

View file

@ -8,10 +8,10 @@ import voluptuous as vol
from homeassistant import auth, data_entry_flow from homeassistant import auth, data_entry_flow
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import json
PATH_DATA = '.users.json' STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant'
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({ CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
}, extra=vol.PREVENT_EXTRA) }, extra=vol.PREVENT_EXTRA)
@ -31,14 +31,22 @@ class InvalidUser(HomeAssistantError):
class Data: class Data:
"""Hold the user data.""" """Hold the user data."""
def __init__(self, path, data): def __init__(self, hass):
"""Initialize the user data store.""" """Initialize the user data store."""
self.path = path self.hass = hass
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._data = None
async def async_load(self):
"""Load stored data."""
data = await self._store.async_load()
if data is None: if data is None:
data = { data = {
'salt': auth.generate_secret(), 'salt': auth.generate_secret(),
'users': [] 'users': []
} }
self._data = data self._data = data
@property @property
@ -99,14 +107,9 @@ class Data:
else: else:
raise InvalidUser raise InvalidUser
def save(self): async def async_save(self):
"""Save data.""" """Save data."""
json.save_json(self.path, self._data) await self._store.async_save(self._data)
def load_data(path):
"""Load auth data."""
return Data(path, json.load_json(path, None))
@auth.AUTH_PROVIDERS.register('homeassistant') @auth.AUTH_PROVIDERS.register('homeassistant')
@ -121,12 +124,10 @@ class HassAuthProvider(auth.AuthProvider):
async def async_validate_login(self, username, password): async def async_validate_login(self, username, password):
"""Helper to validate a username and password.""" """Helper to validate a username and password."""
def validate(): data = Data(self.hass)
"""Validate creds.""" await data.async_load()
data = self._auth_data() await self.hass.async_add_executor_job(
data.validate_login(username, password) data.validate_login, username, password)
await self.hass.async_add_job(validate)
async def async_get_or_create_credentials(self, flow_result): async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
@ -141,10 +142,6 @@ class HassAuthProvider(auth.AuthProvider):
'username': username 'username': username
}) })
def _auth_data(self):
"""Return the auth provider data."""
return load_data(self.hass.config.path(PATH_DATA))
class LoginFlow(data_entry_flow.FlowHandler): class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow.""" """Handler for the login flow."""

View file

@ -1,7 +1,9 @@
"""Script to manage users for the Home Assistant auth provider.""" """Script to manage users for the Home Assistant auth provider."""
import argparse import argparse
import asyncio
import os import os
from homeassistant.core import HomeAssistant
from homeassistant.config import get_default_config_dir from homeassistant.config import get_default_config_dir
from homeassistant.auth_providers import homeassistant as hass_auth from homeassistant.auth_providers import homeassistant as hass_auth
@ -17,7 +19,8 @@ def run(args):
default=get_default_config_dir(), default=get_default_config_dir(),
help="Directory that contains the Home Assistant configuration") help="Directory that contains the Home Assistant configuration")
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers(dest='func')
subparsers.required = True
parser_list = subparsers.add_parser('list') parser_list = subparsers.add_parser('list')
parser_list.set_defaults(func=list_users) parser_list.set_defaults(func=list_users)
@ -37,11 +40,15 @@ def run(args):
parser_change_pw.set_defaults(func=change_password) parser_change_pw.set_defaults(func=change_password)
args = parser.parse_args(args) args = parser.parse_args(args)
path = os.path.join(os.getcwd(), args.config, hass_auth.PATH_DATA) loop = asyncio.get_event_loop()
args.func(hass_auth.load_data(path), args) hass = HomeAssistant(loop=loop)
hass.config.config_dir = os.path.join(os.getcwd(), args.config)
data = hass_auth.Data(hass)
loop.run_until_complete(data.async_load())
loop.run_until_complete(args.func(data, args))
def list_users(data, args): async def list_users(data, args):
"""List the users.""" """List the users."""
count = 0 count = 0
for user in data.users: for user in data.users:
@ -52,14 +59,14 @@ def list_users(data, args):
print("Total users:", count) print("Total users:", count)
def add_user(data, args): async def add_user(data, args):
"""Create a user.""" """Create a user."""
data.add_user(args.username, args.password) data.add_user(args.username, args.password)
data.save() await data.async_save()
print("User created") print("User created")
def validate_login(data, args): async def validate_login(data, args):
"""Validate a login.""" """Validate a login."""
try: try:
data.validate_login(args.username, args.password) data.validate_login(args.username, args.password)
@ -68,11 +75,11 @@ def validate_login(data, args):
print("Auth invalid") print("Auth invalid")
def change_password(data, args): async def change_password(data, args):
"""Change password.""" """Change password."""
try: try:
data.change_password(args.username, args.new_password) data.change_password(args.username, args.new_password)
data.save() await data.async_save()
print("Password changed") print("Password changed")
except hass_auth.InvalidUser: except hass_auth.InvalidUser:
print("User not found") print("User not found")

View file

@ -1,60 +1,48 @@
"""Test the Home Assistant local auth provider.""" """Test the Home Assistant local auth provider."""
from unittest.mock import patch, mock_open
import pytest import pytest
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth_providers import homeassistant as hass_auth from homeassistant.auth_providers import homeassistant as hass_auth
MOCK_PATH = '/bla/users.json' @pytest.fixture
JSON__OPEN_PATH = 'homeassistant.util.json.open' def data(hass):
"""Create a loaded data class."""
data = hass_auth.Data(hass)
hass.loop.run_until_complete(data.async_load())
return data
def test_initialize_empty_config_file_not_found(): async def test_adding_user(data, hass):
"""Test that we initialize an empty config."""
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
data = hass_auth.load_data(MOCK_PATH)
assert data is not None
def test_adding_user():
"""Test adding a user.""" """Test adding a user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
data.validate_login('test-user', 'test-pass') data.validate_login('test-user', 'test-pass')
def test_adding_user_duplicate_username(): async def test_adding_user_duplicate_username(data, hass):
"""Test adding a user.""" """Test adding a user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidUser): with pytest.raises(hass_auth.InvalidUser):
data.add_user('test-user', 'other-pass') data.add_user('test-user', 'other-pass')
def test_validating_password_invalid_user(): async def test_validating_password_invalid_user(data, hass):
"""Test validating an invalid user.""" """Test validating an invalid user."""
data = hass_auth.Data(MOCK_PATH, None)
with pytest.raises(hass_auth.InvalidAuth): with pytest.raises(hass_auth.InvalidAuth):
data.validate_login('non-existing', 'pw') data.validate_login('non-existing', 'pw')
def test_validating_password_invalid_password(): async def test_validating_password_invalid_password(data, hass):
"""Test validating an invalid user.""" """Test validating an invalid user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidAuth): with pytest.raises(hass_auth.InvalidAuth):
data.validate_login('test-user', 'invalid-pass') data.validate_login('test-user', 'invalid-pass')
def test_changing_password(): async def test_changing_password(data, hass):
"""Test adding a user.""" """Test adding a user."""
user = 'test-user' user = 'test-user'
data = hass_auth.Data(MOCK_PATH, None)
data.add_user(user, 'test-pass') data.add_user(user, 'test-pass')
data.change_password(user, 'new-pass') data.change_password(user, 'new-pass')
@ -64,61 +52,50 @@ def test_changing_password():
data.validate_login(user, 'new-pass') data.validate_login(user, 'new-pass')
def test_changing_password_raises_invalid_user(): async def test_changing_password_raises_invalid_user(data, hass):
"""Test that we initialize an empty config.""" """Test that we initialize an empty config."""
data = hass_auth.Data(MOCK_PATH, None)
with pytest.raises(hass_auth.InvalidUser): with pytest.raises(hass_auth.InvalidUser):
data.change_password('non-existing', 'pw') data.change_password('non-existing', 'pw')
async def test_login_flow_validates(hass): async def test_login_flow_validates(data, hass):
"""Test login flow.""" """Test login flow."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
await data.async_save()
provider = hass_auth.HassAuthProvider(hass, None, {}) provider = hass_auth.HassAuthProvider(hass, None, {})
flow = hass_auth.LoginFlow(provider) flow = hass_auth.LoginFlow(provider)
result = await flow.async_step_init() result = await flow.async_step_init()
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
with patch.object(provider, '_auth_data', return_value=data): result = await flow.async_step_init({
result = await flow.async_step_init({ 'username': 'incorrect-user',
'username': 'incorrect-user', 'password': 'test-pass',
'password': 'test-pass', })
}) assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM assert result['errors']['base'] == 'invalid_auth'
assert result['errors']['base'] == 'invalid_auth'
result = await flow.async_step_init({ result = await flow.async_step_init({
'username': 'test-user', 'username': 'test-user',
'password': 'incorrect-pass', 'password': 'incorrect-pass',
}) })
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['errors']['base'] == 'invalid_auth' assert result['errors']['base'] == 'invalid_auth'
result = await flow.async_step_init({ result = await flow.async_step_init({
'username': 'test-user', 'username': 'test-user',
'password': 'test-pass', 'password': 'test-pass',
}) })
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
async def test_saving_loading(hass): async def test_saving_loading(data, hass):
"""Test saving and loading JSON.""" """Test saving and loading JSON."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
data.add_user('second-user', 'second-pass') data.add_user('second-user', 'second-pass')
await data.async_save()
with patch(JSON__OPEN_PATH, mock_open(), create=True) as mock_write: data = hass_auth.Data(hass)
await hass.async_add_job(data.save) await data.async_load()
# Mock open calls are: open file, context enter, write, context leave
written = mock_write.mock_calls[2][1][0]
with patch('os.path.isfile', return_value=True), \
patch(JSON__OPEN_PATH, mock_open(read_data=written), create=True):
await hass.async_add_job(hass_auth.load_data, MOCK_PATH)
data.validate_login('test-user', 'test-pass') data.validate_login('test-user', 'test-pass')
data.validate_login('second-user', 'second-pass') data.validate_login('second-user', 'second-pass')

View file

@ -6,16 +6,21 @@ import pytest
from homeassistant.scripts import auth as script_auth from homeassistant.scripts import auth as script_auth
from homeassistant.auth_providers import homeassistant as hass_auth from homeassistant.auth_providers import homeassistant as hass_auth
MOCK_PATH = '/bla/users.json'
@pytest.fixture
def data(hass):
"""Create a loaded data class."""
data = hass_auth.Data(hass)
hass.loop.run_until_complete(data.async_load())
return data
def test_list_user(capsys): async def test_list_user(data, capsys):
"""Test we can list users.""" """Test we can list users."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
data.add_user('second-user', 'second-pass') data.add_user('second-user', 'second-pass')
script_auth.list_users(data, None) await script_auth.list_users(data, None)
captured = capsys.readouterr() captured = capsys.readouterr()
@ -28,15 +33,12 @@ def test_list_user(capsys):
]) ])
def test_add_user(capsys): async def test_add_user(data, capsys, hass_storage):
"""Test we can add a user.""" """Test we can add a user."""
data = hass_auth.Data(MOCK_PATH, None) await script_auth.add_user(
data, Mock(username='paulus', password='test-pass'))
with patch.object(data, 'save') as mock_save: assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
script_auth.add_user(
data, Mock(username='paulus', password='test-pass'))
assert len(mock_save.mock_calls) == 1
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'User created\n' assert captured.out == 'User created\n'
@ -45,37 +47,34 @@ def test_add_user(capsys):
data.validate_login('paulus', 'test-pass') data.validate_login('paulus', 'test-pass')
def test_validate_login(capsys): async def test_validate_login(data, capsys):
"""Test we can validate a user login.""" """Test we can validate a user login."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
script_auth.validate_login( await script_auth.validate_login(
data, Mock(username='test-user', password='test-pass')) data, Mock(username='test-user', password='test-pass'))
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'Auth valid\n' assert captured.out == 'Auth valid\n'
script_auth.validate_login( await script_auth.validate_login(
data, Mock(username='test-user', password='invalid-pass')) data, Mock(username='test-user', password='invalid-pass'))
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'Auth invalid\n' assert captured.out == 'Auth invalid\n'
script_auth.validate_login( await script_auth.validate_login(
data, Mock(username='invalid-user', password='test-pass')) data, Mock(username='invalid-user', password='test-pass'))
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'Auth invalid\n' assert captured.out == 'Auth invalid\n'
def test_change_password(capsys): async def test_change_password(data, capsys, hass_storage):
"""Test we can change a password.""" """Test we can change a password."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
with patch.object(data, 'save') as mock_save: await script_auth.change_password(
script_auth.change_password( data, Mock(username='test-user', new_password='new-pass'))
data, Mock(username='test-user', new_password='new-pass'))
assert len(mock_save.mock_calls) == 1 assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'Password changed\n' assert captured.out == 'Password changed\n'
data.validate_login('test-user', 'new-pass') data.validate_login('test-user', 'new-pass')
@ -83,18 +82,35 @@ def test_change_password(capsys):
data.validate_login('test-user', 'test-pass') data.validate_login('test-user', 'test-pass')
def test_change_password_invalid_user(capsys): async def test_change_password_invalid_user(data, capsys, hass_storage):
"""Test changing password of non-existing user.""" """Test changing password of non-existing user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass') data.add_user('test-user', 'test-pass')
with patch.object(data, 'save') as mock_save: await script_auth.change_password(
script_auth.change_password( data, Mock(username='invalid-user', new_password='new-pass'))
data, Mock(username='invalid-user', new_password='new-pass'))
assert len(mock_save.mock_calls) == 0 assert hass_auth.STORAGE_KEY not in hass_storage
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'User not found\n' assert captured.out == 'User not found\n'
data.validate_login('test-user', 'test-pass') data.validate_login('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidAuth): with pytest.raises(hass_auth.InvalidAuth):
data.validate_login('invalid-user', 'new-pass') data.validate_login('invalid-user', 'new-pass')
def test_parsing_args(loop):
"""Test we parse args correctly."""
called = False
async def mock_func(data, args2):
"""Mock function to be called."""
nonlocal called
called = True
assert data.hass.config.config_dir == '/somewhere/config'
assert args2 is args
args = Mock(config='/somewhere/config', func=mock_func)
with patch('argparse.ArgumentParser.parse_args', return_value=args):
script_auth.run(None)
assert called, 'Mock function did not get called'