Refactoring login flow (#16104)
* Abstract LoginFlow * Lint and typings
This commit is contained in:
parent
cdb8361050
commit
1ce51bfbd6
8 changed files with 89 additions and 86 deletions
|
@ -259,7 +259,7 @@ class AuthManager:
|
|||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
|
||||
return await auth_provider.async_credential_flow(context)
|
||||
return await auth_provider.async_login_flow(context)
|
||||
|
||||
async def _async_finish_login_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \
|
||||
|
|
|
@ -10,10 +10,11 @@ from voluptuous.humanize import humanize_error
|
|||
from homeassistant import data_entry_flow, requirements
|
||||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from homeassistant.auth.auth_store import AuthStore
|
||||
from homeassistant.auth.models import Credentials, UserMeta
|
||||
from ..auth_store import AuthStore
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS = 'auth_prov_reqs_processed'
|
||||
|
@ -80,9 +81,11 @@ class AuthProvider:
|
|||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> data_entry_flow.FlowHandler:
|
||||
"""Return the data flow for logging in with auth provider."""
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
"""Return the data flow for logging in with auth provider.
|
||||
|
||||
Auth provider should extend LoginFlow and return an instance.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
|
@ -149,3 +152,30 @@ async def load_auth_provider_module(
|
|||
|
||||
processed.add(provider)
|
||||
return module
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider: AuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
self.created_at = dt_util.utcnow()
|
||||
self.user = None
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None) \
|
||||
-> Dict[str, Any]:
|
||||
"""Handle the first step of login flow.
|
||||
|
||||
Return self.async_show_form(step_id='init') if user_input == None.
|
||||
Return await self.async_finish(flow_result) if login init step pass.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_finish(self, flow_result: Any) -> Dict:
|
||||
"""Handle the pass of login flow."""
|
||||
return self.async_create_entry(
|
||||
title=self._auth_provider.name,
|
||||
data=flow_result
|
||||
)
|
||||
|
|
|
@ -3,19 +3,18 @@ import base64
|
|||
from collections import OrderedDict
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.const import CONF_ID
|
||||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from homeassistant.auth.util import generate_secret
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||
from ..models import Credentials, UserMeta
|
||||
from ..util import generate_secret
|
||||
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||
|
@ -159,10 +158,10 @@ class HassAuthProvider(AuthProvider):
|
|||
self.data = Data(self.hass)
|
||||
await self.data.async_load()
|
||||
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
async def async_login_flow(
|
||||
self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
return HassLoginFlow(self)
|
||||
|
||||
async def async_validate_login(self, username: str, password: str) -> None:
|
||||
"""Helper to validate a username and password."""
|
||||
|
@ -207,13 +206,9 @@ class HassAuthProvider(AuthProvider):
|
|||
pass
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
class HassLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider: HassAuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None) \
|
||||
-> Dict[str, Any]:
|
||||
|
@ -222,16 +217,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
if user_input is not None:
|
||||
try:
|
||||
await self._auth_provider.async_validate_login(
|
||||
user_input['username'], user_input['password'])
|
||||
await cast(HassAuthProvider, self._auth_provider)\
|
||||
.async_validate_login(user_input['username'],
|
||||
user_input['password'])
|
||||
except InvalidAuth:
|
||||
errors['base'] = 'invalid_auth'
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(
|
||||
title=self._auth_provider.name,
|
||||
data=user_input
|
||||
)
|
||||
user_input.pop('password')
|
||||
return await self.async_finish(user_input)
|
||||
|
||||
schema = OrderedDict() # type: Dict[str, type]
|
||||
schema['username'] = str
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
"""Example auth provider."""
|
||||
from collections import OrderedDict
|
||||
import hmac
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
|
||||
|
@ -33,10 +32,9 @@ class InvalidAuthError(HomeAssistantError):
|
|||
class ExampleAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
return ExampleLoginFlow(self)
|
||||
|
||||
@callback
|
||||
def async_validate_login(self, username: str, password: str) -> None:
|
||||
|
@ -90,13 +88,9 @@ class ExampleAuthProvider(AuthProvider):
|
|||
return UserMeta(name=name, is_active=True)
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
class ExampleLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider: ExampleAuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None) \
|
||||
-> Dict[str, Any]:
|
||||
|
@ -105,16 +99,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
if user_input is not None:
|
||||
try:
|
||||
self._auth_provider.async_validate_login(
|
||||
user_input['username'], user_input['password'])
|
||||
cast(ExampleAuthProvider, self._auth_provider)\
|
||||
.async_validate_login(user_input['username'],
|
||||
user_input['password'])
|
||||
except InvalidAuthError:
|
||||
errors['base'] = 'invalid_auth'
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(
|
||||
title=self._auth_provider.name,
|
||||
data=user_input
|
||||
)
|
||||
user_input.pop('password')
|
||||
return await self.async_finish(user_input)
|
||||
|
||||
schema = OrderedDict() # type: Dict[str, type]
|
||||
schema['username'] = str
|
||||
|
|
|
@ -3,18 +3,16 @@ Support Legacy API password auth provider.
|
|||
|
||||
It will be removed when auth system production ready
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
import hmac
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
|
||||
|
@ -39,10 +37,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||
|
||||
DEFAULT_TITLE = 'Legacy API Password'
|
||||
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
return LegacyLoginFlow(self)
|
||||
|
||||
@callback
|
||||
def async_validate_login(self, password: str) -> None:
|
||||
|
@ -81,13 +78,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||
return UserMeta(name=LEGACY_USER, is_active=True)
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
class LegacyLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None) \
|
||||
-> Dict[str, Any]:
|
||||
|
@ -96,22 +89,16 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
if user_input is not None:
|
||||
try:
|
||||
self._auth_provider.async_validate_login(
|
||||
user_input['password'])
|
||||
cast(LegacyApiPasswordAuthProvider, self._auth_provider)\
|
||||
.async_validate_login(user_input['password'])
|
||||
except InvalidAuthError:
|
||||
errors['base'] = 'invalid_auth'
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(
|
||||
title=self._auth_provider.name,
|
||||
data={}
|
||||
)
|
||||
|
||||
schema = OrderedDict() # type: Dict[str, type]
|
||||
schema['password'] = str
|
||||
return await self.async_finish({})
|
||||
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=vol.Schema(schema),
|
||||
data_schema=vol.Schema({'password': str}),
|
||||
errors=errors,
|
||||
)
|
||||
|
|
|
@ -7,11 +7,11 @@ from typing import Any, Dict, Optional, cast
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
|
@ -35,8 +35,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
|
||||
DEFAULT_TITLE = 'Trusted Networks'
|
||||
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
assert context is not None
|
||||
users = await self.store.async_get_users()
|
||||
|
@ -44,8 +43,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
for user in users
|
||||
if not user.system_generated and user.is_active}
|
||||
|
||||
return LoginFlow(self, cast(str, context.get('ip_address')),
|
||||
available_users)
|
||||
return TrustedNetworksLoginFlow(
|
||||
self, cast(str, context.get('ip_address')), available_users)
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]) -> Credentials:
|
||||
|
@ -92,14 +91,14 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
raise InvalidAuthError('Not in trusted_networks')
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
class TrustedNetworksLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
|
||||
ip_address: str, available_users: Dict[str, Optional[str]]) \
|
||||
-> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
super().__init__(auth_provider)
|
||||
self._available_users = available_users
|
||||
self._ip_address = ip_address
|
||||
|
||||
|
@ -109,7 +108,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
try:
|
||||
self._auth_provider.async_validate_access(self._ip_address)
|
||||
cast(TrustedNetworksAuthProvider, self._auth_provider)\
|
||||
.async_validate_access(self._ip_address)
|
||||
|
||||
except InvalidAuthError:
|
||||
errors['base'] = 'invalid_auth'
|
||||
|
@ -125,10 +125,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
errors['base'] = 'invalid_auth'
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(
|
||||
title=self._auth_provider.name,
|
||||
data=user_input
|
||||
)
|
||||
return await self.async_finish(user_input)
|
||||
|
||||
schema = {'user': vol.In(self._available_users)}
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth import auth_manager_from_config
|
||||
from homeassistant.auth import auth_manager_from_config, auth_store
|
||||
from homeassistant.auth.providers import (
|
||||
auth_provider_from_config, homeassistant as hass_auth)
|
||||
|
||||
|
@ -24,7 +24,7 @@ async def test_adding_user(data, hass):
|
|||
|
||||
|
||||
async def test_adding_user_duplicate_username(data, hass):
|
||||
"""Test adding a user."""
|
||||
"""Test adding a user with duplicate username."""
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
with pytest.raises(hass_auth.InvalidUser):
|
||||
data.add_auth('test-user', 'other-pass')
|
||||
|
@ -37,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass):
|
|||
|
||||
|
||||
async def test_validating_password_invalid_password(data, hass):
|
||||
"""Test validating an invalid user."""
|
||||
"""Test validating an invalid password."""
|
||||
data.add_auth('test-user', 'test-pass')
|
||||
|
||||
with pytest.raises(hass_auth.InvalidAuth):
|
||||
|
@ -67,8 +67,9 @@ async def test_login_flow_validates(data, hass):
|
|||
data.add_auth('test-user', 'test-pass')
|
||||
await data.async_save()
|
||||
|
||||
provider = hass_auth.HassAuthProvider(hass, None, {})
|
||||
flow = hass_auth.LoginFlow(provider)
|
||||
provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass),
|
||||
{'type': 'homeassistant'})
|
||||
flow = await provider.async_login_flow({})
|
||||
result = await flow.async_step_init()
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
|
||||
|
@ -91,6 +92,7 @@ async def test_login_flow_validates(data, hass):
|
|||
'password': 'test-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result['data']['username'] == 'test-user'
|
||||
|
||||
|
||||
async def test_saving_loading(data, hass):
|
||||
|
|
|
@ -72,7 +72,7 @@ async def test_login_flow(manager, provider):
|
|||
user = await manager.async_create_user("test-user")
|
||||
|
||||
# trusted network didn't loaded
|
||||
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
||||
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
assert step['errors']['base'] == 'invalid_auth'
|
||||
|
@ -80,13 +80,13 @@ async def test_login_flow(manager, provider):
|
|||
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
||||
|
||||
# not from trusted network
|
||||
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
||||
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
assert step['errors']['base'] == 'invalid_auth'
|
||||
|
||||
# from trusted network, list users
|
||||
flow = await provider.async_credential_flow({'ip_address': '192.168.0.1'})
|
||||
flow = await provider.async_login_flow({'ip_address': '192.168.0.1'})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue