User management (#15420)
* User management * Lint * Fix dict * Reuse data instance * OrderedDict all the way
This commit is contained in:
parent
84858f5c19
commit
70fe463ef0
21 changed files with 982 additions and 116 deletions
|
@ -51,7 +51,7 @@ class AuthManager:
|
|||
self.login_flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_login_flow,
|
||||
self._async_finish_login_flow)
|
||||
self._access_tokens = {}
|
||||
self._access_tokens = OrderedDict()
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
|
@ -71,9 +71,13 @@ class AuthManager:
|
|||
return False
|
||||
|
||||
@property
|
||||
def async_auth_providers(self):
|
||||
def auth_providers(self):
|
||||
"""Return a list of available auth providers."""
|
||||
return self._providers.values()
|
||||
return list(self._providers.values())
|
||||
|
||||
async def async_get_users(self):
|
||||
"""Retrieve all users."""
|
||||
return await self._store.async_get_users()
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user."""
|
||||
|
@ -87,6 +91,13 @@ class AuthManager:
|
|||
is_active=True,
|
||||
)
|
||||
|
||||
async def async_create_user(self, name):
|
||||
"""Create a user."""
|
||||
return await self._store.async_create_user(
|
||||
name=name,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
async def async_get_or_create_user(self, credentials):
|
||||
"""Get or create a user."""
|
||||
if not credentials.is_new:
|
||||
|
@ -98,6 +109,10 @@ class AuthManager:
|
|||
raise ValueError('Unable to find the user.')
|
||||
|
||||
auth_provider = self._async_get_auth_provider(credentials)
|
||||
|
||||
if auth_provider is None:
|
||||
raise RuntimeError('Credential with unknown provider encountered')
|
||||
|
||||
info = await auth_provider.async_user_meta_for_credentials(
|
||||
credentials)
|
||||
|
||||
|
@ -122,8 +137,26 @@ class AuthManager:
|
|||
|
||||
async def async_remove_user(self, user):
|
||||
"""Remove a user."""
|
||||
tasks = [
|
||||
self.async_remove_credentials(credentials)
|
||||
for credentials in user.credentials
|
||||
]
|
||||
|
||||
if tasks:
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
await self._store.async_remove_user(user)
|
||||
|
||||
async def async_remove_credentials(self, credentials):
|
||||
"""Remove credentials."""
|
||||
provider = self._async_get_auth_provider(credentials)
|
||||
|
||||
if (provider is not None and
|
||||
hasattr(provider, 'async_will_remove_credentials')):
|
||||
await provider.async_will_remove_credentials(credentials)
|
||||
|
||||
await self._store.async_remove_credentials(credentials)
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new refresh token for a user."""
|
||||
if not user.is_active:
|
||||
|
@ -168,10 +201,6 @@ class AuthManager:
|
|||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
|
||||
if not auth_provider.initialized:
|
||||
auth_provider.initialized = True
|
||||
await auth_provider.async_initialize()
|
||||
|
||||
return await auth_provider.async_credential_flow()
|
||||
|
||||
async def _async_finish_login_flow(self, result):
|
||||
|
@ -188,4 +217,4 @@ class AuthManager:
|
|||
"""Helper to get auth provider from a set of credentials."""
|
||||
auth_provider_key = (credentials.auth_provider_type,
|
||||
credentials.auth_provider_id)
|
||||
return self._providers[auth_provider_key]
|
||||
return self._providers.get(auth_provider_key)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Storage for auth models."""
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
@ -80,6 +81,22 @@ class AuthStore:
|
|||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
|
||||
async def async_remove_credentials(self, credentials):
|
||||
"""Remove credentials."""
|
||||
for user in self._users.values():
|
||||
found = None
|
||||
|
||||
for index, cred in enumerate(user.credentials):
|
||||
if cred is credentials:
|
||||
found = index
|
||||
break
|
||||
|
||||
if found is not None:
|
||||
user.credentials.pop(found)
|
||||
break
|
||||
|
||||
await self.async_save()
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new token for a user."""
|
||||
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||
|
@ -108,14 +125,14 @@ class AuthStore:
|
|||
if self._users is not None:
|
||||
return
|
||||
|
||||
users = OrderedDict()
|
||||
|
||||
if data is None:
|
||||
self._users = {}
|
||||
self._users = users
|
||||
return
|
||||
|
||||
users = {
|
||||
user_dict['id']: models.User(**user_dict)
|
||||
for user_dict in data['users']
|
||||
}
|
||||
for user_dict in data['users']:
|
||||
users[user_dict['id']] = models.User(**user_dict)
|
||||
|
||||
for cred_dict in data['credentials']:
|
||||
users[cred_dict['user_id']].credentials.append(models.Credentials(
|
||||
|
@ -126,7 +143,7 @@ class AuthStore:
|
|||
data=cred_dict['data'],
|
||||
))
|
||||
|
||||
refresh_tokens = {}
|
||||
refresh_tokens = OrderedDict()
|
||||
|
||||
for rt_dict in data['refresh_tokens']:
|
||||
token = models.RefreshToken(
|
||||
|
|
|
@ -77,8 +77,6 @@ class AuthProvider:
|
|||
|
||||
DEFAULT_TITLE = 'Unnamed auth provider'
|
||||
|
||||
initialized = False
|
||||
|
||||
def __init__(self, hass, store, config):
|
||||
"""Initialize an auth provider."""
|
||||
self.hass = hass
|
||||
|
@ -125,12 +123,6 @@ class AuthProvider:
|
|||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_initialize(self):
|
||||
"""Initialize the auth provider.
|
||||
|
||||
Optional.
|
||||
"""
|
||||
|
||||
async def async_credential_flow(self):
|
||||
"""Return the data flow for logging in with auth provider."""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -7,6 +7,8 @@ import hmac
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.const import CONF_ID
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from homeassistant.auth.util import generate_secret
|
||||
|
@ -16,8 +18,17 @@ from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
|||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
|
||||
def _disallow_id(conf):
|
||||
"""Disallow ID in config."""
|
||||
if CONF_ID in conf:
|
||||
raise vol.Invalid(
|
||||
'ID is not allowed for the homeassistant auth provider.')
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
CONFIG_SCHEMA = vol.All(AUTH_PROVIDER_SCHEMA, _disallow_id)
|
||||
|
||||
|
||||
class InvalidAuth(HomeAssistantError):
|
||||
|
@ -88,8 +99,8 @@ class Data:
|
|||
hashed = base64.b64encode(hashed).decode()
|
||||
return hashed
|
||||
|
||||
def add_user(self, username, password):
|
||||
"""Add a user."""
|
||||
def add_auth(self, username, password):
|
||||
"""Add a new authenticated user/pass."""
|
||||
if any(user['username'] == username for user in self.users):
|
||||
raise InvalidUser
|
||||
|
||||
|
@ -98,8 +109,22 @@ class Data:
|
|||
'password': self.hash_password(password, True),
|
||||
})
|
||||
|
||||
@callback
|
||||
def async_remove_auth(self, username):
|
||||
"""Remove authentication."""
|
||||
index = None
|
||||
for i, user in enumerate(self.users):
|
||||
if user['username'] == username:
|
||||
index = i
|
||||
break
|
||||
|
||||
if index is None:
|
||||
raise InvalidUser
|
||||
|
||||
self.users.pop(index)
|
||||
|
||||
def change_password(self, username, new_password):
|
||||
"""Update the password of a user.
|
||||
"""Update the password.
|
||||
|
||||
Raises InvalidUser if user cannot be found.
|
||||
"""
|
||||
|
@ -121,16 +146,24 @@ class HassAuthProvider(AuthProvider):
|
|||
|
||||
DEFAULT_TITLE = 'Home Assistant Local'
|
||||
|
||||
data = None
|
||||
|
||||
async def async_initialize(self):
|
||||
"""Initialize the auth provider."""
|
||||
self.data = Data(self.hass)
|
||||
await self.data.async_load()
|
||||
|
||||
async def async_credential_flow(self):
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
|
||||
async def async_validate_login(self, username, password):
|
||||
"""Helper to validate a username and password."""
|
||||
data = Data(self.hass)
|
||||
await data.async_load()
|
||||
if self.data is None:
|
||||
await self.async_initialize()
|
||||
|
||||
await self.hass.async_add_executor_job(
|
||||
data.validate_login, username, password)
|
||||
self.data.validate_login, username, password)
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
"""Get credentials based on the flow result."""
|
||||
|
@ -145,6 +178,24 @@ class HassAuthProvider(AuthProvider):
|
|||
'username': username
|
||||
})
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
"""Get extra info for this credential."""
|
||||
return {
|
||||
'name': credentials.data['username']
|
||||
}
|
||||
|
||||
async def async_will_remove_credentials(self, credentials):
|
||||
"""When credentials get removed, also remove the auth."""
|
||||
if self.data is None:
|
||||
await self.async_initialize()
|
||||
|
||||
try:
|
||||
self.data.async_remove_auth(credentials.data['username'])
|
||||
await self.data.async_save()
|
||||
except InvalidUser:
|
||||
# Can happen if somehow we didn't clean up a credential
|
||||
pass
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
|
|
@ -152,7 +152,7 @@ class AuthProvidersView(HomeAssistantView):
|
|||
'name': provider.name,
|
||||
'id': provider.id,
|
||||
'type': provider.type,
|
||||
} for provider in request.app['hass'].auth.async_auth_providers])
|
||||
} for provider in request.app['hass'].auth.auth_providers])
|
||||
|
||||
|
||||
class LoginFlowIndexView(FlowManagerIndexView):
|
||||
|
|
|
@ -66,8 +66,8 @@ CAMERA_SERVICE_SNAPSHOT = CAMERA_SERVICE_SCHEMA.extend({
|
|||
|
||||
WS_TYPE_CAMERA_THUMBNAIL = 'camera_thumbnail'
|
||||
SCHEMA_WS_CAMERA_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
'type': WS_TYPE_CAMERA_THUMBNAIL,
|
||||
'entity_id': cv.entity_id
|
||||
vol.Required('type'): WS_TYPE_CAMERA_THUMBNAIL,
|
||||
vol.Required('entity_id'): cv.entity_id
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -49,6 +49,10 @@ async def async_setup(hass, config):
|
|||
|
||||
tasks = [setup_panel(panel_name) for panel_name in SECTIONS]
|
||||
|
||||
if hass.auth.active:
|
||||
tasks.append(setup_panel('auth'))
|
||||
tasks.append(setup_panel('auth_provider_homeassistant'))
|
||||
|
||||
for panel_name in ON_DEMAND:
|
||||
if panel_name in hass.config.components:
|
||||
tasks.append(setup_panel(panel_name))
|
||||
|
|
113
homeassistant/components/config/auth.py
Normal file
113
homeassistant/components/config/auth.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
"""Offer API to configure Home Assistant auth."""
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api
|
||||
|
||||
|
||||
WS_TYPE_LIST = 'config/auth/list'
|
||||
SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): WS_TYPE_LIST,
|
||||
})
|
||||
|
||||
WS_TYPE_DELETE = 'config/auth/delete'
|
||||
SCHEMA_WS_DELETE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): WS_TYPE_DELETE,
|
||||
vol.Required('user_id'): str,
|
||||
})
|
||||
|
||||
WS_TYPE_CREATE = 'config/auth/create'
|
||||
SCHEMA_WS_CREATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): WS_TYPE_CREATE,
|
||||
vol.Required('name'): str,
|
||||
})
|
||||
|
||||
|
||||
async def async_setup(hass):
|
||||
"""Enable the Home Assistant views."""
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_LIST, websocket_list,
|
||||
SCHEMA_WS_LIST
|
||||
)
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_DELETE, websocket_delete,
|
||||
SCHEMA_WS_DELETE
|
||||
)
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_CREATE, websocket_create,
|
||||
SCHEMA_WS_CREATE
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
def websocket_list(hass, connection, msg):
|
||||
"""Return a list of users."""
|
||||
async def send_users():
|
||||
"""Send users."""
|
||||
result = [_user_info(u) for u in await hass.auth.async_get_users()]
|
||||
|
||||
connection.send_message_outside(
|
||||
websocket_api.result_message(msg['id'], result))
|
||||
|
||||
hass.async_add_job(send_users())
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
def websocket_delete(hass, connection, msg):
|
||||
"""Delete a user."""
|
||||
async def delete_user():
|
||||
"""Delete user."""
|
||||
if msg['user_id'] == connection.request.get('hass_user').id:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'no_delete_self',
|
||||
'Unable to delete your own account'))
|
||||
return
|
||||
|
||||
user = await hass.auth.async_get_user(msg['user_id'])
|
||||
|
||||
if not user:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'not_found', 'User not found'))
|
||||
return
|
||||
|
||||
await hass.auth.async_remove_user(user)
|
||||
|
||||
connection.send_message_outside(
|
||||
websocket_api.result_message(msg['id']))
|
||||
|
||||
hass.async_add_job(delete_user())
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
def websocket_create(hass, connection, msg):
|
||||
"""Create a user."""
|
||||
async def create_user():
|
||||
"""Create a user."""
|
||||
user = await hass.auth.async_create_user(msg['name'])
|
||||
|
||||
connection.send_message_outside(
|
||||
websocket_api.result_message(msg['id'], {
|
||||
'user': _user_info(user)
|
||||
}))
|
||||
|
||||
hass.async_add_job(create_user())
|
||||
|
||||
|
||||
def _user_info(user):
|
||||
"""Format a user."""
|
||||
return {
|
||||
'id': user.id,
|
||||
'name': user.name,
|
||||
'is_owner': user.is_owner,
|
||||
'is_active': user.is_active,
|
||||
'system_generated': user.system_generated,
|
||||
'credentials': [
|
||||
{
|
||||
'type': c.auth_provider_type,
|
||||
} for c in user.credentials
|
||||
]
|
||||
}
|
120
homeassistant/components/config/auth_provider_homeassistant.py
Normal file
120
homeassistant/components/config/auth_provider_homeassistant.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
"""Offer API to configure the Home Assistant auth provider."""
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.providers import homeassistant as auth_ha
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api
|
||||
|
||||
|
||||
WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create'
|
||||
SCHEMA_WS_CREATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): WS_TYPE_CREATE,
|
||||
vol.Required('user_id'): str,
|
||||
vol.Required('username'): str,
|
||||
vol.Required('password'): str,
|
||||
})
|
||||
|
||||
WS_TYPE_DELETE = 'config/auth_provider/homeassistant/delete'
|
||||
SCHEMA_WS_DELETE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): WS_TYPE_DELETE,
|
||||
vol.Required('username'): str,
|
||||
})
|
||||
|
||||
|
||||
async def async_setup(hass):
|
||||
"""Enable the Home Assistant views."""
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_CREATE, websocket_create,
|
||||
SCHEMA_WS_CREATE
|
||||
)
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_DELETE, websocket_delete,
|
||||
SCHEMA_WS_DELETE
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _get_provider(hass):
|
||||
"""Get homeassistant auth provider."""
|
||||
for prv in hass.auth.auth_providers:
|
||||
if prv.type == 'homeassistant':
|
||||
return prv
|
||||
|
||||
raise RuntimeError('Provider not found')
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
def websocket_create(hass, connection, msg):
|
||||
"""Create credentials and attach to a user."""
|
||||
async def create_creds():
|
||||
"""Create credentials."""
|
||||
provider = _get_provider(hass)
|
||||
await provider.async_initialize()
|
||||
|
||||
user = await hass.auth.async_get_user(msg['user_id'])
|
||||
|
||||
if user is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'not_found', 'User not found'))
|
||||
return
|
||||
|
||||
if user.system_generated:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'system_generated',
|
||||
'Cannot add credentials to a system generated user.'))
|
||||
return
|
||||
|
||||
try:
|
||||
await hass.async_add_executor_job(
|
||||
provider.data.add_auth, msg['username'], msg['password'])
|
||||
except auth_ha.InvalidUser:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'username_exists', 'Username already exists'))
|
||||
return
|
||||
|
||||
credentials = await provider.async_get_or_create_credentials({
|
||||
'username': msg['username']
|
||||
})
|
||||
await hass.auth.async_link_user(user, credentials)
|
||||
|
||||
await provider.data.async_save()
|
||||
connection.to_write.put_nowait(websocket_api.result_message(msg['id']))
|
||||
|
||||
hass.async_add_job(create_creds())
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
def websocket_delete(hass, connection, msg):
|
||||
"""Delete username and related credential."""
|
||||
async def delete_creds():
|
||||
"""Delete user credentials."""
|
||||
provider = _get_provider(hass)
|
||||
await provider.async_initialize()
|
||||
|
||||
credentials = await provider.async_get_or_create_credentials({
|
||||
'username': msg['username']
|
||||
})
|
||||
|
||||
# if not new, an existing credential exists.
|
||||
# Removing the credential will also remove the auth.
|
||||
if not credentials.is_new:
|
||||
await hass.auth.async_remove_credentials(credentials)
|
||||
|
||||
connection.to_write.put_nowait(
|
||||
websocket_api.result_message(msg['id']))
|
||||
return
|
||||
|
||||
try:
|
||||
provider.data.async_remove_auth(msg['username'])
|
||||
await provider.data.async_save()
|
||||
except auth_ha.InvalidUser:
|
||||
connection.to_write.put_nowait(websocket_api.error_message(
|
||||
msg['id'], 'auth_not_found', 'Given username was not found.'))
|
||||
return
|
||||
|
||||
connection.to_write.put_nowait(
|
||||
websocket_api.result_message(msg['id']))
|
||||
|
||||
hass.async_add_job(delete_creds())
|
|
@ -106,6 +106,11 @@ async def async_validate_auth_header(request, api_password=None):
|
|||
if access_token is None:
|
||||
return False
|
||||
|
||||
user = access_token.refresh_token.user
|
||||
|
||||
if not user.is_active:
|
||||
return False
|
||||
|
||||
request['hass_user'] = access_token.refresh_token.user
|
||||
return True
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ https://home-assistant.io/developers/websocket_api/
|
|||
import asyncio
|
||||
from concurrent import futures
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
@ -196,6 +196,23 @@ def async_register_command(hass, command, handler, schema):
|
|||
handlers[command] = (handler, schema)
|
||||
|
||||
|
||||
def require_owner(func):
|
||||
"""Websocket decorator to require user to be an owner."""
|
||||
@wraps(func)
|
||||
def with_owner(hass, connection, msg):
|
||||
"""Check owner and call function."""
|
||||
user = connection.request.get('hass_user')
|
||||
|
||||
if user is None or not user.is_owner:
|
||||
connection.to_write.put_nowait(error_message(
|
||||
msg['id'], 'unauthorized', 'This command is for owners only.'))
|
||||
return
|
||||
|
||||
func(hass, connection, msg)
|
||||
|
||||
return with_owner
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Initialize the websocket API."""
|
||||
hass.http.register_view(WebsocketAPIView)
|
||||
|
@ -325,6 +342,8 @@ class ActiveConnection:
|
|||
token = self.hass.auth.async_get_access_token(
|
||||
msg['access_token'])
|
||||
authenticated = token is not None
|
||||
if authenticated:
|
||||
request['hass_user'] = token.refresh_token.user
|
||||
|
||||
elif ((not self.hass.auth.active or
|
||||
self.hass.auth.support_legacy) and
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
"""Script to manage users for the Home Assistant auth provider."""
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from homeassistant.auth import auth_manager_from_config
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.config import get_default_config_dir
|
||||
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||
|
@ -42,16 +44,28 @@ def run(args):
|
|||
args = parser.parse_args(args)
|
||||
loop = asyncio.get_event_loop()
|
||||
hass = HomeAssistant(loop=loop)
|
||||
loop.run_until_complete(run_command(hass, args))
|
||||
|
||||
# Triggers save on used storage helpers with delay (core auth)
|
||||
logging.getLogger('homeassistant.core').setLevel(logging.WARNING)
|
||||
loop.run_until_complete(hass.async_stop())
|
||||
|
||||
|
||||
async def run_command(hass, args):
|
||||
"""Run the command."""
|
||||
hass.config.config_dir = os.path.join(os.getcwd(), args.config)
|
||||
data = hass_auth.Data(hass)
|
||||
loop.run_until_complete(data.async_load())
|
||||
loop.run_until_complete(args.func(data, args))
|
||||
hass.auth = await auth_manager_from_config(hass, [{
|
||||
'type': 'homeassistant',
|
||||
}])
|
||||
provider = hass.auth.auth_providers[0]
|
||||
await provider.async_initialize()
|
||||
await args.func(hass, provider, args)
|
||||
|
||||
|
||||
async def list_users(data, args):
|
||||
async def list_users(hass, provider, args):
|
||||
"""List the users."""
|
||||
count = 0
|
||||
for user in data.users:
|
||||
for user in provider.data.users:
|
||||
count += 1
|
||||
print(user['username'])
|
||||
|
||||
|
@ -59,27 +73,40 @@ async def list_users(data, args):
|
|||
print("Total users:", count)
|
||||
|
||||
|
||||
async def add_user(data, args):
|
||||
async def add_user(hass, provider, args):
|
||||
"""Create a user."""
|
||||
data.add_user(args.username, args.password)
|
||||
await data.async_save()
|
||||
try:
|
||||
provider.data.add_auth(args.username, args.password)
|
||||
except hass_auth.InvalidUser:
|
||||
print("Username already exists!")
|
||||
return
|
||||
|
||||
credentials = await provider.async_get_or_create_credentials({
|
||||
'username': args.username
|
||||
})
|
||||
|
||||
user = await hass.auth.async_create_user(args.username)
|
||||
await hass.auth.async_link_user(user, credentials)
|
||||
|
||||
# Save username/password
|
||||
await provider.data.async_save()
|
||||
print("User created")
|
||||
|
||||
|
||||
async def validate_login(data, args):
|
||||
async def validate_login(hass, provider, args):
|
||||
"""Validate a login."""
|
||||
try:
|
||||
data.validate_login(args.username, args.password)
|
||||
provider.data.validate_login(args.username, args.password)
|
||||
print("Auth valid")
|
||||
except hass_auth.InvalidAuth:
|
||||
print("Auth invalid")
|
||||
|
||||
|
||||
async def change_password(data, args):
|
||||
async def change_password(hass, provider, args):
|
||||
"""Change password."""
|
||||
try:
|
||||
data.change_password(args.username, args.new_password)
|
||||
await data.async_save()
|
||||
provider.data.change_password(args.username, args.new_password)
|
||||
await provider.data.async_save()
|
||||
print("Password changed")
|
||||
except hass_auth.InvalidUser:
|
||||
print("User not found")
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
"""Test the Home Assistant local auth provider."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||
from homeassistant.auth.providers import (
|
||||
auth_provider_from_config, homeassistant as hass_auth)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -15,15 +18,15 @@ def data(hass):
|
|||
|
||||
async def test_adding_user(data, hass):
|
||||
"""Test adding a user."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
data.validate_login('test-user', 'test-pass')
|
||||
|
||||
|
||||
async def test_adding_user_duplicate_username(data, hass):
|
||||
"""Test adding a user."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
with pytest.raises(hass_auth.InvalidUser):
|
||||
data.add_user('test-user', 'other-pass')
|
||||
data.add_auth('test-user', 'other-pass')
|
||||
|
||||
|
||||
async def test_validating_password_invalid_user(data, hass):
|
||||
|
@ -34,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass):
|
|||
|
||||
async def test_validating_password_invalid_password(data, hass):
|
||||
"""Test validating an invalid user."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
|
||||
with pytest.raises(hass_auth.InvalidAuth):
|
||||
data.validate_login('test-user', 'invalid-pass')
|
||||
|
@ -43,7 +46,7 @@ async def test_validating_password_invalid_password(data, hass):
|
|||
async def test_changing_password(data, hass):
|
||||
"""Test adding a user."""
|
||||
user = 'test-user'
|
||||
data.add_user(user, 'test-pass')
|
||||
data.add_auth(user, 'test-pass')
|
||||
data.change_password(user, 'new-pass')
|
||||
|
||||
with pytest.raises(hass_auth.InvalidAuth):
|
||||
|
@ -60,7 +63,7 @@ async def test_changing_password_raises_invalid_user(data, hass):
|
|||
|
||||
async def test_login_flow_validates(data, hass):
|
||||
"""Test login flow."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
await data.async_save()
|
||||
|
||||
provider = hass_auth.HassAuthProvider(hass, None, {})
|
||||
|
@ -91,11 +94,21 @@ async def test_login_flow_validates(data, hass):
|
|||
|
||||
async def test_saving_loading(data, hass):
|
||||
"""Test saving and loading JSON."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_user('second-user', 'second-pass')
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
data.add_auth('second-user', 'second-pass')
|
||||
await data.async_save()
|
||||
|
||||
data = hass_auth.Data(hass)
|
||||
await data.async_load()
|
||||
data.validate_login('test-user', 'test-pass')
|
||||
data.validate_login('second-user', 'second-pass')
|
||||
|
||||
|
||||
async def test_not_allow_set_id():
|
||||
"""Test we are not allowed to set an ID in config."""
|
||||
hass = Mock()
|
||||
provider = await auth_provider_from_config(hass, None, {
|
||||
'type': 'homeassistant',
|
||||
'id': 'invalid',
|
||||
})
|
||||
assert provider is None
|
||||
|
|
|
@ -46,7 +46,7 @@ async def test_auth_manager_from_config_validates_config_and_id(mock_hass):
|
|||
'name': provider.name,
|
||||
'id': provider.id,
|
||||
'type': provider.type,
|
||||
} for provider in manager.async_auth_providers]
|
||||
} for provider in manager.auth_providers]
|
||||
assert providers == [{
|
||||
'name': 'Test Name',
|
||||
'type': 'insecure_example',
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Test the helper method for writing tests."""
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
import functools as ft
|
||||
import json
|
||||
|
@ -12,7 +13,8 @@ import threading
|
|||
from contextlib import contextmanager
|
||||
|
||||
from homeassistant import auth, core as ha, data_entry_flow, config_entries
|
||||
from homeassistant.auth import models as auth_models, auth_store
|
||||
from homeassistant.auth import (
|
||||
models as auth_models, auth_store, providers as auth_providers)
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.config import async_process_component_config
|
||||
from homeassistant.helpers import (
|
||||
|
@ -312,11 +314,12 @@ def mock_registry(hass, mock_entries=None):
|
|||
class MockUser(auth_models.User):
|
||||
"""Mock a user in Home Assistant."""
|
||||
|
||||
def __init__(self, id='mock-id', is_owner=True, is_active=True,
|
||||
name='Mock User'):
|
||||
def __init__(self, id='mock-id', is_owner=False, is_active=True,
|
||||
name='Mock User', system_generated=False):
|
||||
"""Initialize mock user."""
|
||||
super().__init__(
|
||||
id=id, is_owner=is_owner, is_active=is_active, name=name)
|
||||
id=id, is_owner=is_owner, is_active=is_active, name=name,
|
||||
system_generated=system_generated)
|
||||
|
||||
def add_to_hass(self, hass):
|
||||
"""Test helper to add entry to hass."""
|
||||
|
@ -329,12 +332,27 @@ class MockUser(auth_models.User):
|
|||
return self
|
||||
|
||||
|
||||
async def register_auth_provider(hass, config):
|
||||
"""Helper to register an auth provider."""
|
||||
provider = await auth_providers.auth_provider_from_config(
|
||||
hass, hass.auth._store, config)
|
||||
assert provider is not None, 'Invalid config specified'
|
||||
key = (provider.type, provider.id)
|
||||
providers = hass.auth._providers
|
||||
|
||||
if key in providers:
|
||||
raise ValueError('Provider already registered')
|
||||
|
||||
providers[key] = provider
|
||||
return provider
|
||||
|
||||
|
||||
@ha.callback
|
||||
def ensure_auth_manager_loaded(auth_mgr):
|
||||
"""Ensure an auth manager is considered loaded."""
|
||||
store = auth_mgr._store
|
||||
if store._users is None:
|
||||
store._users = {}
|
||||
store._users = OrderedDict()
|
||||
|
||||
|
||||
class MockModule(object):
|
||||
|
@ -731,7 +749,13 @@ def mock_storage(data=None):
|
|||
if store.key not in data:
|
||||
return None
|
||||
|
||||
store._data = data.get(store.key)
|
||||
mock_data = data.get(store.key)
|
||||
|
||||
if 'data' not in mock_data or 'version' not in mock_data:
|
||||
_LOGGER.error('Mock data needs "version" and "data"')
|
||||
raise ValueError('Mock data needs "version" and "data"')
|
||||
|
||||
store._data = mock_data
|
||||
|
||||
# Route through original load so that we trigger migration
|
||||
loaded = await orig_load(store)
|
||||
|
|
211
tests/components/config/test_auth.py
Normal file
211
tests/components/config/test_auth.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
"""Test config entries API."""
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth import models as auth_models
|
||||
from homeassistant.components.config import auth as auth_config
|
||||
|
||||
from tests.common import MockUser, CLIENT_ID
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def auth_active(hass):
|
||||
"""Mock that auth is active."""
|
||||
with patch('homeassistant.auth.AuthManager.active',
|
||||
PropertyMock(return_value=True)):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_config(hass, aiohttp_client):
|
||||
"""Fixture that sets up the auth provider homeassistant module."""
|
||||
hass.loop.run_until_complete(auth_config.async_setup(hass))
|
||||
|
||||
|
||||
async def test_list_requires_owner(hass, hass_ws_client, hass_access_token):
|
||||
"""Test get users requires auth."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_LIST,
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'unauthorized'
|
||||
|
||||
|
||||
async def test_list(hass, hass_ws_client):
|
||||
"""Test get users."""
|
||||
owner = MockUser(
|
||||
id='abc',
|
||||
name='Test Owner',
|
||||
is_owner=True,
|
||||
).add_to_hass(hass)
|
||||
|
||||
owner.credentials.append(auth_models.Credentials(
|
||||
auth_provider_type='homeassistant',
|
||||
auth_provider_id=None,
|
||||
data={},
|
||||
))
|
||||
|
||||
system = MockUser(
|
||||
id='efg',
|
||||
name='Test Hass.io',
|
||||
system_generated=True
|
||||
).add_to_hass(hass)
|
||||
|
||||
inactive = MockUser(
|
||||
id='hij',
|
||||
name='Inactive User',
|
||||
is_active=False,
|
||||
).add_to_hass(hass)
|
||||
|
||||
refresh_token = await hass.auth.async_create_refresh_token(
|
||||
owner, CLIENT_ID)
|
||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
client = await hass_ws_client(hass, access_token)
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_LIST,
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result['success'], result
|
||||
data = result['result']
|
||||
assert len(data) == 3
|
||||
assert data[0] == {
|
||||
'id': owner.id,
|
||||
'name': 'Test Owner',
|
||||
'is_owner': True,
|
||||
'is_active': True,
|
||||
'system_generated': False,
|
||||
'credentials': [{'type': 'homeassistant'}]
|
||||
}
|
||||
assert data[1] == {
|
||||
'id': system.id,
|
||||
'name': 'Test Hass.io',
|
||||
'is_owner': False,
|
||||
'is_active': True,
|
||||
'system_generated': True,
|
||||
'credentials': [],
|
||||
}
|
||||
assert data[2] == {
|
||||
'id': inactive.id,
|
||||
'name': 'Inactive User',
|
||||
'is_owner': False,
|
||||
'is_active': False,
|
||||
'system_generated': False,
|
||||
'credentials': [],
|
||||
}
|
||||
|
||||
|
||||
async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
|
||||
"""Test delete command requires an owner."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_DELETE,
|
||||
'user_id': 'abcd',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'unauthorized'
|
||||
|
||||
|
||||
async def test_delete_unable_self_account(hass, hass_ws_client,
|
||||
hass_access_token):
|
||||
"""Test we cannot delete our own account."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_DELETE,
|
||||
'user_id': hass_access_token.refresh_token.user.id,
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'unauthorized'
|
||||
|
||||
|
||||
async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
|
||||
"""Test we cannot delete an unknown user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_DELETE,
|
||||
'user_id': 'abcd',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'not_found'
|
||||
|
||||
|
||||
async def test_delete(hass, hass_ws_client, hass_access_token):
|
||||
"""Test delete command works."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
test_user = MockUser(
|
||||
id='efg',
|
||||
).add_to_hass(hass)
|
||||
|
||||
assert len(await hass.auth.async_get_users()) == 2
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_DELETE,
|
||||
'user_id': test_user.id,
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result['success'], result
|
||||
assert len(await hass.auth.async_get_users()) == 1
|
||||
|
||||
|
||||
async def test_create(hass, hass_ws_client, hass_access_token):
|
||||
"""Test create command works."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
assert len(await hass.auth.async_get_users()) == 1
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_CREATE,
|
||||
'name': 'Paulus',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result['success'], result
|
||||
assert len(await hass.auth.async_get_users()) == 2
|
||||
data_user = result['result']['user']
|
||||
user = await hass.auth.async_get_user(data_user['id'])
|
||||
assert user is not None
|
||||
assert user.name == data_user['name']
|
||||
assert user.is_active
|
||||
assert not user.is_owner
|
||||
assert not user.system_generated
|
||||
|
||||
|
||||
async def test_create_requires_owner(hass, hass_ws_client, hass_access_token):
|
||||
"""Test create command requires an owner."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_CREATE,
|
||||
'name': 'YO',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'unauthorized'
|
229
tests/components/config/test_auth_provider_homeassistant.py
Normal file
229
tests/components/config/test_auth_provider_homeassistant.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
"""Test config entries API."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth.providers import homeassistant as prov_ha
|
||||
from homeassistant.components.config import (
|
||||
auth_provider_homeassistant as auth_ha)
|
||||
|
||||
from tests.common import MockUser, register_auth_provider
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_config(hass, aiohttp_client):
|
||||
"""Fixture that sets up the auth provider homeassistant module."""
|
||||
hass.loop.run_until_complete(register_auth_provider(hass, {
|
||||
'type': 'homeassistant'
|
||||
}))
|
||||
hass.loop.run_until_complete(auth_ha.async_setup(hass))
|
||||
|
||||
|
||||
async def test_create_auth_system_generated_user(hass, hass_access_token,
|
||||
hass_ws_client):
|
||||
"""Test we can't add auth to system generated users."""
|
||||
system_user = MockUser(system_generated=True).add_to_hass(hass)
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_CREATE,
|
||||
'user_id': system_user.id,
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'system_generated'
|
||||
|
||||
|
||||
async def test_create_auth_user_already_credentials():
|
||||
"""Test we can't create auth for user with pre-existing credentials."""
|
||||
# assert False
|
||||
|
||||
|
||||
async def test_create_auth_unknown_user(hass_ws_client, hass,
|
||||
hass_access_token):
|
||||
"""Test create pointing at unknown user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_CREATE,
|
||||
'user_id': 'test-id',
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'not_found'
|
||||
|
||||
|
||||
async def test_create_auth_requires_owner(hass, hass_ws_client,
|
||||
hass_access_token):
|
||||
"""Test create requires owner to call API."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_CREATE,
|
||||
'user_id': 'test-id',
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'unauthorized'
|
||||
|
||||
|
||||
async def test_create_auth(hass, hass_ws_client, hass_access_token,
|
||||
hass_storage):
|
||||
"""Test create auth command works."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
user = MockUser().add_to_hass(hass)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
assert len(user.credentials) == 0
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_CREATE,
|
||||
'user_id': user.id,
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result['success'], result
|
||||
assert len(user.credentials) == 1
|
||||
creds = user.credentials[0]
|
||||
assert creds.auth_provider_type == 'homeassistant'
|
||||
assert creds.auth_provider_id is None
|
||||
assert creds.data == {
|
||||
'username': 'test-user'
|
||||
}
|
||||
assert prov_ha.STORAGE_KEY in hass_storage
|
||||
entry = hass_storage[prov_ha.STORAGE_KEY]['data']['users'][0]
|
||||
assert entry['username'] == 'test-user'
|
||||
|
||||
|
||||
async def test_create_auth_duplicate_username(hass, hass_ws_client,
|
||||
hass_access_token, hass_storage):
|
||||
"""Test we can't create auth with a duplicate username."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
user = MockUser().add_to_hass(hass)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
hass_storage[prov_ha.STORAGE_KEY] = {
|
||||
'version': 1,
|
||||
'data': {
|
||||
'users': [{
|
||||
'username': 'test-user'
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_CREATE,
|
||||
'user_id': user.id,
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'username_exists'
|
||||
|
||||
|
||||
async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage,
|
||||
hass_access_token):
|
||||
"""Test deleting an auth without being connected to a user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
hass_storage[prov_ha.STORAGE_KEY] = {
|
||||
'version': 1,
|
||||
'data': {
|
||||
'users': [{
|
||||
'username': 'test-user'
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_DELETE,
|
||||
'username': 'test-user',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result['success'], result
|
||||
assert len(hass_storage[prov_ha.STORAGE_KEY]['data']['users']) == 0
|
||||
|
||||
|
||||
async def test_delete_removes_credential(hass, hass_ws_client,
|
||||
hass_access_token, hass_storage):
|
||||
"""Test deleting auth that is connected to a user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
user = MockUser().add_to_hass(hass)
|
||||
user.credentials.append(
|
||||
await hass.auth.auth_providers[0].async_get_or_create_credentials({
|
||||
'username': 'test-user'}))
|
||||
|
||||
hass_storage[prov_ha.STORAGE_KEY] = {
|
||||
'version': 1,
|
||||
'data': {
|
||||
'users': [{
|
||||
'username': 'test-user'
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_DELETE,
|
||||
'username': 'test-user',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result['success'], result
|
||||
assert len(hass_storage[prov_ha.STORAGE_KEY]['data']['users']) == 0
|
||||
|
||||
|
||||
async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
|
||||
"""Test delete requires owner."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_DELETE,
|
||||
'username': 'test-user',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'unauthorized'
|
||||
|
||||
|
||||
async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token):
|
||||
"""Test trying to delete an unknown auth username."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_ha.WS_TYPE_DELETE,
|
||||
'username': 'test-user',
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result['success'], result
|
||||
assert result['error']['code'] == 'auth_not_found'
|
|
@ -2,6 +2,7 @@
|
|||
import pytest
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import websocket_api
|
||||
|
||||
from tests.common import MockUser, CLIENT_ID
|
||||
|
||||
|
@ -9,13 +10,27 @@ from tests.common import MockUser, CLIENT_ID
|
|||
@pytest.fixture
|
||||
def hass_ws_client(aiohttp_client):
|
||||
"""Websocket client fixture connected to websocket server."""
|
||||
async def create_client(hass):
|
||||
async def create_client(hass, access_token=None):
|
||||
"""Create a websocket client."""
|
||||
wapi = hass.components.websocket_api
|
||||
assert await async_setup_component(hass, 'websocket_api')
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
websocket = await client.ws_connect(wapi.URL)
|
||||
auth_resp = await websocket.receive_json()
|
||||
|
||||
if auth_resp['type'] == wapi.TYPE_AUTH_OK:
|
||||
assert access_token is None, \
|
||||
'Access token given but no auth required'
|
||||
return websocket
|
||||
|
||||
assert access_token is not None, 'Access token required for fixture'
|
||||
|
||||
await websocket.send_json({
|
||||
'type': websocket_api.TYPE_AUTH,
|
||||
'access_token': access_token.token
|
||||
})
|
||||
|
||||
auth_ok = await websocket.receive_json()
|
||||
assert auth_ok['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
"""The tests for the Home Assistant HTTP component."""
|
||||
# pylint: disable=protected-access
|
||||
from ipaddress import ip_network
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import BasicAuth, web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
|
||||
from homeassistant.auth.models import AccessToken, RefreshToken
|
||||
from homeassistant.components.http.auth import setup_auth
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
from homeassistant.components.http.real_ip import setup_real_ip
|
||||
|
@ -16,8 +15,6 @@ from homeassistant.setup import async_setup_component
|
|||
from . import mock_real_ip
|
||||
|
||||
|
||||
ACCESS_TOKEN = 'tk.1234'
|
||||
|
||||
API_PASSWORD = 'test1234'
|
||||
|
||||
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
||||
|
@ -39,33 +36,21 @@ async def mock_handler(request):
|
|||
return web.Response(status=200)
|
||||
|
||||
|
||||
def mock_async_get_access_token(token):
|
||||
"""Return if token is valid."""
|
||||
if token == ACCESS_TOKEN:
|
||||
return Mock(spec=AccessToken,
|
||||
token=ACCESS_TOKEN,
|
||||
refresh_token=Mock(spec=RefreshToken))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
def app(hass):
|
||||
"""Fixture to setup a web.Application."""
|
||||
app = web.Application()
|
||||
mock_auth = Mock(async_get_access_token=mock_async_get_access_token)
|
||||
app['hass'] = Mock(auth=mock_auth)
|
||||
app['hass'] = hass
|
||||
app.router.add_get('/', mock_handler)
|
||||
setup_real_ip(app, False, [])
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app2():
|
||||
def app2(hass):
|
||||
"""Fixture to setup a web.Application without real_ip middleware."""
|
||||
app = web.Application()
|
||||
mock_auth = Mock(async_get_access_token=mock_async_get_access_token)
|
||||
app['hass'] = Mock(auth=mock_auth)
|
||||
app['hass'] = hass
|
||||
app.router.add_get('/', mock_handler)
|
||||
return app
|
||||
|
||||
|
@ -171,33 +156,35 @@ async def test_access_with_trusted_ip(app2, aiohttp_client):
|
|||
|
||||
|
||||
async def test_auth_active_access_with_access_token_in_header(
|
||||
app, aiohttp_client):
|
||||
app, aiohttp_client, hass_access_token):
|
||||
"""Test access with access token in header."""
|
||||
token = hass_access_token.token
|
||||
setup_auth(app, [], True, api_password=None)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
req = await client.get(
|
||||
'/', headers={'Authorization': 'Bearer {}'.format(ACCESS_TOKEN)})
|
||||
'/', headers={'Authorization': 'Bearer {}'.format(token)})
|
||||
assert req.status == 200
|
||||
|
||||
req = await client.get(
|
||||
'/', headers={'AUTHORIZATION': 'Bearer {}'.format(ACCESS_TOKEN)})
|
||||
'/', headers={'AUTHORIZATION': 'Bearer {}'.format(token)})
|
||||
assert req.status == 200
|
||||
|
||||
req = await client.get(
|
||||
'/', headers={'authorization': 'Bearer {}'.format(ACCESS_TOKEN)})
|
||||
'/', headers={'authorization': 'Bearer {}'.format(token)})
|
||||
assert req.status == 200
|
||||
|
||||
req = await client.get(
|
||||
'/', headers={'Authorization': ACCESS_TOKEN})
|
||||
'/', headers={'Authorization': token})
|
||||
assert req.status == 401
|
||||
|
||||
req = await client.get(
|
||||
'/', headers={'Authorization': 'BEARER {}'.format(ACCESS_TOKEN)})
|
||||
'/', headers={'Authorization': 'BEARER {}'.format(token)})
|
||||
assert req.status == 401
|
||||
|
||||
hass_access_token.refresh_token.user.is_active = False
|
||||
req = await client.get(
|
||||
'/', headers={'Authorization': 'Bearer wrong-pass'})
|
||||
'/', headers={'Authorization': 'Bearer {}'.format(token)})
|
||||
assert req.status == 401
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ if os.environ.get('UVLOOP') == '1':
|
|||
import uvloop
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
|
||||
|
||||
|
||||
|
|
|
@ -6,21 +6,26 @@ import pytest
|
|||
from homeassistant.scripts import auth as script_auth
|
||||
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||
|
||||
from tests.common import register_auth_provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data(hass):
|
||||
"""Create a loaded data class."""
|
||||
data = hass_auth.Data(hass)
|
||||
hass.loop.run_until_complete(data.async_load())
|
||||
return data
|
||||
def provider(hass):
|
||||
"""Home Assistant auth provider."""
|
||||
provider = hass.loop.run_until_complete(register_auth_provider(hass, {
|
||||
'type': 'homeassistant',
|
||||
}))
|
||||
hass.loop.run_until_complete(provider.async_initialize())
|
||||
return provider
|
||||
|
||||
|
||||
async def test_list_user(data, capsys):
|
||||
async def test_list_user(hass, provider, capsys):
|
||||
"""Test we can list users."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_user('second-user', 'second-pass')
|
||||
data = provider.data
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
data.add_auth('second-user', 'second-pass')
|
||||
|
||||
await script_auth.list_users(data, None)
|
||||
await script_auth.list_users(hass, provider, None)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
|
||||
|
@ -33,10 +38,11 @@ async def test_list_user(data, capsys):
|
|||
])
|
||||
|
||||
|
||||
async def test_add_user(data, capsys, hass_storage):
|
||||
async def test_add_user(hass, provider, capsys, hass_storage):
|
||||
"""Test we can add a user."""
|
||||
data = provider.data
|
||||
await script_auth.add_user(
|
||||
data, Mock(username='paulus', password='test-pass'))
|
||||
hass, provider, Mock(username='paulus', password='test-pass'))
|
||||
|
||||
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
||||
|
||||
|
@ -47,32 +53,34 @@ async def test_add_user(data, capsys, hass_storage):
|
|||
data.validate_login('paulus', 'test-pass')
|
||||
|
||||
|
||||
async def test_validate_login(data, capsys):
|
||||
async def test_validate_login(hass, provider, capsys):
|
||||
"""Test we can validate a user login."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data = provider.data
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
|
||||
await script_auth.validate_login(
|
||||
data, Mock(username='test-user', password='test-pass'))
|
||||
hass, provider, Mock(username='test-user', password='test-pass'))
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'Auth valid\n'
|
||||
|
||||
await script_auth.validate_login(
|
||||
data, Mock(username='test-user', password='invalid-pass'))
|
||||
hass, provider, Mock(username='test-user', password='invalid-pass'))
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'Auth invalid\n'
|
||||
|
||||
await script_auth.validate_login(
|
||||
data, Mock(username='invalid-user', password='test-pass'))
|
||||
hass, provider, Mock(username='invalid-user', password='test-pass'))
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'Auth invalid\n'
|
||||
|
||||
|
||||
async def test_change_password(data, capsys, hass_storage):
|
||||
async def test_change_password(hass, provider, capsys, hass_storage):
|
||||
"""Test we can change a password."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data = provider.data
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
|
||||
await script_auth.change_password(
|
||||
data, Mock(username='test-user', new_password='new-pass'))
|
||||
hass, provider, Mock(username='test-user', new_password='new-pass'))
|
||||
|
||||
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
||||
captured = capsys.readouterr()
|
||||
|
@ -82,12 +90,14 @@ async def test_change_password(data, capsys, hass_storage):
|
|||
data.validate_login('test-user', 'test-pass')
|
||||
|
||||
|
||||
async def test_change_password_invalid_user(data, capsys, hass_storage):
|
||||
async def test_change_password_invalid_user(hass, provider, capsys,
|
||||
hass_storage):
|
||||
"""Test changing password of non-existing user."""
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data = provider.data
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
|
||||
await script_auth.change_password(
|
||||
data, Mock(username='invalid-user', new_password='new-pass'))
|
||||
hass, provider, Mock(username='invalid-user', new_password='new-pass'))
|
||||
|
||||
assert hass_auth.STORAGE_KEY not in hass_storage
|
||||
captured = capsys.readouterr()
|
||||
|
@ -101,11 +111,11 @@ def test_parsing_args(loop):
|
|||
"""Test we parse args correctly."""
|
||||
called = False
|
||||
|
||||
async def mock_func(data, args2):
|
||||
async def mock_func(hass, provider, args2):
|
||||
"""Mock function to be called."""
|
||||
nonlocal called
|
||||
called = True
|
||||
assert data.hass.config.config_dir == '/somewhere/config'
|
||||
assert provider.hass.config.config_dir == '/somewhere/config'
|
||||
assert args2 is args
|
||||
|
||||
args = Mock(config='/somewhere/config', func=mock_func)
|
||||
|
|
Loading…
Add table
Reference in a new issue