Refactoring login flow (#16104)
* Abstract LoginFlow * Lint and typings
This commit is contained in:
parent
cdb8361050
commit
1ce51bfbd6
8 changed files with 89 additions and 86 deletions
|
@ -259,7 +259,7 @@ class AuthManager:
|
||||||
"""Create a login flow."""
|
"""Create a login flow."""
|
||||||
auth_provider = self._providers[handler]
|
auth_provider = self._providers[handler]
|
||||||
|
|
||||||
return await auth_provider.async_credential_flow(context)
|
return await auth_provider.async_login_flow(context)
|
||||||
|
|
||||||
async def _async_finish_login_flow(
|
async def _async_finish_login_flow(
|
||||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \
|
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \
|
||||||
|
|
|
@ -10,10 +10,11 @@ from voluptuous.humanize import humanize_error
|
||||||
from homeassistant import data_entry_flow, requirements
|
from homeassistant import data_entry_flow, requirements
|
||||||
from homeassistant.core import callback, HomeAssistant
|
from homeassistant.core import callback, HomeAssistant
|
||||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.util.decorator import Registry
|
||||||
|
|
||||||
from homeassistant.auth.auth_store import AuthStore
|
from ..auth_store import AuthStore
|
||||||
from homeassistant.auth.models import Credentials, UserMeta
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
DATA_REQS = 'auth_prov_reqs_processed'
|
DATA_REQS = 'auth_prov_reqs_processed'
|
||||||
|
@ -80,9 +81,11 @@ class AuthProvider:
|
||||||
|
|
||||||
# Implement by extending class
|
# Implement by extending class
|
||||||
|
|
||||||
async def async_credential_flow(
|
async def async_login_flow(self, context: Optional[Dict]) -> 'LoginFlow':
|
||||||
self, context: Optional[Dict]) -> data_entry_flow.FlowHandler:
|
"""Return the data flow for logging in with auth provider.
|
||||||
"""Return the data flow for logging in with auth provider."""
|
|
||||||
|
Auth provider should extend LoginFlow and return an instance.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
|
@ -149,3 +152,30 @@ async def load_auth_provider_module(
|
||||||
|
|
||||||
processed.add(provider)
|
processed.add(provider)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
|
def __init__(self, auth_provider: AuthProvider) -> None:
|
||||||
|
"""Initialize the login flow."""
|
||||||
|
self._auth_provider = auth_provider
|
||||||
|
self.created_at = dt_util.utcnow()
|
||||||
|
self.user = None
|
||||||
|
|
||||||
|
async def async_step_init(
|
||||||
|
self, user_input: Optional[Dict[str, str]] = None) \
|
||||||
|
-> Dict[str, Any]:
|
||||||
|
"""Handle the first step of login flow.
|
||||||
|
|
||||||
|
Return self.async_show_form(step_id='init') if user_input == None.
|
||||||
|
Return await self.async_finish(flow_result) if login init step pass.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_finish(self, flow_result: Any) -> Dict:
|
||||||
|
"""Handle the pass of login flow."""
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=self._auth_provider.name,
|
||||||
|
data=flow_result
|
||||||
|
)
|
||||||
|
|
|
@ -3,19 +3,18 @@ import base64
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
|
||||||
from homeassistant.const import CONF_ID
|
from homeassistant.const import CONF_ID
|
||||||
from homeassistant.core import callback, HomeAssistant
|
from homeassistant.core import callback, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from homeassistant.auth.util import generate_secret
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||||
|
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
|
||||||
from ..models import Credentials, UserMeta
|
from ..models import Credentials, UserMeta
|
||||||
|
from ..util import generate_secret
|
||||||
|
|
||||||
|
|
||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION = 1
|
||||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||||
|
@ -159,10 +158,10 @@ class HassAuthProvider(AuthProvider):
|
||||||
self.data = Data(self.hass)
|
self.data = Data(self.hass)
|
||||||
await self.data.async_load()
|
await self.data.async_load()
|
||||||
|
|
||||||
async def async_credential_flow(
|
async def async_login_flow(
|
||||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
self, context: Optional[Dict]) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return LoginFlow(self)
|
return HassLoginFlow(self)
|
||||||
|
|
||||||
async def async_validate_login(self, username: str, password: str) -> None:
|
async def async_validate_login(self, username: str, password: str) -> None:
|
||||||
"""Helper to validate a username and password."""
|
"""Helper to validate a username and password."""
|
||||||
|
@ -207,13 +206,9 @@ class HassAuthProvider(AuthProvider):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class HassLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider: HassAuthProvider) -> None:
|
|
||||||
"""Initialize the login flow."""
|
|
||||||
self._auth_provider = auth_provider
|
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None) \
|
self, user_input: Optional[Dict[str, str]] = None) \
|
||||||
-> Dict[str, Any]:
|
-> Dict[str, Any]:
|
||||||
|
@ -222,16 +217,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
try:
|
try:
|
||||||
await self._auth_provider.async_validate_login(
|
await cast(HassAuthProvider, self._auth_provider)\
|
||||||
user_input['username'], user_input['password'])
|
.async_validate_login(user_input['username'],
|
||||||
|
user_input['password'])
|
||||||
except InvalidAuth:
|
except InvalidAuth:
|
||||||
errors['base'] = 'invalid_auth'
|
errors['base'] = 'invalid_auth'
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
return self.async_create_entry(
|
user_input.pop('password')
|
||||||
title=self._auth_provider.name,
|
return await self.async_finish(user_input)
|
||||||
data=user_input
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = OrderedDict() # type: Dict[str, type]
|
schema = OrderedDict() # type: Dict[str, type]
|
||||||
schema['username'] = str
|
schema['username'] = str
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
"""Example auth provider."""
|
"""Example auth provider."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant import data_entry_flow
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||||
from ..models import Credentials, UserMeta
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,10 +32,9 @@ class InvalidAuthError(HomeAssistantError):
|
||||||
class ExampleAuthProvider(AuthProvider):
|
class ExampleAuthProvider(AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
|
||||||
async def async_credential_flow(
|
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return LoginFlow(self)
|
return ExampleLoginFlow(self)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_login(self, username: str, password: str) -> None:
|
def async_validate_login(self, username: str, password: str) -> None:
|
||||||
|
@ -90,13 +88,9 @@ class ExampleAuthProvider(AuthProvider):
|
||||||
return UserMeta(name=name, is_active=True)
|
return UserMeta(name=name, is_active=True)
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class ExampleLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider: ExampleAuthProvider) -> None:
|
|
||||||
"""Initialize the login flow."""
|
|
||||||
self._auth_provider = auth_provider
|
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None) \
|
self, user_input: Optional[Dict[str, str]] = None) \
|
||||||
-> Dict[str, Any]:
|
-> Dict[str, Any]:
|
||||||
|
@ -105,16 +99,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
try:
|
try:
|
||||||
self._auth_provider.async_validate_login(
|
cast(ExampleAuthProvider, self._auth_provider)\
|
||||||
user_input['username'], user_input['password'])
|
.async_validate_login(user_input['username'],
|
||||||
|
user_input['password'])
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
errors['base'] = 'invalid_auth'
|
errors['base'] = 'invalid_auth'
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
return self.async_create_entry(
|
user_input.pop('password')
|
||||||
title=self._auth_provider.name,
|
return await self.async_finish(user_input)
|
||||||
data=user_input
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = OrderedDict() # type: Dict[str, type]
|
schema = OrderedDict() # type: Dict[str, type]
|
||||||
schema['username'] = str
|
schema['username'] = str
|
||||||
|
|
|
@ -3,18 +3,16 @@ Support Legacy API password auth provider.
|
||||||
|
|
||||||
It will be removed when auth system production ready
|
It will be removed when auth system production ready
|
||||||
"""
|
"""
|
||||||
from collections import OrderedDict
|
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
|
||||||
from homeassistant import data_entry_flow
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||||
from ..models import Credentials, UserMeta
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,10 +37,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||||
|
|
||||||
DEFAULT_TITLE = 'Legacy API Password'
|
DEFAULT_TITLE = 'Legacy API Password'
|
||||||
|
|
||||||
async def async_credential_flow(
|
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return LoginFlow(self)
|
return LegacyLoginFlow(self)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_login(self, password: str) -> None:
|
def async_validate_login(self, password: str) -> None:
|
||||||
|
@ -81,13 +78,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||||
return UserMeta(name=LEGACY_USER, is_active=True)
|
return UserMeta(name=LEGACY_USER, is_active=True)
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LegacyLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None:
|
|
||||||
"""Initialize the login flow."""
|
|
||||||
self._auth_provider = auth_provider
|
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None) \
|
self, user_input: Optional[Dict[str, str]] = None) \
|
||||||
-> Dict[str, Any]:
|
-> Dict[str, Any]:
|
||||||
|
@ -96,22 +89,16 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
try:
|
try:
|
||||||
self._auth_provider.async_validate_login(
|
cast(LegacyApiPasswordAuthProvider, self._auth_provider)\
|
||||||
user_input['password'])
|
.async_validate_login(user_input['password'])
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
errors['base'] = 'invalid_auth'
|
errors['base'] = 'invalid_auth'
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
return self.async_create_entry(
|
return await self.async_finish({})
|
||||||
title=self._auth_provider.name,
|
|
||||||
data={}
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = OrderedDict() # type: Dict[str, type]
|
|
||||||
schema['password'] = str
|
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id='init',
|
step_id='init',
|
||||||
data_schema=vol.Schema(schema),
|
data_schema=vol.Schema({'password': str}),
|
||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,11 +7,11 @@ from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
|
||||||
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
|
||||||
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||||
from ..models import Credentials, UserMeta
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||||
|
@ -35,8 +35,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
|
|
||||||
DEFAULT_TITLE = 'Trusted Networks'
|
DEFAULT_TITLE = 'Trusted Networks'
|
||||||
|
|
||||||
async def async_credential_flow(
|
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
assert context is not None
|
assert context is not None
|
||||||
users = await self.store.async_get_users()
|
users = await self.store.async_get_users()
|
||||||
|
@ -44,8 +43,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
for user in users
|
for user in users
|
||||||
if not user.system_generated and user.is_active}
|
if not user.system_generated and user.is_active}
|
||||||
|
|
||||||
return LoginFlow(self, cast(str, context.get('ip_address')),
|
return TrustedNetworksLoginFlow(
|
||||||
available_users)
|
self, cast(str, context.get('ip_address')), available_users)
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
self, flow_result: Dict[str, str]) -> Credentials:
|
self, flow_result: Dict[str, str]) -> Credentials:
|
||||||
|
@ -92,14 +91,14 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
raise InvalidAuthError('Not in trusted_networks')
|
raise InvalidAuthError('Not in trusted_networks')
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class TrustedNetworksLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
|
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
|
||||||
ip_address: str, available_users: Dict[str, Optional[str]]) \
|
ip_address: str, available_users: Dict[str, Optional[str]]) \
|
||||||
-> None:
|
-> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
self._auth_provider = auth_provider
|
super().__init__(auth_provider)
|
||||||
self._available_users = available_users
|
self._available_users = available_users
|
||||||
self._ip_address = ip_address
|
self._ip_address = ip_address
|
||||||
|
|
||||||
|
@ -109,7 +108,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
try:
|
try:
|
||||||
self._auth_provider.async_validate_access(self._ip_address)
|
cast(TrustedNetworksAuthProvider, self._auth_provider)\
|
||||||
|
.async_validate_access(self._ip_address)
|
||||||
|
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
errors['base'] = 'invalid_auth'
|
errors['base'] = 'invalid_auth'
|
||||||
|
@ -125,10 +125,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
errors['base'] = 'invalid_auth'
|
errors['base'] = 'invalid_auth'
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
return self.async_create_entry(
|
return await self.async_finish(user_input)
|
||||||
title=self._auth_provider.name,
|
|
||||||
data=user_input
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = {'user': vol.In(self._available_users)}
|
schema = {'user': vol.In(self._available_users)}
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from unittest.mock import Mock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.auth import auth_manager_from_config
|
from homeassistant.auth import auth_manager_from_config, auth_store
|
||||||
from homeassistant.auth.providers import (
|
from homeassistant.auth.providers import (
|
||||||
auth_provider_from_config, homeassistant as hass_auth)
|
auth_provider_from_config, homeassistant as hass_auth)
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ async def test_adding_user(data, hass):
|
||||||
|
|
||||||
|
|
||||||
async def test_adding_user_duplicate_username(data, hass):
|
async def test_adding_user_duplicate_username(data, hass):
|
||||||
"""Test adding a user."""
|
"""Test adding a user with duplicate username."""
|
||||||
data.add_auth('test-user', 'test-pass')
|
data.add_auth('test-user', 'test-pass')
|
||||||
with pytest.raises(hass_auth.InvalidUser):
|
with pytest.raises(hass_auth.InvalidUser):
|
||||||
data.add_auth('test-user', 'other-pass')
|
data.add_auth('test-user', 'other-pass')
|
||||||
|
@ -37,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass):
|
||||||
|
|
||||||
|
|
||||||
async def test_validating_password_invalid_password(data, hass):
|
async def test_validating_password_invalid_password(data, hass):
|
||||||
"""Test validating an invalid user."""
|
"""Test validating an invalid password."""
|
||||||
data.add_auth('test-user', 'test-pass')
|
data.add_auth('test-user', 'test-pass')
|
||||||
|
|
||||||
with pytest.raises(hass_auth.InvalidAuth):
|
with pytest.raises(hass_auth.InvalidAuth):
|
||||||
|
@ -67,8 +67,9 @@ async def test_login_flow_validates(data, hass):
|
||||||
data.add_auth('test-user', 'test-pass')
|
data.add_auth('test-user', 'test-pass')
|
||||||
await data.async_save()
|
await data.async_save()
|
||||||
|
|
||||||
provider = hass_auth.HassAuthProvider(hass, None, {})
|
provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass),
|
||||||
flow = hass_auth.LoginFlow(provider)
|
{'type': 'homeassistant'})
|
||||||
|
flow = await provider.async_login_flow({})
|
||||||
result = await flow.async_step_init()
|
result = await flow.async_step_init()
|
||||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
|
||||||
|
@ -91,6 +92,7 @@ async def test_login_flow_validates(data, hass):
|
||||||
'password': 'test-pass',
|
'password': 'test-pass',
|
||||||
})
|
})
|
||||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
assert result['data']['username'] == 'test-user'
|
||||||
|
|
||||||
|
|
||||||
async def test_saving_loading(data, hass):
|
async def test_saving_loading(data, hass):
|
||||||
|
|
|
@ -72,7 +72,7 @@ async def test_login_flow(manager, provider):
|
||||||
user = await manager.async_create_user("test-user")
|
user = await manager.async_create_user("test-user")
|
||||||
|
|
||||||
# trusted network didn't loaded
|
# trusted network didn't loaded
|
||||||
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
||||||
step = await flow.async_step_init()
|
step = await flow.async_step_init()
|
||||||
assert step['step_id'] == 'init'
|
assert step['step_id'] == 'init'
|
||||||
assert step['errors']['base'] == 'invalid_auth'
|
assert step['errors']['base'] == 'invalid_auth'
|
||||||
|
@ -80,13 +80,13 @@ async def test_login_flow(manager, provider):
|
||||||
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
||||||
|
|
||||||
# not from trusted network
|
# not from trusted network
|
||||||
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
||||||
step = await flow.async_step_init()
|
step = await flow.async_step_init()
|
||||||
assert step['step_id'] == 'init'
|
assert step['step_id'] == 'init'
|
||||||
assert step['errors']['base'] == 'invalid_auth'
|
assert step['errors']['base'] == 'invalid_auth'
|
||||||
|
|
||||||
# from trusted network, list users
|
# from trusted network, list users
|
||||||
flow = await provider.async_credential_flow({'ip_address': '192.168.0.1'})
|
flow = await provider.async_login_flow({'ip_address': '192.168.0.1'})
|
||||||
step = await flow.async_step_init()
|
step = await flow.async_step_init()
|
||||||
assert step['step_id'] == 'init'
|
assert step['step_id'] == 'init'
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue