Support blocking trusted network from new ip (#44630)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
e4a7692610
commit
38d2cacf7a
21 changed files with 381 additions and 131 deletions
|
@ -24,6 +24,14 @@ _ProviderKey = Tuple[str, Optional[str]]
|
|||
_ProviderDict = Dict[_ProviderKey, AuthProvider]
|
||||
|
||||
|
||||
class InvalidAuthError(Exception):
|
||||
"""Raised when a authentication error occurs."""
|
||||
|
||||
|
||||
class InvalidProvider(Exception):
|
||||
"""Authentication provider not found."""
|
||||
|
||||
|
||||
async def auth_manager_from_config(
|
||||
hass: HomeAssistant,
|
||||
provider_configs: List[Dict[str, Any]],
|
||||
|
@ -96,7 +104,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||
return result
|
||||
|
||||
# we got final result
|
||||
if isinstance(result["data"], models.User):
|
||||
if isinstance(result["data"], models.Credentials):
|
||||
result["result"] = result["data"]
|
||||
return result
|
||||
|
||||
|
@ -120,11 +128,12 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||
modules = await self.auth_manager.async_get_enabled_mfa(user)
|
||||
|
||||
if modules:
|
||||
flow.credential = credentials
|
||||
flow.user = user
|
||||
flow.available_mfa_modules = modules
|
||||
return await flow.async_step_select_mfa_module()
|
||||
|
||||
result["result"] = await self.auth_manager.async_get_or_create_user(credentials)
|
||||
result["result"] = credentials
|
||||
return result
|
||||
|
||||
|
||||
|
@ -156,7 +165,7 @@ class AuthManager:
|
|||
return list(self._mfa_modules.values())
|
||||
|
||||
def get_auth_provider(
|
||||
self, provider_type: str, provider_id: str
|
||||
self, provider_type: str, provider_id: Optional[str]
|
||||
) -> Optional[AuthProvider]:
|
||||
"""Return an auth provider, None if not found."""
|
||||
return self._providers.get((provider_type, provider_id))
|
||||
|
@ -367,6 +376,7 @@ class AuthManager:
|
|||
client_icon: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||
credential: Optional[models.Credentials] = None,
|
||||
) -> models.RefreshToken:
|
||||
"""Create a new refresh token for a user."""
|
||||
if not user.is_active:
|
||||
|
@ -415,6 +425,7 @@ class AuthManager:
|
|||
client_icon,
|
||||
token_type,
|
||||
access_token_expiration,
|
||||
credential,
|
||||
)
|
||||
|
||||
async def async_get_refresh_token(
|
||||
|
@ -440,6 +451,8 @@ class AuthManager:
|
|||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
||||
) -> str:
|
||||
"""Create a new access token."""
|
||||
self.async_validate_refresh_token(refresh_token, remote_ip)
|
||||
|
||||
self._store.async_log_refresh_token_usage(refresh_token, remote_ip)
|
||||
|
||||
now = dt_util.utcnow()
|
||||
|
@ -453,6 +466,40 @@ class AuthManager:
|
|||
algorithm="HS256",
|
||||
).decode()
|
||||
|
||||
@callback
|
||||
def _async_resolve_provider(
|
||||
self, refresh_token: models.RefreshToken
|
||||
) -> Optional[AuthProvider]:
|
||||
"""Get the auth provider for the given refresh token.
|
||||
|
||||
Raises an exception if the expected provider is no longer available or return
|
||||
None if no provider was expected for this refresh token.
|
||||
"""
|
||||
if refresh_token.credential is None:
|
||||
return None
|
||||
|
||||
provider = self.get_auth_provider(
|
||||
refresh_token.credential.auth_provider_type,
|
||||
refresh_token.credential.auth_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
raise InvalidProvider(
|
||||
f"Auth provider {refresh_token.credential.auth_provider_type}, {refresh_token.credential.auth_provider_id} not available"
|
||||
)
|
||||
return provider
|
||||
|
||||
@callback
|
||||
def async_validate_refresh_token(
|
||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
||||
) -> None:
|
||||
"""Validate that a refresh token is usable.
|
||||
|
||||
Will raise InvalidAuthError on errors.
|
||||
"""
|
||||
provider = self._async_resolve_provider(refresh_token)
|
||||
if provider:
|
||||
provider.async_validate_refresh_token(refresh_token, remote_ip)
|
||||
|
||||
async def async_validate_access_token(
|
||||
self, token: str
|
||||
) -> Optional[models.RefreshToken]:
|
||||
|
|
|
@ -208,6 +208,7 @@ class AuthStore:
|
|||
client_icon: Optional[str] = None,
|
||||
token_type: str = models.TOKEN_TYPE_NORMAL,
|
||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||
credential: Optional[models.Credentials] = None,
|
||||
) -> models.RefreshToken:
|
||||
"""Create a new token for a user."""
|
||||
kwargs: Dict[str, Any] = {
|
||||
|
@ -215,6 +216,7 @@ class AuthStore:
|
|||
"client_id": client_id,
|
||||
"token_type": token_type,
|
||||
"access_token_expiration": access_token_expiration,
|
||||
"credential": credential,
|
||||
}
|
||||
if client_name:
|
||||
kwargs["client_name"] = client_name
|
||||
|
@ -309,6 +311,7 @@ class AuthStore:
|
|||
|
||||
users: Dict[str, models.User] = OrderedDict()
|
||||
groups: Dict[str, models.Group] = OrderedDict()
|
||||
credentials: Dict[str, models.Credentials] = OrderedDict()
|
||||
|
||||
# Soft-migrating data as we load. We are going to make sure we have a
|
||||
# read only group and an admin group. There are two states that we can
|
||||
|
@ -415,15 +418,15 @@ class AuthStore:
|
|||
)
|
||||
|
||||
for cred_dict in data["credentials"]:
|
||||
users[cred_dict["user_id"]].credentials.append(
|
||||
models.Credentials(
|
||||
id=cred_dict["id"],
|
||||
is_new=False,
|
||||
auth_provider_type=cred_dict["auth_provider_type"],
|
||||
auth_provider_id=cred_dict["auth_provider_id"],
|
||||
data=cred_dict["data"],
|
||||
)
|
||||
credential = models.Credentials(
|
||||
id=cred_dict["id"],
|
||||
is_new=False,
|
||||
auth_provider_type=cred_dict["auth_provider_type"],
|
||||
auth_provider_id=cred_dict["auth_provider_id"],
|
||||
data=cred_dict["data"],
|
||||
)
|
||||
credentials[cred_dict["id"]] = credential
|
||||
users[cred_dict["user_id"]].credentials.append(credential)
|
||||
|
||||
for rt_dict in data["refresh_tokens"]:
|
||||
# Filter out the old keys that don't have jwt_key (pre-0.76)
|
||||
|
@ -469,6 +472,8 @@ class AuthStore:
|
|||
jwt_key=rt_dict["jwt_key"],
|
||||
last_used_at=last_used_at,
|
||||
last_used_ip=rt_dict.get("last_used_ip"),
|
||||
credential=credentials.get(rt_dict.get("credential_id")),
|
||||
version=rt_dict.get("version"),
|
||||
)
|
||||
users[rt_dict["user_id"]].refresh_tokens[token.id] = token
|
||||
|
||||
|
@ -542,6 +547,10 @@ class AuthStore:
|
|||
if refresh_token.last_used_at
|
||||
else None,
|
||||
"last_used_ip": refresh_token.last_used_ip,
|
||||
"credential_id": refresh_token.credential.id
|
||||
if refresh_token.credential
|
||||
else None,
|
||||
"version": refresh_token.version,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
|
|
|
@ -6,6 +6,7 @@ import uuid
|
|||
|
||||
import attr
|
||||
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import permissions as perm_mdl
|
||||
|
@ -106,6 +107,10 @@ class RefreshToken:
|
|||
last_used_at: Optional[datetime] = attr.ib(default=None)
|
||||
last_used_ip: Optional[str] = attr.ib(default=None)
|
||||
|
||||
credential: Optional["Credentials"] = attr.ib(default=None)
|
||||
|
||||
version: Optional[str] = attr.ib(default=__version__)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Credentials:
|
||||
|
|
|
@ -16,7 +16,7 @@ from homeassistant.util.decorator import Registry
|
|||
|
||||
from ..auth_store import AuthStore
|
||||
from ..const import MFA_SESSION_EXPIRATION
|
||||
from ..models import Credentials, User, UserMeta
|
||||
from ..models import Credentials, RefreshToken, User, UserMeta
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS = "auth_prov_reqs_processed"
|
||||
|
@ -117,6 +117,16 @@ class AuthProvider:
|
|||
async def async_initialize(self) -> None:
|
||||
"""Initialize the auth provider."""
|
||||
|
||||
@callback
|
||||
def async_validate_refresh_token(
|
||||
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
||||
) -> None:
|
||||
"""Verify a refresh token is still valid.
|
||||
|
||||
Optional hook for an auth provider to verify validity of a refresh token.
|
||||
Should raise InvalidAuthError on errors.
|
||||
"""
|
||||
|
||||
|
||||
async def auth_provider_from_config(
|
||||
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
|
||||
|
@ -182,6 +192,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
self.created_at = dt_util.utcnow()
|
||||
self.invalid_mfa_times = 0
|
||||
self.user: Optional[User] = None
|
||||
self.credential: Optional[Credentials] = None
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
|
@ -222,6 +233,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle the step of mfa validation."""
|
||||
assert self.credential
|
||||
assert self.user
|
||||
|
||||
errors = {}
|
||||
|
@ -257,7 +269,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
return self.async_abort(reason="too_many_retry")
|
||||
|
||||
if not errors:
|
||||
return await self.async_finish(self.user)
|
||||
return await self.async_finish(self.credential)
|
||||
|
||||
description_placeholders: Dict[str, Optional[str]] = {
|
||||
"mfa_module_name": auth_module.name,
|
||||
|
|
|
@ -8,13 +8,12 @@ from typing import Any, Dict, Optional, cast
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
from .. import AuthManager
|
||||
from ..models import Credentials, User, UserMeta
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
AUTH_PROVIDER_TYPE = "legacy_api_password"
|
||||
CONF_API_PASSWORD = "api_password"
|
||||
|
@ -30,23 +29,6 @@ class InvalidAuthError(HomeAssistantError):
|
|||
"""Raised when submitting invalid authentication."""
|
||||
|
||||
|
||||
async def async_validate_password(hass: HomeAssistant, password: str) -> Optional[User]:
|
||||
"""Return a user if password is valid. None if not."""
|
||||
auth = cast(AuthManager, hass.auth) # type: ignore
|
||||
providers = auth.get_auth_providers(AUTH_PROVIDER_TYPE)
|
||||
if not providers:
|
||||
raise ValueError("Legacy API password provider not found")
|
||||
|
||||
try:
|
||||
provider = cast(LegacyApiPasswordAuthProvider, providers[0])
|
||||
provider.async_validate_login(password)
|
||||
return await auth.async_get_or_create_user(
|
||||
await provider.async_get_or_create_credentials({})
|
||||
)
|
||||
except InvalidAuthError:
|
||||
return None
|
||||
|
||||
|
||||
@AUTH_PROVIDERS.register(AUTH_PROVIDER_TYPE)
|
||||
class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||
"""An auth provider support legacy api_password."""
|
||||
|
|
|
@ -3,7 +3,14 @@
|
|||
It shows list of users if access from trusted network.
|
||||
Abort login flow if not access from trusted network.
|
||||
"""
|
||||
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Network,
|
||||
IPv6Address,
|
||||
IPv6Network,
|
||||
ip_address,
|
||||
ip_network,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -13,7 +20,8 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
from ..models import Credentials, UserMeta
|
||||
from .. import InvalidAuthError
|
||||
from ..models import Credentials, RefreshToken, UserMeta
|
||||
|
||||
IPAddress = Union[IPv4Address, IPv6Address]
|
||||
IPNetwork = Union[IPv4Network, IPv6Network]
|
||||
|
@ -46,10 +54,6 @@ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
|
|||
)
|
||||
|
||||
|
||||
class InvalidAuthError(HomeAssistantError):
|
||||
"""Raised when try to access from untrusted networks."""
|
||||
|
||||
|
||||
class InvalidUserError(HomeAssistantError):
|
||||
"""Raised when try to login as invalid user."""
|
||||
|
||||
|
@ -163,6 +167,17 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
):
|
||||
raise InvalidAuthError("Not in trusted_networks")
|
||||
|
||||
@callback
|
||||
def async_validate_refresh_token(
|
||||
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
||||
) -> None:
|
||||
"""Verify a refresh token is still valid."""
|
||||
if remote_ip is None:
|
||||
raise InvalidAuthError(
|
||||
"Unknown remote ip can't be used for trusted network provider."
|
||||
)
|
||||
self.async_validate_access(ip_address(remote_ip))
|
||||
|
||||
|
||||
class TrustedNetworksLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
|
|
@ -115,11 +115,13 @@ Result will be a long-lived access token:
|
|||
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from typing import Union
|
||||
import uuid
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth import InvalidAuthError
|
||||
from homeassistant.auth.models import (
|
||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
Credentials,
|
||||
|
@ -180,9 +182,11 @@ RESULT_TYPE_USER = "user"
|
|||
|
||||
|
||||
@bind_hass
|
||||
def create_auth_code(hass, client_id: str, user: User) -> str:
|
||||
def create_auth_code(
|
||||
hass, client_id: str, credential_or_user: Union[Credentials, User]
|
||||
) -> str:
|
||||
"""Create an authorization code to fetch tokens."""
|
||||
return hass.data[DOMAIN](client_id, user)
|
||||
return hass.data[DOMAIN](client_id, credential_or_user)
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
|
@ -228,9 +232,9 @@ class TokenView(HomeAssistantView):
|
|||
requires_auth = False
|
||||
cors_allowed = True
|
||||
|
||||
def __init__(self, retrieve_user):
|
||||
def __init__(self, retrieve_auth):
|
||||
"""Initialize the token view."""
|
||||
self._retrieve_user = retrieve_user
|
||||
self._retrieve_auth = retrieve_auth
|
||||
|
||||
@log_invalid_auth
|
||||
async def post(self, request):
|
||||
|
@ -293,16 +297,15 @@ class TokenView(HomeAssistantView):
|
|||
status_code=HTTP_BAD_REQUEST,
|
||||
)
|
||||
|
||||
user = self._retrieve_user(client_id, RESULT_TYPE_USER, code)
|
||||
credential = self._retrieve_auth(client_id, RESULT_TYPE_CREDENTIALS, code)
|
||||
|
||||
if user is None or not isinstance(user, User):
|
||||
if credential is None or not isinstance(credential, Credentials):
|
||||
return self.json(
|
||||
{"error": "invalid_request", "error_description": "Invalid code"},
|
||||
status_code=HTTP_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# refresh user
|
||||
user = await hass.auth.async_get_user(user.id)
|
||||
user = await hass.auth.async_get_or_create_user(credential)
|
||||
|
||||
if not user.is_active:
|
||||
return self.json(
|
||||
|
@ -310,8 +313,18 @@ class TokenView(HomeAssistantView):
|
|||
status_code=HTTP_FORBIDDEN,
|
||||
)
|
||||
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user, client_id)
|
||||
access_token = hass.auth.async_create_access_token(refresh_token, remote_addr)
|
||||
refresh_token = await hass.auth.async_create_refresh_token(
|
||||
user, client_id, credential=credential
|
||||
)
|
||||
try:
|
||||
access_token = hass.auth.async_create_access_token(
|
||||
refresh_token, remote_addr
|
||||
)
|
||||
except InvalidAuthError as exc:
|
||||
return self.json(
|
||||
{"error": "access_denied", "error_description": str(exc)},
|
||||
status_code=HTTP_FORBIDDEN,
|
||||
)
|
||||
|
||||
return self.json(
|
||||
{
|
||||
|
@ -346,7 +359,15 @@ class TokenView(HomeAssistantView):
|
|||
if refresh_token.client_id != client_id:
|
||||
return self.json({"error": "invalid_request"}, status_code=HTTP_BAD_REQUEST)
|
||||
|
||||
access_token = hass.auth.async_create_access_token(refresh_token, remote_addr)
|
||||
try:
|
||||
access_token = hass.auth.async_create_access_token(
|
||||
refresh_token, remote_addr
|
||||
)
|
||||
except InvalidAuthError as exc:
|
||||
return self.json(
|
||||
{"error": "access_denied", "error_description": str(exc)},
|
||||
status_code=HTTP_FORBIDDEN,
|
||||
)
|
||||
|
||||
return self.json(
|
||||
{
|
||||
|
@ -482,7 +503,12 @@ async def websocket_create_long_lived_access_token(
|
|||
access_token_expiration=timedelta(days=msg["lifespan"]),
|
||||
)
|
||||
|
||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
try:
|
||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
except InvalidAuthError as exc:
|
||||
return websocket_api.error_message(
|
||||
msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc)
|
||||
)
|
||||
|
||||
connection.send_message(websocket_api.result_message(msg["id"], access_token))
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import jwt
|
|||
from homeassistant.core import callback
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_USER
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
|
@ -62,6 +62,7 @@ def setup_auth(hass, app):
|
|||
return False
|
||||
|
||||
request[KEY_HASS_USER] = refresh_token.user
|
||||
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||
return True
|
||||
|
||||
async def async_validate_signed_request(request):
|
||||
|
@ -92,6 +93,7 @@ def setup_auth(hass, app):
|
|||
return False
|
||||
|
||||
request[KEY_HASS_USER] = refresh_token.user
|
||||
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||
return True
|
||||
|
||||
@middleware
|
||||
|
|
|
@ -2,3 +2,4 @@
|
|||
KEY_AUTHENTICATED = "ha_authenticated"
|
||||
KEY_HASS = "hass"
|
||||
KEY_HASS_USER = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID = "hass_refresh_token_id"
|
||||
|
|
|
@ -5,6 +5,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
from homeassistant.components.auth import indieauth
|
||||
from homeassistant.components.http.const import KEY_HASS_REFRESH_TOKEN_ID
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.components.http.view import HomeAssistantView
|
||||
from homeassistant.const import HTTP_BAD_REQUEST, HTTP_FORBIDDEN
|
||||
|
@ -132,7 +133,9 @@ class UserOnboardingView(_BaseOnboardingView):
|
|||
|
||||
# Return authorization code for fetching tokens and connect
|
||||
# during onboarding.
|
||||
auth_code = hass.components.auth.create_auth_code(data["client_id"], user)
|
||||
auth_code = hass.components.auth.create_auth_code(
|
||||
data["client_id"], credentials
|
||||
)
|
||||
return self.json({"auth_code": auth_code})
|
||||
|
||||
|
||||
|
@ -183,7 +186,7 @@ class IntegrationOnboardingView(_BaseOnboardingView):
|
|||
async def post(self, request, data):
|
||||
"""Handle token creation."""
|
||||
hass = request.app["hass"]
|
||||
user = request["hass_user"]
|
||||
refresh_token_id = request[KEY_HASS_REFRESH_TOKEN_ID]
|
||||
|
||||
async with self._lock:
|
||||
if self._async_is_done():
|
||||
|
@ -201,8 +204,16 @@ class IntegrationOnboardingView(_BaseOnboardingView):
|
|||
"invalid client id or redirect uri", HTTP_BAD_REQUEST
|
||||
)
|
||||
|
||||
refresh_token = await hass.auth.async_get_refresh_token(refresh_token_id)
|
||||
if refresh_token is None or refresh_token.credential is None:
|
||||
return self.json_message(
|
||||
"Credentials for user not available", HTTP_FORBIDDEN
|
||||
)
|
||||
|
||||
# Return authorization code so we can redirect user and log them in
|
||||
auth_code = hass.components.auth.create_auth_code(data["client_id"], user)
|
||||
auth_code = hass.components.auth.create_auth_code(
|
||||
data["client_id"], refresh_token.credential
|
||||
)
|
||||
return self.json({"auth_code": auth_code})
|
||||
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ async def test_login(hass):
|
|||
result["flow_id"], {"pin": "123456"}
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["data"].id == "mock-user"
|
||||
assert result["data"].id == "mock-id"
|
||||
|
||||
|
||||
async def test_setup_flow(hass):
|
||||
|
|
|
@ -229,7 +229,7 @@ async def test_login_flow_validates_mfa(hass):
|
|||
result["flow_id"], {"code": MOCK_CODE}
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["data"].id == "mock-user"
|
||||
assert result["data"].id == "mock-id"
|
||||
|
||||
|
||||
async def test_setup_user_notify_service(hass):
|
||||
|
|
|
@ -127,7 +127,7 @@ async def test_login_flow_validates_mfa(hass):
|
|||
result["flow_id"], {"code": MOCK_CODE}
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["data"].id == "mock-user"
|
||||
assert result["data"].id == "mock-id"
|
||||
|
||||
|
||||
async def test_race_condition_in_data_loading(hass):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Test the Trusted Networks auth provider."""
|
||||
from ipaddress import ip_address, ip_network
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
@ -142,6 +143,16 @@ async def test_validate_access(provider):
|
|||
provider.async_validate_access(ip_address("2001:db8::ff00:42:8329"))
|
||||
|
||||
|
||||
async def test_validate_refresh_token(provider):
|
||||
"""Verify re-validation of refresh token."""
|
||||
with patch.object(provider, "async_validate_access") as mock:
|
||||
with pytest.raises(tn_auth.InvalidAuthError):
|
||||
provider.async_validate_refresh_token(Mock(), None)
|
||||
|
||||
provider.async_validate_refresh_token(Mock(), "127.0.0.1")
|
||||
mock.assert_called_once_with(ip_address("127.0.0.1"))
|
||||
|
||||
|
||||
async def test_login_flow(manager, provider):
|
||||
"""Test login flow."""
|
||||
owner = await manager.async_create_user("test-owner")
|
||||
|
|
|
@ -37,6 +37,7 @@ async def test_loading_no_group_data_format(hass, hass_storage):
|
|||
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
"version": "1.2.3",
|
||||
},
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
|
@ -87,12 +88,14 @@ async def test_loading_no_group_data_format(hass, hass_storage):
|
|||
assert len(owner.refresh_tokens) == 1
|
||||
owner_token = list(owner.refresh_tokens.values())[0]
|
||||
assert owner_token.id == "user-token-id"
|
||||
assert owner_token.version == "1.2.3"
|
||||
|
||||
assert system.system_generated is True
|
||||
assert system.groups == []
|
||||
assert len(system.refresh_tokens) == 1
|
||||
system_token = list(system.refresh_tokens.values())[0]
|
||||
assert system_token.id == "system-token-id"
|
||||
assert system_token.version is None
|
||||
|
||||
|
||||
async def test_loading_all_access_group_data_format(hass, hass_storage):
|
||||
|
@ -129,6 +132,7 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
|
|||
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
"version": "1.2.3",
|
||||
},
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
|
@ -139,6 +143,7 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
|
|||
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
||||
"token": "some-token",
|
||||
"user_id": "system-id",
|
||||
"version": None,
|
||||
},
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
|
@ -179,12 +184,14 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
|
|||
assert len(owner.refresh_tokens) == 1
|
||||
owner_token = list(owner.refresh_tokens.values())[0]
|
||||
assert owner_token.id == "user-token-id"
|
||||
assert owner_token.version == "1.2.3"
|
||||
|
||||
assert system.system_generated is True
|
||||
assert system.groups == []
|
||||
assert len(system.refresh_tokens) == 1
|
||||
system_token = list(system.refresh_tokens.values())[0]
|
||||
assert system_token.id == "system-token-id"
|
||||
assert system_token.version is None
|
||||
|
||||
|
||||
async def test_loading_empty_data(hass, hass_storage):
|
||||
|
|
|
@ -7,7 +7,12 @@ import pytest
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant.auth import auth_store, const as auth_const, models as auth_models
|
||||
from homeassistant.auth import (
|
||||
InvalidAuthError,
|
||||
auth_store,
|
||||
const as auth_const,
|
||||
models as auth_models,
|
||||
)
|
||||
from homeassistant.auth.const import MFA_SESSION_EXPIRATION
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
@ -162,7 +167,10 @@ async def test_create_new_user(hass):
|
|||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||
)
|
||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
user = step["result"]
|
||||
credential = step["result"]
|
||||
assert credential is not None
|
||||
|
||||
user = await manager.async_get_or_create_user(credential)
|
||||
assert user is not None
|
||||
assert user.is_owner is False
|
||||
assert user.name == "Test Name"
|
||||
|
@ -229,7 +237,8 @@ async def test_login_as_existing_user(mock_hass):
|
|||
)
|
||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
user = step["result"]
|
||||
credential = step["result"]
|
||||
user = await manager.async_get_user_by_credentials(credential)
|
||||
assert user is not None
|
||||
assert user.id == "mock-user"
|
||||
assert user.is_owner is False
|
||||
|
@ -259,7 +268,8 @@ async def test_linking_user_to_two_auth_providers(hass, hass_storage):
|
|||
step = await manager.login_flow.async_configure(
|
||||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||
)
|
||||
user = step["result"]
|
||||
credential = step["result"]
|
||||
user = await manager.async_get_or_create_user(credential)
|
||||
assert user is not None
|
||||
|
||||
step = await manager.login_flow.async_init(
|
||||
|
@ -293,13 +303,19 @@ async def test_saving_loading(hass, hass_storage):
|
|||
step = await manager.login_flow.async_configure(
|
||||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||
)
|
||||
user = step["result"]
|
||||
credential = step["result"]
|
||||
user = await manager.async_get_or_create_user(credential)
|
||||
|
||||
await manager.async_activate_user(user)
|
||||
# the first refresh token will be used to create access token
|
||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user, CLIENT_ID, credential=credential
|
||||
)
|
||||
manager.async_create_access_token(refresh_token, "192.168.0.1")
|
||||
# the second refresh token will not be used
|
||||
await manager.async_create_refresh_token(user, "dummy-client")
|
||||
await manager.async_create_refresh_token(
|
||||
user, "dummy-client", credential=credential
|
||||
)
|
||||
|
||||
await flush_store(manager._store._store)
|
||||
|
||||
|
@ -452,6 +468,46 @@ async def test_refresh_token_type_long_lived_access_token(hass):
|
|||
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
|
||||
|
||||
async def test_refresh_token_provider_validation(mock_hass):
|
||||
"""Test that creating access token from refresh token checks with provider."""
|
||||
manager = await auth.auth_manager_from_config(
|
||||
mock_hass,
|
||||
[
|
||||
{
|
||||
"type": "insecure_example",
|
||||
"users": [{"username": "test-user", "password": "test-pass"}],
|
||||
}
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
credential = auth_models.Credentials(
|
||||
id="mock-credential-id",
|
||||
auth_provider_type="insecure_example",
|
||||
auth_provider_id=None,
|
||||
data={"username": "test-user"},
|
||||
is_new=False,
|
||||
)
|
||||
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
user.credentials.append(credential)
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user, CLIENT_ID, credential=credential
|
||||
)
|
||||
ip = "127.0.0.1"
|
||||
|
||||
assert manager.async_create_access_token(refresh_token, ip) is not None
|
||||
|
||||
with patch(
|
||||
"homeassistant.auth.providers.insecure_example.ExampleAuthProvider.async_validate_refresh_token",
|
||||
side_effect=InvalidAuthError("Invalid access"),
|
||||
) as call:
|
||||
with pytest.raises(InvalidAuthError):
|
||||
manager.async_create_access_token(refresh_token, ip)
|
||||
|
||||
call.assert_called_with(refresh_token, ip)
|
||||
|
||||
|
||||
async def test_cannot_deactive_owner(mock_hass):
|
||||
"""Test that we cannot deactivate the owner."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
|
@ -626,14 +682,10 @@ async def test_login_with_auth_module(mock_hass):
|
|||
step["flow_id"], {"pin": "test-pin"}
|
||||
)
|
||||
|
||||
# Finally passed, get user
|
||||
# Finally passed, get credential
|
||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
user = step["result"]
|
||||
assert user is not None
|
||||
assert user.id == "mock-user"
|
||||
assert user.is_owner is False
|
||||
assert user.is_active is False
|
||||
assert user.name == "Paulus"
|
||||
assert step["result"]
|
||||
assert step["result"].id == "mock-id"
|
||||
|
||||
|
||||
async def test_login_with_multi_auth_module(mock_hass):
|
||||
|
@ -703,14 +755,10 @@ async def test_login_with_multi_auth_module(mock_hass):
|
|||
step["flow_id"], {"pin": "test-pin2"}
|
||||
)
|
||||
|
||||
# Finally passed, get user
|
||||
# Finally passed, get credential
|
||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
user = step["result"]
|
||||
assert user is not None
|
||||
assert user.id == "mock-user"
|
||||
assert user.is_owner is False
|
||||
assert user.is_active is False
|
||||
assert user.name == "Paulus"
|
||||
assert step["result"]
|
||||
assert step["result"].id == "mock-id"
|
||||
|
||||
|
||||
async def test_auth_module_expired_session(mock_hass):
|
||||
|
@ -792,7 +840,8 @@ async def test_enable_mfa_for_user(hass, hass_storage):
|
|||
step = await manager.login_flow.async_configure(
|
||||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||
)
|
||||
user = step["result"]
|
||||
credential = step["result"]
|
||||
user = await manager.async_get_or_create_user(credential)
|
||||
assert user is not None
|
||||
|
||||
# new user don't have mfa enabled
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.auth import InvalidAuthError
|
||||
from homeassistant.auth.models import Credentials
|
||||
from homeassistant.components import auth
|
||||
from homeassistant.components.auth import RESULT_TYPE_USER
|
||||
|
@ -13,6 +14,24 @@ from . import async_setup_auth
|
|||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
|
||||
|
||||
|
||||
async def async_setup_user_refresh_token(hass):
|
||||
"""Create a testing user with a connected credential."""
|
||||
user = await hass.auth.async_create_user("Test User")
|
||||
|
||||
credential = Credentials(
|
||||
id="mock-credential-id",
|
||||
auth_provider_type="insecure_example",
|
||||
auth_provider_id=None,
|
||||
data={"username": "test-user"},
|
||||
is_new=False,
|
||||
)
|
||||
user.credentials.append(credential)
|
||||
|
||||
return await hass.auth.async_create_refresh_token(
|
||||
user, CLIENT_ID, credential=credential
|
||||
)
|
||||
|
||||
|
||||
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
||||
"""Test logging in with new user and refreshing tokens."""
|
||||
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
||||
|
@ -107,12 +126,6 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
|||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
user = refresh_token.user
|
||||
credential = Credentials(
|
||||
auth_provider_type="homeassistant", auth_provider_id=None, data={}, id="test-id"
|
||||
)
|
||||
user.credentials.append(credential)
|
||||
assert len(user.credentials) == 1
|
||||
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({"id": 5, "type": auth.WS_TYPE_CURRENT_USER})
|
||||
|
@ -185,8 +198,7 @@ async def test_refresh_token_system_generated(hass, aiohttp_client):
|
|||
async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
||||
"""Test that we verify client ID."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
user = await hass.auth.async_create_user("Test User")
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)
|
||||
refresh_token = await async_setup_user_refresh_token(hass)
|
||||
|
||||
# No client ID
|
||||
resp = await client.post(
|
||||
|
@ -229,11 +241,37 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
|||
)
|
||||
|
||||
|
||||
async def test_refresh_token_provider_rejected(
|
||||
hass, aiohttp_client, hass_admin_user, hass_admin_credential
|
||||
):
|
||||
"""Test that we verify client ID."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
refresh_token = await async_setup_user_refresh_token(hass)
|
||||
|
||||
# Rejected by provider
|
||||
with patch(
|
||||
"homeassistant.auth.providers.insecure_example.ExampleAuthProvider.async_validate_refresh_token",
|
||||
side_effect=InvalidAuthError("Invalid access"),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/auth/token",
|
||||
data={
|
||||
"client_id": CLIENT_ID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token.token,
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 403
|
||||
result = await resp.json()
|
||||
assert result["error"] == "access_denied"
|
||||
assert result["error_description"] == "Invalid access"
|
||||
|
||||
|
||||
async def test_revoking_refresh_token(hass, aiohttp_client):
|
||||
"""Test that we can revoke refresh tokens."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
user = await hass.auth.async_create_user("Test User")
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)
|
||||
refresh_token = await async_setup_user_refresh_token(hass)
|
||||
|
||||
# Test that we can create an access token
|
||||
resp = await client.post(
|
||||
|
|
|
@ -48,7 +48,9 @@ async def test_list(hass, hass_ws_client, hass_admin_user):
|
|||
id="hij", name="Inactive User", is_active=False, groups=[group]
|
||||
).add_to_hass(hass)
|
||||
|
||||
refresh_token = await hass.auth.async_create_refresh_token(owner, CLIENT_ID)
|
||||
refresh_token = await hass.auth.async_create_refresh_token(
|
||||
owner, CLIENT_ID, credential=owner.credentials[0]
|
||||
)
|
||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
client = await hass_ws_client(hass, access_token)
|
||||
|
@ -60,13 +62,13 @@ async def test_list(hass, hass_ws_client, hass_admin_user):
|
|||
assert len(data) == 4
|
||||
assert data[0] == {
|
||||
"id": hass_admin_user.id,
|
||||
"username": None,
|
||||
"username": "admin",
|
||||
"name": "Mock User",
|
||||
"is_owner": False,
|
||||
"is_active": True,
|
||||
"system_generated": False,
|
||||
"group_ids": [group.id for group in hass_admin_user.groups],
|
||||
"credentials": [],
|
||||
"credentials": [{"type": "homeassistant"}],
|
||||
}
|
||||
assert data[1] == {
|
||||
"id": owner.id,
|
||||
|
|
|
@ -4,24 +4,19 @@ import pytest
|
|||
from homeassistant.auth.providers import homeassistant as prov_ha
|
||||
from homeassistant.components.config import auth_provider_homeassistant as auth_ha
|
||||
|
||||
from tests.common import CLIENT_ID, MockUser, register_auth_provider
|
||||
from tests.common import CLIENT_ID, MockUser
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_config(hass):
|
||||
"""Fixture that sets up the auth provider homeassistant module."""
|
||||
hass.loop.run_until_complete(
|
||||
register_auth_provider(hass, {"type": "homeassistant"})
|
||||
)
|
||||
hass.loop.run_until_complete(auth_ha.async_setup(hass))
|
||||
async def setup_config(hass, local_auth):
|
||||
"""Fixture that sets up the auth provider ."""
|
||||
await auth_ha.async_setup(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_provider(hass):
|
||||
async def auth_provider(local_auth):
|
||||
"""Hass auth provider."""
|
||||
provider = hass.auth.auth_providers[0]
|
||||
await provider.async_initialize()
|
||||
return provider
|
||||
return local_auth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -34,8 +29,8 @@ async def owner_access_token(hass, hass_owner_user):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user_credential(hass, auth_provider):
|
||||
"""Add a test user."""
|
||||
async def hass_admin_credential(hass, auth_provider):
|
||||
"""Overload credentials to admin user."""
|
||||
await hass.async_add_executor_job(
|
||||
auth_provider.data.add_auth, "test-user", "test-pass"
|
||||
)
|
||||
|
@ -124,7 +119,7 @@ async def test_create_auth(hass, hass_ws_client, hass_storage):
|
|||
"id": 5,
|
||||
"type": "config/auth_provider/homeassistant/create",
|
||||
"user_id": user.id,
|
||||
"username": "test-user",
|
||||
"username": "test-user2",
|
||||
"password": "test-pass",
|
||||
}
|
||||
)
|
||||
|
@ -135,10 +130,10 @@ async def test_create_auth(hass, hass_ws_client, hass_storage):
|
|||
creds = user.credentials[0]
|
||||
assert creds.auth_provider_type == "homeassistant"
|
||||
assert creds.auth_provider_id is None
|
||||
assert creds.data == {"username": "test-user"}
|
||||
assert creds.data == {"username": "test-user2"}
|
||||
assert prov_ha.STORAGE_KEY in hass_storage
|
||||
entry = hass_storage[prov_ha.STORAGE_KEY]["data"]["users"][0]
|
||||
assert entry["username"] == "test-user"
|
||||
entry = hass_storage[prov_ha.STORAGE_KEY]["data"]["users"][1]
|
||||
assert entry["username"] == "test-user2"
|
||||
|
||||
|
||||
async def test_create_auth_duplicate_username(hass, hass_ws_client, hass_storage):
|
||||
|
@ -242,7 +237,7 @@ async def test_delete_unknown_auth(hass, hass_ws_client):
|
|||
{
|
||||
"id": 5,
|
||||
"type": "config/auth_provider/homeassistant/delete",
|
||||
"username": "test-user",
|
||||
"username": "test-user2",
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -251,12 +246,8 @@ async def test_delete_unknown_auth(hass, hass_ws_client):
|
|||
assert result["error"]["code"] == "auth_not_found"
|
||||
|
||||
|
||||
async def test_change_password(
|
||||
hass, hass_ws_client, hass_admin_user, auth_provider, test_user_credential
|
||||
):
|
||||
async def test_change_password(hass, hass_ws_client, auth_provider):
|
||||
"""Test that change password succeeds with valid password."""
|
||||
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json(
|
||||
{
|
||||
|
@ -273,10 +264,9 @@ async def test_change_password(
|
|||
|
||||
|
||||
async def test_change_password_wrong_pw(
|
||||
hass, hass_ws_client, hass_admin_user, auth_provider, test_user_credential
|
||||
hass, hass_ws_client, hass_admin_user, auth_provider
|
||||
):
|
||||
"""Test that change password fails with invalid password."""
|
||||
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json(
|
||||
|
@ -295,8 +285,9 @@ async def test_change_password_wrong_pw(
|
|||
await auth_provider.async_validate_login("test-user", "new-pass")
|
||||
|
||||
|
||||
async def test_change_password_no_creds(hass, hass_ws_client):
|
||||
async def test_change_password_no_creds(hass, hass_ws_client, hass_admin_user):
|
||||
"""Test that change password fails with no credentials."""
|
||||
hass_admin_user.credentials.clear()
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json(
|
||||
|
@ -313,9 +304,7 @@ async def test_change_password_no_creds(hass, hass_ws_client):
|
|||
assert result["error"]["code"] == "credentials_not_found"
|
||||
|
||||
|
||||
async def test_admin_change_password_not_owner(
|
||||
hass, hass_ws_client, auth_provider, test_user_credential
|
||||
):
|
||||
async def test_admin_change_password_not_owner(hass, hass_ws_client, auth_provider):
|
||||
"""Test that change password fails when not owner."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
|
@ -358,6 +347,8 @@ async def test_admin_change_password_no_cred(
|
|||
hass, hass_ws_client, owner_access_token, hass_admin_user
|
||||
):
|
||||
"""Test that change password fails with unknown credential."""
|
||||
|
||||
hass_admin_user.credentials.clear()
|
||||
client = await hass_ws_client(hass, owner_access_token)
|
||||
|
||||
await client.send_json(
|
||||
|
@ -379,12 +370,9 @@ async def test_admin_change_password(
|
|||
hass_ws_client,
|
||||
owner_access_token,
|
||||
auth_provider,
|
||||
test_user_credential,
|
||||
hass_admin_user,
|
||||
):
|
||||
"""Test that owners can change any password."""
|
||||
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
|
||||
|
||||
client = await hass_ws_client(hass, owner_access_token)
|
||||
|
||||
await client.send_json(
|
||||
|
|
|
@ -247,7 +247,7 @@ async def test_onboarding_user_race(hass, hass_storage, aiohttp_client):
|
|||
assert sorted([res1.status, res2.status]) == [200, HTTP_FORBIDDEN]
|
||||
|
||||
|
||||
async def test_onboarding_integration(hass, hass_storage, hass_client):
|
||||
async def test_onboarding_integration(hass, hass_storage, hass_client, hass_admin_user):
|
||||
"""Test finishing integration step."""
|
||||
mock_storage(hass_storage, {"done": [const.STEP_USER]})
|
||||
|
||||
|
@ -288,6 +288,28 @@ async def test_onboarding_integration(hass, hass_storage, hass_client):
|
|||
assert len(user.refresh_tokens) == 2, user
|
||||
|
||||
|
||||
async def test_onboarding_integration_missing_credential(
|
||||
hass, hass_storage, hass_client, hass_access_token
|
||||
):
|
||||
"""Test that we fail integration step if user is missing credentials."""
|
||||
mock_storage(hass_storage, {"done": [const.STEP_USER]})
|
||||
|
||||
assert await async_setup_component(hass, "onboarding", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token.credential = None
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
resp = await client.post(
|
||||
"/api/onboarding/integration",
|
||||
json={"client_id": CLIENT_ID, "redirect_uri": CLIENT_REDIRECT_URI},
|
||||
)
|
||||
|
||||
assert resp.status == 403
|
||||
|
||||
|
||||
async def test_onboarding_integration_invalid_redirect_uri(
|
||||
hass, hass_storage, hass_client
|
||||
):
|
||||
|
|
|
@ -14,6 +14,7 @@ import requests_mock as _requests_mock
|
|||
|
||||
from homeassistant import core as ha, loader, runner, util
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import Credentials
|
||||
from homeassistant.auth.providers import homeassistant, legacy_api_password
|
||||
from homeassistant.components import mqtt
|
||||
from homeassistant.components.websocket_api.auth import (
|
||||
|
@ -201,10 +202,20 @@ def mock_device_tracker_conf():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_access_token(hass, hass_admin_user):
|
||||
async def hass_admin_credential(hass, local_auth):
|
||||
"""Provide credentials for admin user."""
|
||||
await hass.async_add_executor_job(local_auth.data.add_auth, "admin", "admin-pass")
|
||||
|
||||
return await local_auth.async_get_or_create_credentials({"username": "admin"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def hass_access_token(hass, hass_admin_user, hass_admin_credential):
|
||||
"""Return an access token to access Home Assistant."""
|
||||
refresh_token = hass.loop.run_until_complete(
|
||||
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID)
|
||||
await hass.auth.async_link_user(hass_admin_user, hass_admin_credential)
|
||||
|
||||
refresh_token = await hass.auth.async_create_refresh_token(
|
||||
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
|
||||
)
|
||||
return hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
|
@ -234,10 +245,21 @@ def hass_read_only_user(hass, local_auth):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_read_only_access_token(hass, hass_read_only_user):
|
||||
def hass_read_only_access_token(hass, hass_read_only_user, local_auth):
|
||||
"""Return a Home Assistant read only user."""
|
||||
credential = Credentials(
|
||||
id="mock-readonly-credential-id",
|
||||
auth_provider_type="homeassistant",
|
||||
auth_provider_id=None,
|
||||
data={"username": "readonly"},
|
||||
is_new=False,
|
||||
)
|
||||
hass_read_only_user.credentials.append(credential)
|
||||
|
||||
refresh_token = hass.loop.run_until_complete(
|
||||
hass.auth.async_create_refresh_token(hass_read_only_user, CLIENT_ID)
|
||||
hass.auth.async_create_refresh_token(
|
||||
hass_read_only_user, CLIENT_ID, credential=credential
|
||||
)
|
||||
)
|
||||
return hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
|
@ -260,6 +282,7 @@ def local_auth(hass):
|
|||
prv = homeassistant.HassAuthProvider(
|
||||
hass, hass.auth._store, {"type": "homeassistant"}
|
||||
)
|
||||
hass.loop.run_until_complete(prv.async_initialize())
|
||||
hass.auth._providers[(prv.type, prv.id)] = prv
|
||||
return prv
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue