Add local auth provider (#14365)
* Add local auth provider * Lint * Docstring
This commit is contained in:
parent
6e831138b4
commit
ea01b127c2
7 changed files with 501 additions and 36 deletions
|
@ -15,7 +15,6 @@ from voluptuous.humanize import humanize_error
|
||||||
from homeassistant import data_entry_flow, requirements
|
from homeassistant import data_entry_flow, requirements
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.util.decorator import Registry
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
@ -36,22 +35,6 @@ ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
||||||
DATA_REQS = 'auth_reqs_processed'
|
DATA_REQS = 'auth_reqs_processed'
|
||||||
|
|
||||||
|
|
||||||
class AuthError(HomeAssistantError):
|
|
||||||
"""Generic authentication error."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidUser(AuthError):
|
|
||||||
"""Raised when an invalid user has been specified."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidPassword(AuthError):
|
|
||||||
"""Raised when an invalid password has been supplied."""
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownError(AuthError):
|
|
||||||
"""When an unknown error occurs."""
|
|
||||||
|
|
||||||
|
|
||||||
def generate_secret(entropy=32):
|
def generate_secret(entropy=32):
|
||||||
"""Generate a secret.
|
"""Generate a secret.
|
||||||
|
|
||||||
|
@ -69,8 +52,9 @@ class AuthProvider:
|
||||||
|
|
||||||
initialized = False
|
initialized = False
|
||||||
|
|
||||||
def __init__(self, store, config):
|
def __init__(self, hass, store, config):
|
||||||
"""Initialize an auth provider."""
|
"""Initialize an auth provider."""
|
||||||
|
self.hass = hass
|
||||||
self.store = store
|
self.store = store
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -284,7 +268,7 @@ async def _auth_provider_from_config(hass, store, config):
|
||||||
provider_name, humanize_error(config, err))
|
provider_name, humanize_error(config, err))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AUTH_PROVIDERS[provider_name](store, config)
|
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
||||||
|
|
||||||
|
|
||||||
class AuthManager:
|
class AuthManager:
|
||||||
|
|
181
homeassistant/auth_providers/homeassistant.py
Normal file
181
homeassistant/auth_providers/homeassistant.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
"""Home Assistant auth provider."""
|
||||||
|
import base64
|
||||||
|
from collections import OrderedDict
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import auth, data_entry_flow
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.util import json
|
||||||
|
|
||||||
|
|
||||||
|
PATH_DATA = '.users.json'
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
||||||
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuth(HomeAssistantError):
|
||||||
|
"""Raised when we encounter invalid authentication."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidUser(HomeAssistantError):
|
||||||
|
"""Raised when invalid user is specified.
|
||||||
|
|
||||||
|
Will not be raised when validating authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Data:
|
||||||
|
"""Hold the user data."""
|
||||||
|
|
||||||
|
def __init__(self, path, data):
|
||||||
|
"""Initialize the user data store."""
|
||||||
|
self.path = path
|
||||||
|
if data is None:
|
||||||
|
data = {
|
||||||
|
'salt': auth.generate_secret(),
|
||||||
|
'users': []
|
||||||
|
}
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def users(self):
|
||||||
|
"""Return users."""
|
||||||
|
return self._data['users']
|
||||||
|
|
||||||
|
def validate_login(self, username, password):
|
||||||
|
"""Validate a username and password.
|
||||||
|
|
||||||
|
Raises InvalidAuth if auth invalid.
|
||||||
|
"""
|
||||||
|
password = self.hash_password(password)
|
||||||
|
|
||||||
|
found = None
|
||||||
|
|
||||||
|
# Compare all users to avoid timing attacks.
|
||||||
|
for user in self._data['users']:
|
||||||
|
if username == user['username']:
|
||||||
|
found = user
|
||||||
|
|
||||||
|
if found is None:
|
||||||
|
# Do one more compare to make timing the same as if user was found.
|
||||||
|
hmac.compare_digest(password, password)
|
||||||
|
raise InvalidAuth
|
||||||
|
|
||||||
|
if not hmac.compare_digest(password,
|
||||||
|
base64.b64decode(found['password'])):
|
||||||
|
raise InvalidAuth
|
||||||
|
|
||||||
|
def hash_password(self, password, for_storage=False):
|
||||||
|
"""Encode a password."""
|
||||||
|
hashed = hashlib.pbkdf2_hmac(
|
||||||
|
'sha512', password.encode(), self._data['salt'].encode(), 100000)
|
||||||
|
if for_storage:
|
||||||
|
hashed = base64.b64encode(hashed).decode()
|
||||||
|
return hashed
|
||||||
|
|
||||||
|
def add_user(self, username, password):
|
||||||
|
"""Add a user."""
|
||||||
|
if any(user['username'] == username for user in self.users):
|
||||||
|
raise InvalidUser
|
||||||
|
|
||||||
|
self.users.append({
|
||||||
|
'username': username,
|
||||||
|
'password': self.hash_password(password, True),
|
||||||
|
})
|
||||||
|
|
||||||
|
def change_password(self, username, new_password):
|
||||||
|
"""Update the password of a user.
|
||||||
|
|
||||||
|
Raises InvalidUser if user cannot be found.
|
||||||
|
"""
|
||||||
|
for user in self.users:
|
||||||
|
if user['username'] == username:
|
||||||
|
user['password'] = self.hash_password(new_password, True)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise InvalidUser
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""Save data."""
|
||||||
|
json.save_json(self.path, self._data)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(path):
|
||||||
|
"""Load auth data."""
|
||||||
|
return Data(path, json.load_json(path, None))
|
||||||
|
|
||||||
|
|
||||||
|
@auth.AUTH_PROVIDERS.register('homeassistant')
|
||||||
|
class HassAuthProvider(auth.AuthProvider):
|
||||||
|
"""Auth provider based on a local storage of users in HASS config dir."""
|
||||||
|
|
||||||
|
DEFAULT_TITLE = 'Home Assistant Local'
|
||||||
|
|
||||||
|
async def async_credential_flow(self):
|
||||||
|
"""Return a flow to login."""
|
||||||
|
return LoginFlow(self)
|
||||||
|
|
||||||
|
async def async_validate_login(self, username, password):
|
||||||
|
"""Helper to validate a username and password."""
|
||||||
|
def validate():
|
||||||
|
"""Validate creds."""
|
||||||
|
data = self._auth_data()
|
||||||
|
data.validate_login(username, password)
|
||||||
|
|
||||||
|
await self.hass.async_add_job(validate)
|
||||||
|
|
||||||
|
async def async_get_or_create_credentials(self, flow_result):
|
||||||
|
"""Get credentials based on the flow result."""
|
||||||
|
username = flow_result['username']
|
||||||
|
|
||||||
|
for credential in await self.async_credentials():
|
||||||
|
if credential.data['username'] == username:
|
||||||
|
return credential
|
||||||
|
|
||||||
|
# Create new credentials.
|
||||||
|
return self.async_create_credentials({
|
||||||
|
'username': username
|
||||||
|
})
|
||||||
|
|
||||||
|
def _auth_data(self):
|
||||||
|
"""Return the auth provider data."""
|
||||||
|
return load_data(self.hass.config.path(PATH_DATA))
|
||||||
|
|
||||||
|
|
||||||
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
|
def __init__(self, auth_provider):
|
||||||
|
"""Initialize the login flow."""
|
||||||
|
self._auth_provider = auth_provider
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
"""Handle the step of the form."""
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
try:
|
||||||
|
await self._auth_provider.async_validate_login(
|
||||||
|
user_input['username'], user_input['password'])
|
||||||
|
except InvalidAuth:
|
||||||
|
errors['base'] = 'invalid_auth'
|
||||||
|
|
||||||
|
if not errors:
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=self._auth_provider.name,
|
||||||
|
data=user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = OrderedDict()
|
||||||
|
schema['username'] = str
|
||||||
|
schema['password'] = str
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id='init',
|
||||||
|
data_schema=vol.Schema(schema),
|
||||||
|
errors=errors,
|
||||||
|
)
|
|
@ -4,6 +4,7 @@ import hmac
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant import auth, data_entry_flow
|
from homeassistant import auth, data_entry_flow
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
@ -20,6 +21,10 @@ CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
||||||
}, extra=vol.PREVENT_EXTRA)
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuthError(HomeAssistantError):
|
||||||
|
"""Raised when submitting invalid authentication."""
|
||||||
|
|
||||||
|
|
||||||
@auth.AUTH_PROVIDERS.register('insecure_example')
|
@auth.AUTH_PROVIDERS.register('insecure_example')
|
||||||
class ExampleAuthProvider(auth.AuthProvider):
|
class ExampleAuthProvider(auth.AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
@ -43,18 +48,15 @@ class ExampleAuthProvider(auth.AuthProvider):
|
||||||
# Do one more compare to make timing the same as if user was found.
|
# Do one more compare to make timing the same as if user was found.
|
||||||
hmac.compare_digest(password.encode('utf-8'),
|
hmac.compare_digest(password.encode('utf-8'),
|
||||||
password.encode('utf-8'))
|
password.encode('utf-8'))
|
||||||
raise auth.InvalidUser
|
raise InvalidAuthError
|
||||||
|
|
||||||
if not hmac.compare_digest(user['password'].encode('utf-8'),
|
if not hmac.compare_digest(user['password'].encode('utf-8'),
|
||||||
password.encode('utf-8')):
|
password.encode('utf-8')):
|
||||||
raise auth.InvalidPassword
|
raise InvalidAuthError
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
async def async_get_or_create_credentials(self, flow_result):
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
username = flow_result['username']
|
username = flow_result['username']
|
||||||
password = flow_result['password']
|
|
||||||
|
|
||||||
self.async_validate_login(username, password)
|
|
||||||
|
|
||||||
for credential in await self.async_credentials():
|
for credential in await self.async_credentials():
|
||||||
if credential.data['username'] == username:
|
if credential.data['username'] == username:
|
||||||
|
@ -96,7 +98,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
try:
|
try:
|
||||||
self._auth_provider.async_validate_login(
|
self._auth_provider.async_validate_login(
|
||||||
user_input['username'], user_input['password'])
|
user_input['username'], user_input['password'])
|
||||||
except (auth.InvalidUser, auth.InvalidPassword):
|
except InvalidAuthError:
|
||||||
errors['base'] = 'invalid_auth'
|
errors['base'] = 'invalid_auth'
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
|
|
78
homeassistant/scripts/auth.py
Normal file
78
homeassistant/scripts/auth.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
"""Script to manage users for the Home Assistant auth provider."""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
from homeassistant.config import get_default_config_dir
|
||||||
|
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
|
||||||
|
def run(args):
|
||||||
|
"""Handle Home Assistant auth provider script."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=("Manage Home Assistant users"))
|
||||||
|
parser.add_argument(
|
||||||
|
'--script', choices=['auth'])
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--config',
|
||||||
|
default=get_default_config_dir(),
|
||||||
|
help="Directory that contains the Home Assistant configuration")
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers()
|
||||||
|
parser_list = subparsers.add_parser('list')
|
||||||
|
parser_list.set_defaults(func=list_users)
|
||||||
|
|
||||||
|
parser_add = subparsers.add_parser('add')
|
||||||
|
parser_add.add_argument('username', type=str)
|
||||||
|
parser_add.add_argument('password', type=str)
|
||||||
|
parser_add.set_defaults(func=add_user)
|
||||||
|
|
||||||
|
parser_validate_login = subparsers.add_parser('validate')
|
||||||
|
parser_validate_login.add_argument('username', type=str)
|
||||||
|
parser_validate_login.add_argument('password', type=str)
|
||||||
|
parser_validate_login.set_defaults(func=validate_login)
|
||||||
|
|
||||||
|
parser_change_pw = subparsers.add_parser('change_password')
|
||||||
|
parser_change_pw.add_argument('username', type=str)
|
||||||
|
parser_change_pw.add_argument('new_password', type=str)
|
||||||
|
parser_change_pw.set_defaults(func=change_password)
|
||||||
|
|
||||||
|
args = parser.parse_args(args)
|
||||||
|
path = os.path.join(os.getcwd(), args.config, hass_auth.PATH_DATA)
|
||||||
|
args.func(hass_auth.load_data(path), args)
|
||||||
|
|
||||||
|
|
||||||
|
def list_users(data, args):
|
||||||
|
"""List the users."""
|
||||||
|
count = 0
|
||||||
|
for user in data.users:
|
||||||
|
count += 1
|
||||||
|
print(user['username'])
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Total users:", count)
|
||||||
|
|
||||||
|
|
||||||
|
def add_user(data, args):
|
||||||
|
"""Create a user."""
|
||||||
|
data.add_user(args.username, args.password)
|
||||||
|
data.save()
|
||||||
|
print("User created")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_login(data, args):
|
||||||
|
"""Validate a login."""
|
||||||
|
try:
|
||||||
|
data.validate_login(args.username, args.password)
|
||||||
|
print("Auth valid")
|
||||||
|
except hass_auth.InvalidAuth:
|
||||||
|
print("Auth invalid")
|
||||||
|
|
||||||
|
|
||||||
|
def change_password(data, args):
|
||||||
|
"""Change password."""
|
||||||
|
try:
|
||||||
|
data.change_password(args.username, args.new_password)
|
||||||
|
data.save()
|
||||||
|
print("Password changed")
|
||||||
|
except hass_auth.InvalidUser:
|
||||||
|
print("User not found")
|
124
tests/auth_providers/test_homeassistant.py
Normal file
124
tests/auth_providers/test_homeassistant.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
"""Test the Home Assistant local auth provider."""
|
||||||
|
from unittest.mock import patch, mock_open
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_PATH = '/bla/users.json'
|
||||||
|
JSON__OPEN_PATH = 'homeassistant.util.json.open'
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_empty_config_file_not_found():
|
||||||
|
"""Test that we initialize an empty config."""
|
||||||
|
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
|
||||||
|
data = hass_auth.load_data(MOCK_PATH)
|
||||||
|
|
||||||
|
assert data is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_adding_user():
|
||||||
|
"""Test adding a user."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
data.validate_login('test-user', 'test-pass')
|
||||||
|
|
||||||
|
|
||||||
|
def test_adding_user_duplicate_username():
|
||||||
|
"""Test adding a user."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
with pytest.raises(hass_auth.InvalidUser):
|
||||||
|
data.add_user('test-user', 'other-pass')
|
||||||
|
|
||||||
|
|
||||||
|
def test_validating_password_invalid_user():
|
||||||
|
"""Test validating an invalid user."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
|
||||||
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
|
data.validate_login('non-existing', 'pw')
|
||||||
|
|
||||||
|
|
||||||
|
def test_validating_password_invalid_password():
|
||||||
|
"""Test validating an invalid user."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
|
data.validate_login('test-user', 'invalid-pass')
|
||||||
|
|
||||||
|
|
||||||
|
def test_changing_password():
|
||||||
|
"""Test adding a user."""
|
||||||
|
user = 'test-user'
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user(user, 'test-pass')
|
||||||
|
data.change_password(user, 'new-pass')
|
||||||
|
|
||||||
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
|
data.validate_login(user, 'test-pass')
|
||||||
|
|
||||||
|
data.validate_login(user, 'new-pass')
|
||||||
|
|
||||||
|
|
||||||
|
def test_changing_password_raises_invalid_user():
|
||||||
|
"""Test that we initialize an empty config."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
|
||||||
|
with pytest.raises(hass_auth.InvalidUser):
|
||||||
|
data.change_password('non-existing', 'pw')
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_flow_validates(hass):
|
||||||
|
"""Test login flow."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
|
provider = hass_auth.HassAuthProvider(hass, None, {})
|
||||||
|
flow = hass_auth.LoginFlow(provider)
|
||||||
|
result = await flow.async_step_init()
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
|
||||||
|
with patch.object(provider, '_auth_data', return_value=data):
|
||||||
|
result = await flow.async_step_init({
|
||||||
|
'username': 'incorrect-user',
|
||||||
|
'password': 'test-pass',
|
||||||
|
})
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result['errors']['base'] == 'invalid_auth'
|
||||||
|
|
||||||
|
result = await flow.async_step_init({
|
||||||
|
'username': 'test-user',
|
||||||
|
'password': 'incorrect-pass',
|
||||||
|
})
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result['errors']['base'] == 'invalid_auth'
|
||||||
|
|
||||||
|
result = await flow.async_step_init({
|
||||||
|
'username': 'test-user',
|
||||||
|
'password': 'test-pass',
|
||||||
|
})
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
|
||||||
|
|
||||||
|
async def test_saving_loading(hass):
|
||||||
|
"""Test saving and loading JSON."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
data.add_user('second-user', 'second-pass')
|
||||||
|
|
||||||
|
with patch(JSON__OPEN_PATH, mock_open(), create=True) as mock_write:
|
||||||
|
await hass.async_add_job(data.save)
|
||||||
|
|
||||||
|
# Mock open calls are: open file, context enter, write, context leave
|
||||||
|
written = mock_write.mock_calls[2][1][0]
|
||||||
|
|
||||||
|
with patch('os.path.isfile', return_value=True), \
|
||||||
|
patch(JSON__OPEN_PATH, mock_open(read_data=written), create=True):
|
||||||
|
await hass.async_add_job(hass_auth.load_data, MOCK_PATH)
|
||||||
|
|
||||||
|
data.validate_login('test-user', 'test-pass')
|
||||||
|
data.validate_login('second-user', 'second-pass')
|
|
@ -19,7 +19,7 @@ def store():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def provider(store):
|
def provider(store):
|
||||||
"""Mock provider."""
|
"""Mock provider."""
|
||||||
return insecure_example.ExampleAuthProvider(store, {
|
return insecure_example.ExampleAuthProvider(None, store, {
|
||||||
'type': 'insecure_example',
|
'type': 'insecure_example',
|
||||||
'users': [
|
'users': [
|
||||||
{
|
{
|
||||||
|
@ -64,20 +64,16 @@ async def test_match_existing_credentials(store, provider):
|
||||||
|
|
||||||
async def test_verify_username(provider):
|
async def test_verify_username(provider):
|
||||||
"""Test we raise if incorrect user specified."""
|
"""Test we raise if incorrect user specified."""
|
||||||
with pytest.raises(auth.InvalidUser):
|
with pytest.raises(insecure_example.InvalidAuthError):
|
||||||
await provider.async_get_or_create_credentials({
|
await provider.async_validate_login(
|
||||||
'username': 'non-existing-user',
|
'non-existing-user', 'password-test')
|
||||||
'password': 'password-test',
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def test_verify_password(provider):
|
async def test_verify_password(provider):
|
||||||
"""Test we raise if incorrect user specified."""
|
"""Test we raise if incorrect user specified."""
|
||||||
with pytest.raises(auth.InvalidPassword):
|
with pytest.raises(insecure_example.InvalidAuthError):
|
||||||
await provider.async_get_or_create_credentials({
|
await provider.async_validate_login(
|
||||||
'username': 'user-test',
|
'user-test', 'incorrect-password')
|
||||||
'password': 'incorrect-password',
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def test_utf_8_username_password(provider):
|
async def test_utf_8_username_password(provider):
|
||||||
|
|
100
tests/scripts/test_auth.py
Normal file
100
tests/scripts/test_auth.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
"""Test the auth script to manage local users."""
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.scripts import auth as script_auth
|
||||||
|
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
MOCK_PATH = '/bla/users.json'
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_user(capsys):
|
||||||
|
"""Test we can list users."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
data.add_user('second-user', 'second-pass')
|
||||||
|
|
||||||
|
script_auth.list_users(data, None)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert captured.out == '\n'.join([
|
||||||
|
'test-user',
|
||||||
|
'second-user',
|
||||||
|
'',
|
||||||
|
'Total users: 2',
|
||||||
|
''
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_user(capsys):
|
||||||
|
"""Test we can add a user."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
|
||||||
|
with patch.object(data, 'save') as mock_save:
|
||||||
|
script_auth.add_user(
|
||||||
|
data, Mock(username='paulus', password='test-pass'))
|
||||||
|
|
||||||
|
assert len(mock_save.mock_calls) == 1
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == 'User created\n'
|
||||||
|
|
||||||
|
assert len(data.users) == 1
|
||||||
|
data.validate_login('paulus', 'test-pass')
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_login(capsys):
|
||||||
|
"""Test we can validate a user login."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
|
script_auth.validate_login(
|
||||||
|
data, Mock(username='test-user', password='test-pass'))
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == 'Auth valid\n'
|
||||||
|
|
||||||
|
script_auth.validate_login(
|
||||||
|
data, Mock(username='test-user', password='invalid-pass'))
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == 'Auth invalid\n'
|
||||||
|
|
||||||
|
script_auth.validate_login(
|
||||||
|
data, Mock(username='invalid-user', password='test-pass'))
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == 'Auth invalid\n'
|
||||||
|
|
||||||
|
|
||||||
|
def test_change_password(capsys):
|
||||||
|
"""Test we can change a password."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
|
with patch.object(data, 'save') as mock_save:
|
||||||
|
script_auth.change_password(
|
||||||
|
data, Mock(username='test-user', new_password='new-pass'))
|
||||||
|
|
||||||
|
assert len(mock_save.mock_calls) == 1
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == 'Password changed\n'
|
||||||
|
data.validate_login('test-user', 'new-pass')
|
||||||
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
|
data.validate_login('test-user', 'test-pass')
|
||||||
|
|
||||||
|
|
||||||
|
def test_change_password_invalid_user(capsys):
|
||||||
|
"""Test changing password of non-existing user."""
|
||||||
|
data = hass_auth.Data(MOCK_PATH, None)
|
||||||
|
data.add_user('test-user', 'test-pass')
|
||||||
|
|
||||||
|
with patch.object(data, 'save') as mock_save:
|
||||||
|
script_auth.change_password(
|
||||||
|
data, Mock(username='invalid-user', new_password='new-pass'))
|
||||||
|
|
||||||
|
assert len(mock_save.mock_calls) == 0
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == 'User not found\n'
|
||||||
|
data.validate_login('test-user', 'test-pass')
|
||||||
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
|
data.validate_login('invalid-user', 'new-pass')
|
Loading…
Add table
Reference in a new issue