Migrate home assistant auth provider to use storage helper (#15200)
This commit is contained in:
parent
39971ee919
commit
26590e244c
4 changed files with 113 additions and 116 deletions
|
@ -8,10 +8,10 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import auth, data_entry_flow
|
from homeassistant import auth, data_entry_flow
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.util import json
|
|
||||||
|
|
||||||
|
|
||||||
PATH_DATA = '.users.json'
|
STORAGE_VERSION = 1
|
||||||
|
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||||
|
|
||||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
||||||
}, extra=vol.PREVENT_EXTRA)
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
|
@ -31,14 +31,22 @@ class InvalidUser(HomeAssistantError):
|
||||||
class Data:
|
class Data:
|
||||||
"""Hold the user data."""
|
"""Hold the user data."""
|
||||||
|
|
||||||
def __init__(self, path, data):
|
def __init__(self, hass):
|
||||||
"""Initialize the user data store."""
|
"""Initialize the user data store."""
|
||||||
self.path = path
|
self.hass = hass
|
||||||
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
self._data = None
|
||||||
|
|
||||||
|
async def async_load(self):
|
||||||
|
"""Load stored data."""
|
||||||
|
data = await self._store.async_load()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
data = {
|
data = {
|
||||||
'salt': auth.generate_secret(),
|
'salt': auth.generate_secret(),
|
||||||
'users': []
|
'users': []
|
||||||
}
|
}
|
||||||
|
|
||||||
self._data = data
|
self._data = data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -99,14 +107,9 @@ class Data:
|
||||||
else:
|
else:
|
||||||
raise InvalidUser
|
raise InvalidUser
|
||||||
|
|
||||||
def save(self):
|
async def async_save(self):
|
||||||
"""Save data."""
|
"""Save data."""
|
||||||
json.save_json(self.path, self._data)
|
await self._store.async_save(self._data)
|
||||||
|
|
||||||
|
|
||||||
def load_data(path):
|
|
||||||
"""Load auth data."""
|
|
||||||
return Data(path, json.load_json(path, None))
|
|
||||||
|
|
||||||
|
|
||||||
@auth.AUTH_PROVIDERS.register('homeassistant')
|
@auth.AUTH_PROVIDERS.register('homeassistant')
|
||||||
|
@ -121,12 +124,10 @@ class HassAuthProvider(auth.AuthProvider):
|
||||||
|
|
||||||
async def async_validate_login(self, username, password):
|
async def async_validate_login(self, username, password):
|
||||||
"""Helper to validate a username and password."""
|
"""Helper to validate a username and password."""
|
||||||
def validate():
|
data = Data(self.hass)
|
||||||
"""Validate creds."""
|
await data.async_load()
|
||||||
data = self._auth_data()
|
await self.hass.async_add_executor_job(
|
||||||
data.validate_login(username, password)
|
data.validate_login, username, password)
|
||||||
|
|
||||||
await self.hass.async_add_job(validate)
|
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
async def async_get_or_create_credentials(self, flow_result):
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
|
@ -141,10 +142,6 @@ class HassAuthProvider(auth.AuthProvider):
|
||||||
'username': username
|
'username': username
|
||||||
})
|
})
|
||||||
|
|
||||||
def _auth_data(self):
|
|
||||||
"""Return the auth provider data."""
|
|
||||||
return load_data(self.hass.config.path(PATH_DATA))
|
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Script to manage users for the Home Assistant auth provider."""
|
"""Script to manage users for the Home Assistant auth provider."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.config import get_default_config_dir
|
from homeassistant.config import get_default_config_dir
|
||||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
@ -17,7 +19,8 @@ def run(args):
|
||||||
default=get_default_config_dir(),
|
default=get_default_config_dir(),
|
||||||
help="Directory that contains the Home Assistant configuration")
|
help="Directory that contains the Home Assistant configuration")
|
||||||
|
|
||||||
subparsers = parser.add_subparsers()
|
subparsers = parser.add_subparsers(dest='func')
|
||||||
|
subparsers.required = True
|
||||||
parser_list = subparsers.add_parser('list')
|
parser_list = subparsers.add_parser('list')
|
||||||
parser_list.set_defaults(func=list_users)
|
parser_list.set_defaults(func=list_users)
|
||||||
|
|
||||||
|
@ -37,11 +40,15 @@ def run(args):
|
||||||
parser_change_pw.set_defaults(func=change_password)
|
parser_change_pw.set_defaults(func=change_password)
|
||||||
|
|
||||||
args = parser.parse_args(args)
|
args = parser.parse_args(args)
|
||||||
path = os.path.join(os.getcwd(), args.config, hass_auth.PATH_DATA)
|
loop = asyncio.get_event_loop()
|
||||||
args.func(hass_auth.load_data(path), args)
|
hass = HomeAssistant(loop=loop)
|
||||||
|
hass.config.config_dir = os.path.join(os.getcwd(), args.config)
|
||||||
|
data = hass_auth.Data(hass)
|
||||||
|
loop.run_until_complete(data.async_load())
|
||||||
|
loop.run_until_complete(args.func(data, args))
|
||||||
|
|
||||||
|
|
||||||
def list_users(data, args):
|
async def list_users(data, args):
|
||||||
"""List the users."""
|
"""List the users."""
|
||||||
count = 0
|
count = 0
|
||||||
for user in data.users:
|
for user in data.users:
|
||||||
|
@ -52,14 +59,14 @@ def list_users(data, args):
|
||||||
print("Total users:", count)
|
print("Total users:", count)
|
||||||
|
|
||||||
|
|
||||||
def add_user(data, args):
|
async def add_user(data, args):
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
data.add_user(args.username, args.password)
|
data.add_user(args.username, args.password)
|
||||||
data.save()
|
await data.async_save()
|
||||||
print("User created")
|
print("User created")
|
||||||
|
|
||||||
|
|
||||||
def validate_login(data, args):
|
async def validate_login(data, args):
|
||||||
"""Validate a login."""
|
"""Validate a login."""
|
||||||
try:
|
try:
|
||||||
data.validate_login(args.username, args.password)
|
data.validate_login(args.username, args.password)
|
||||||
|
@ -68,11 +75,11 @@ def validate_login(data, args):
|
||||||
print("Auth invalid")
|
print("Auth invalid")
|
||||||
|
|
||||||
|
|
||||||
def change_password(data, args):
|
async def change_password(data, args):
|
||||||
"""Change password."""
|
"""Change password."""
|
||||||
try:
|
try:
|
||||||
data.change_password(args.username, args.new_password)
|
data.change_password(args.username, args.new_password)
|
||||||
data.save()
|
await data.async_save()
|
||||||
print("Password changed")
|
print("Password changed")
|
||||||
except hass_auth.InvalidUser:
|
except hass_auth.InvalidUser:
|
||||||
print("User not found")
|
print("User not found")
|
||||||
|
|
|
@ -1,60 +1,48 @@
|
||||||
"""Test the Home Assistant local auth provider."""
|
"""Test the Home Assistant local auth provider."""
|
||||||
from unittest.mock import patch, mock_open
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
|
||||||
MOCK_PATH = '/bla/users.json'
|
@pytest.fixture
|
||||||
JSON__OPEN_PATH = 'homeassistant.util.json.open'
|
def data(hass):
|
||||||
|
"""Create a loaded data class."""
|
||||||
|
data = hass_auth.Data(hass)
|
||||||
|
hass.loop.run_until_complete(data.async_load())
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_empty_config_file_not_found():
|
async def test_adding_user(data, hass):
|
||||||
"""Test that we initialize an empty config."""
|
|
||||||
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
|
|
||||||
data = hass_auth.load_data(MOCK_PATH)
|
|
||||||
|
|
||||||
assert data is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_adding_user():
|
|
||||||
"""Test adding a user."""
|
"""Test adding a user."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
data.validate_login('test-user', 'test-pass')
|
data.validate_login('test-user', 'test-pass')
|
||||||
|
|
||||||
|
|
||||||
def test_adding_user_duplicate_username():
|
async def test_adding_user_duplicate_username(data, hass):
|
||||||
"""Test adding a user."""
|
"""Test adding a user."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
with pytest.raises(hass_auth.InvalidUser):
|
with pytest.raises(hass_auth.InvalidUser):
|
||||||
data.add_user('test-user', 'other-pass')
|
data.add_user('test-user', 'other-pass')
|
||||||
|
|
||||||
|
|
||||||
def test_validating_password_invalid_user():
|
async def test_validating_password_invalid_user(data, hass):
|
||||||
"""Test validating an invalid user."""
|
"""Test validating an invalid user."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
|
|
||||||
with pytest.raises(hass_auth.InvalidAuth):
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
data.validate_login('non-existing', 'pw')
|
data.validate_login('non-existing', 'pw')
|
||||||
|
|
||||||
|
|
||||||
def test_validating_password_invalid_password():
|
async def test_validating_password_invalid_password(data, hass):
|
||||||
"""Test validating an invalid user."""
|
"""Test validating an invalid user."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
with pytest.raises(hass_auth.InvalidAuth):
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
data.validate_login('test-user', 'invalid-pass')
|
data.validate_login('test-user', 'invalid-pass')
|
||||||
|
|
||||||
|
|
||||||
def test_changing_password():
|
async def test_changing_password(data, hass):
|
||||||
"""Test adding a user."""
|
"""Test adding a user."""
|
||||||
user = 'test-user'
|
user = 'test-user'
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user(user, 'test-pass')
|
data.add_user(user, 'test-pass')
|
||||||
data.change_password(user, 'new-pass')
|
data.change_password(user, 'new-pass')
|
||||||
|
|
||||||
|
@ -64,25 +52,22 @@ def test_changing_password():
|
||||||
data.validate_login(user, 'new-pass')
|
data.validate_login(user, 'new-pass')
|
||||||
|
|
||||||
|
|
||||||
def test_changing_password_raises_invalid_user():
|
async def test_changing_password_raises_invalid_user(data, hass):
|
||||||
"""Test that we initialize an empty config."""
|
"""Test that we initialize an empty config."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
|
|
||||||
with pytest.raises(hass_auth.InvalidUser):
|
with pytest.raises(hass_auth.InvalidUser):
|
||||||
data.change_password('non-existing', 'pw')
|
data.change_password('non-existing', 'pw')
|
||||||
|
|
||||||
|
|
||||||
async def test_login_flow_validates(hass):
|
async def test_login_flow_validates(data, hass):
|
||||||
"""Test login flow."""
|
"""Test login flow."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
|
await data.async_save()
|
||||||
|
|
||||||
provider = hass_auth.HassAuthProvider(hass, None, {})
|
provider = hass_auth.HassAuthProvider(hass, None, {})
|
||||||
flow = hass_auth.LoginFlow(provider)
|
flow = hass_auth.LoginFlow(provider)
|
||||||
result = await flow.async_step_init()
|
result = await flow.async_step_init()
|
||||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
|
||||||
with patch.object(provider, '_auth_data', return_value=data):
|
|
||||||
result = await flow.async_step_init({
|
result = await flow.async_step_init({
|
||||||
'username': 'incorrect-user',
|
'username': 'incorrect-user',
|
||||||
'password': 'test-pass',
|
'password': 'test-pass',
|
||||||
|
@ -104,21 +89,13 @@ async def test_login_flow_validates(hass):
|
||||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
|
||||||
|
|
||||||
async def test_saving_loading(hass):
|
async def test_saving_loading(data, hass):
|
||||||
"""Test saving and loading JSON."""
|
"""Test saving and loading JSON."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
data.add_user('second-user', 'second-pass')
|
data.add_user('second-user', 'second-pass')
|
||||||
|
await data.async_save()
|
||||||
|
|
||||||
with patch(JSON__OPEN_PATH, mock_open(), create=True) as mock_write:
|
data = hass_auth.Data(hass)
|
||||||
await hass.async_add_job(data.save)
|
await data.async_load()
|
||||||
|
|
||||||
# Mock open calls are: open file, context enter, write, context leave
|
|
||||||
written = mock_write.mock_calls[2][1][0]
|
|
||||||
|
|
||||||
with patch('os.path.isfile', return_value=True), \
|
|
||||||
patch(JSON__OPEN_PATH, mock_open(read_data=written), create=True):
|
|
||||||
await hass.async_add_job(hass_auth.load_data, MOCK_PATH)
|
|
||||||
|
|
||||||
data.validate_login('test-user', 'test-pass')
|
data.validate_login('test-user', 'test-pass')
|
||||||
data.validate_login('second-user', 'second-pass')
|
data.validate_login('second-user', 'second-pass')
|
||||||
|
|
|
@ -6,16 +6,21 @@ import pytest
|
||||||
from homeassistant.scripts import auth as script_auth
|
from homeassistant.scripts import auth as script_auth
|
||||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||||
|
|
||||||
MOCK_PATH = '/bla/users.json'
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data(hass):
|
||||||
|
"""Create a loaded data class."""
|
||||||
|
data = hass_auth.Data(hass)
|
||||||
|
hass.loop.run_until_complete(data.async_load())
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def test_list_user(capsys):
|
async def test_list_user(data, capsys):
|
||||||
"""Test we can list users."""
|
"""Test we can list users."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
data.add_user('second-user', 'second-pass')
|
data.add_user('second-user', 'second-pass')
|
||||||
|
|
||||||
script_auth.list_users(data, None)
|
await script_auth.list_users(data, None)
|
||||||
|
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
@ -28,15 +33,12 @@ def test_list_user(capsys):
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def test_add_user(capsys):
|
async def test_add_user(data, capsys, hass_storage):
|
||||||
"""Test we can add a user."""
|
"""Test we can add a user."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
await script_auth.add_user(
|
||||||
|
|
||||||
with patch.object(data, 'save') as mock_save:
|
|
||||||
script_auth.add_user(
|
|
||||||
data, Mock(username='paulus', password='test-pass'))
|
data, Mock(username='paulus', password='test-pass'))
|
||||||
|
|
||||||
assert len(mock_save.mock_calls) == 1
|
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
||||||
|
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == 'User created\n'
|
assert captured.out == 'User created\n'
|
||||||
|
@ -45,37 +47,34 @@ def test_add_user(capsys):
|
||||||
data.validate_login('paulus', 'test-pass')
|
data.validate_login('paulus', 'test-pass')
|
||||||
|
|
||||||
|
|
||||||
def test_validate_login(capsys):
|
async def test_validate_login(data, capsys):
|
||||||
"""Test we can validate a user login."""
|
"""Test we can validate a user login."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
script_auth.validate_login(
|
await script_auth.validate_login(
|
||||||
data, Mock(username='test-user', password='test-pass'))
|
data, Mock(username='test-user', password='test-pass'))
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == 'Auth valid\n'
|
assert captured.out == 'Auth valid\n'
|
||||||
|
|
||||||
script_auth.validate_login(
|
await script_auth.validate_login(
|
||||||
data, Mock(username='test-user', password='invalid-pass'))
|
data, Mock(username='test-user', password='invalid-pass'))
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == 'Auth invalid\n'
|
assert captured.out == 'Auth invalid\n'
|
||||||
|
|
||||||
script_auth.validate_login(
|
await script_auth.validate_login(
|
||||||
data, Mock(username='invalid-user', password='test-pass'))
|
data, Mock(username='invalid-user', password='test-pass'))
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == 'Auth invalid\n'
|
assert captured.out == 'Auth invalid\n'
|
||||||
|
|
||||||
|
|
||||||
def test_change_password(capsys):
|
async def test_change_password(data, capsys, hass_storage):
|
||||||
"""Test we can change a password."""
|
"""Test we can change a password."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
with patch.object(data, 'save') as mock_save:
|
await script_auth.change_password(
|
||||||
script_auth.change_password(
|
|
||||||
data, Mock(username='test-user', new_password='new-pass'))
|
data, Mock(username='test-user', new_password='new-pass'))
|
||||||
|
|
||||||
assert len(mock_save.mock_calls) == 1
|
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == 'Password changed\n'
|
assert captured.out == 'Password changed\n'
|
||||||
data.validate_login('test-user', 'new-pass')
|
data.validate_login('test-user', 'new-pass')
|
||||||
|
@ -83,18 +82,35 @@ def test_change_password(capsys):
|
||||||
data.validate_login('test-user', 'test-pass')
|
data.validate_login('test-user', 'test-pass')
|
||||||
|
|
||||||
|
|
||||||
def test_change_password_invalid_user(capsys):
|
async def test_change_password_invalid_user(data, capsys, hass_storage):
|
||||||
"""Test changing password of non-existing user."""
|
"""Test changing password of non-existing user."""
|
||||||
data = hass_auth.Data(MOCK_PATH, None)
|
|
||||||
data.add_user('test-user', 'test-pass')
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
with patch.object(data, 'save') as mock_save:
|
await script_auth.change_password(
|
||||||
script_auth.change_password(
|
|
||||||
data, Mock(username='invalid-user', new_password='new-pass'))
|
data, Mock(username='invalid-user', new_password='new-pass'))
|
||||||
|
|
||||||
assert len(mock_save.mock_calls) == 0
|
assert hass_auth.STORAGE_KEY not in hass_storage
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == 'User not found\n'
|
assert captured.out == 'User not found\n'
|
||||||
data.validate_login('test-user', 'test-pass')
|
data.validate_login('test-user', 'test-pass')
|
||||||
with pytest.raises(hass_auth.InvalidAuth):
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
data.validate_login('invalid-user', 'new-pass')
|
data.validate_login('invalid-user', 'new-pass')
|
||||||
|
|
||||||
|
|
||||||
|
def test_parsing_args(loop):
|
||||||
|
"""Test we parse args correctly."""
|
||||||
|
called = False
|
||||||
|
|
||||||
|
async def mock_func(data, args2):
|
||||||
|
"""Mock function to be called."""
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
assert data.hass.config.config_dir == '/somewhere/config'
|
||||||
|
assert args2 is args
|
||||||
|
|
||||||
|
args = Mock(config='/somewhere/config', func=mock_func)
|
||||||
|
|
||||||
|
with patch('argparse.ArgumentParser.parse_args', return_value=args):
|
||||||
|
script_auth.run(None)
|
||||||
|
|
||||||
|
assert called, 'Mock function did not get called'
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue