Refactoring login flow (#16104)

* Abstract LoginFlow

* Lint and typings
This commit is contained in:
Jason Hu 2018-08-21 11:03:38 -07:00 committed by GitHub
parent cdb8361050
commit 1ce51bfbd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 89 additions and 86 deletions

View file

@ -259,7 +259,7 @@ class AuthManager:
"""Create a login flow.""" """Create a login flow."""
auth_provider = self._providers[handler] auth_provider = self._providers[handler]
return await auth_provider.async_credential_flow(context) return await auth_provider.async_login_flow(context)
async def _async_finish_login_flow( async def _async_finish_login_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \ self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \

View file

@ -10,10 +10,11 @@ from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback, HomeAssistant from homeassistant.core import callback, HomeAssistant
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from homeassistant.auth.auth_store import AuthStore from ..auth_store import AuthStore
from homeassistant.auth.models import Credentials, UserMeta from ..models import Credentials, UserMeta
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_REQS = 'auth_prov_reqs_processed' DATA_REQS = 'auth_prov_reqs_processed'
@ -80,9 +81,11 @@ class AuthProvider:
# Implement by extending class # Implement by extending class
async def async_credential_flow( async def async_login_flow(self, context: Optional[Dict]) -> 'LoginFlow':
self, context: Optional[Dict]) -> data_entry_flow.FlowHandler: """Return the data flow for logging in with auth provider.
"""Return the data flow for logging in with auth provider."""
Auth provider should extend LoginFlow and return an instance.
"""
raise NotImplementedError raise NotImplementedError
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
@ -149,3 +152,30 @@ async def load_auth_provider_module(
processed.add(provider) processed.add(provider)
return module return module
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
self.created_at = dt_util.utcnow()
self.user = None
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
"""Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input == None.
Return await self.async_finish(flow_result) if login init step pass.
"""
raise NotImplementedError
async def async_finish(self, flow_result: Any) -> Dict:
"""Handle the pass of login flow."""
return self.async_create_entry(
title=self._auth_provider.name,
data=flow_result
)

View file

@ -3,19 +3,18 @@ import base64
from collections import OrderedDict from collections import OrderedDict
import hashlib import hashlib
import hmac import hmac
from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import from typing import Any, Dict, List, Optional, cast
import voluptuous as vol import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.const import CONF_ID from homeassistant.const import CONF_ID
from homeassistant.core import callback, HomeAssistant from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.auth.util import generate_secret from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from ..models import Credentials, UserMeta from ..models import Credentials, UserMeta
from ..util import generate_secret
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant' STORAGE_KEY = 'auth_provider.homeassistant'
@ -159,10 +158,10 @@ class HassAuthProvider(AuthProvider):
self.data = Data(self.hass) self.data = Data(self.hass)
await self.data.async_load() await self.data.async_load()
async def async_credential_flow( async def async_login_flow(
self, context: Optional[Dict]) -> 'LoginFlow': self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return LoginFlow(self) return HassLoginFlow(self)
async def async_validate_login(self, username: str, password: str) -> None: async def async_validate_login(self, username: str, password: str) -> None:
"""Helper to validate a username and password.""" """Helper to validate a username and password."""
@ -207,13 +206,9 @@ class HassAuthProvider(AuthProvider):
pass pass
class LoginFlow(data_entry_flow.FlowHandler): class HassLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
def __init__(self, auth_provider: HassAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]: -> Dict[str, Any]:
@ -222,16 +217,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
if user_input is not None: if user_input is not None:
try: try:
await self._auth_provider.async_validate_login( await cast(HassAuthProvider, self._auth_provider)\
user_input['username'], user_input['password']) .async_validate_login(user_input['username'],
user_input['password'])
except InvalidAuth: except InvalidAuth:
errors['base'] = 'invalid_auth' errors['base'] = 'invalid_auth'
if not errors: if not errors:
return self.async_create_entry( user_input.pop('password')
title=self._auth_provider.name, return await self.async_finish(user_input)
data=user_input
)
schema = OrderedDict() # type: Dict[str, type] schema = OrderedDict() # type: Dict[str, type]
schema['username'] = str schema['username'] = str

View file

@ -1,15 +1,14 @@
"""Example auth provider.""" """Example auth provider."""
from collections import OrderedDict from collections import OrderedDict
import hmac import hmac
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant import data_entry_flow
from homeassistant.core import callback from homeassistant.core import callback
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta from ..models import Credentials, UserMeta
@ -33,10 +32,9 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider): class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords.""" """Example auth provider based on hardcoded usernames and passwords."""
async def async_credential_flow( async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
self, context: Optional[Dict]) -> 'LoginFlow':
"""Return a flow to login.""" """Return a flow to login."""
return LoginFlow(self) return ExampleLoginFlow(self)
@callback @callback
def async_validate_login(self, username: str, password: str) -> None: def async_validate_login(self, username: str, password: str) -> None:
@ -90,13 +88,9 @@ class ExampleAuthProvider(AuthProvider):
return UserMeta(name=name, is_active=True) return UserMeta(name=name, is_active=True)
class LoginFlow(data_entry_flow.FlowHandler): class ExampleLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
def __init__(self, auth_provider: ExampleAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]: -> Dict[str, Any]:
@ -105,16 +99,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
if user_input is not None: if user_input is not None:
try: try:
self._auth_provider.async_validate_login( cast(ExampleAuthProvider, self._auth_provider)\
user_input['username'], user_input['password']) .async_validate_login(user_input['username'],
user_input['password'])
except InvalidAuthError: except InvalidAuthError:
errors['base'] = 'invalid_auth' errors['base'] = 'invalid_auth'
if not errors: if not errors:
return self.async_create_entry( user_input.pop('password')
title=self._auth_provider.name, return await self.async_finish(user_input)
data=user_input
)
schema = OrderedDict() # type: Dict[str, type] schema = OrderedDict() # type: Dict[str, type]
schema['username'] = str schema['username'] = str

View file

@ -3,18 +3,16 @@ Support Legacy API password auth provider.
It will be removed when auth system production ready It will be removed when auth system production ready
""" """
from collections import OrderedDict
import hmac import hmac
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401 from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
from homeassistant.exceptions import HomeAssistantError
from homeassistant import data_entry_flow
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta from ..models import Credentials, UserMeta
@ -39,10 +37,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Legacy API Password' DEFAULT_TITLE = 'Legacy API Password'
async def async_credential_flow( async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
self, context: Optional[Dict]) -> 'LoginFlow':
"""Return a flow to login.""" """Return a flow to login."""
return LoginFlow(self) return LegacyLoginFlow(self)
@callback @callback
def async_validate_login(self, password: str) -> None: def async_validate_login(self, password: str) -> None:
@ -81,13 +78,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
return UserMeta(name=LEGACY_USER, is_active=True) return UserMeta(name=LEGACY_USER, is_active=True)
class LoginFlow(data_entry_flow.FlowHandler): class LegacyLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]: -> Dict[str, Any]:
@ -96,22 +89,16 @@ class LoginFlow(data_entry_flow.FlowHandler):
if user_input is not None: if user_input is not None:
try: try:
self._auth_provider.async_validate_login( cast(LegacyApiPasswordAuthProvider, self._auth_provider)\
user_input['password']) .async_validate_login(user_input['password'])
except InvalidAuthError: except InvalidAuthError:
errors['base'] = 'invalid_auth' errors['base'] = 'invalid_auth'
if not errors: if not errors:
return self.async_create_entry( return await self.async_finish({})
title=self._auth_provider.name,
data={}
)
schema = OrderedDict() # type: Dict[str, type]
schema['password'] = str
return self.async_show_form( return self.async_show_form(
step_id='init', step_id='init',
data_schema=vol.Schema(schema), data_schema=vol.Schema({'password': str}),
errors=errors, errors=errors,
) )

View file

@ -7,11 +7,11 @@ from typing import Any, Dict, Optional, cast
import voluptuous as vol import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401 from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta from ..models import Credentials, UserMeta
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
@ -35,8 +35,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Trusted Networks' DEFAULT_TITLE = 'Trusted Networks'
async def async_credential_flow( async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
self, context: Optional[Dict]) -> 'LoginFlow':
"""Return a flow to login.""" """Return a flow to login."""
assert context is not None assert context is not None
users = await self.store.async_get_users() users = await self.store.async_get_users()
@ -44,8 +43,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
for user in users for user in users
if not user.system_generated and user.is_active} if not user.system_generated and user.is_active}
return LoginFlow(self, cast(str, context.get('ip_address')), return TrustedNetworksLoginFlow(
available_users) self, cast(str, context.get('ip_address')), available_users)
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials: self, flow_result: Dict[str, str]) -> Credentials:
@ -92,14 +91,14 @@ class TrustedNetworksAuthProvider(AuthProvider):
raise InvalidAuthError('Not in trusted_networks') raise InvalidAuthError('Not in trusted_networks')
class LoginFlow(data_entry_flow.FlowHandler): class TrustedNetworksLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
def __init__(self, auth_provider: TrustedNetworksAuthProvider, def __init__(self, auth_provider: TrustedNetworksAuthProvider,
ip_address: str, available_users: Dict[str, Optional[str]]) \ ip_address: str, available_users: Dict[str, Optional[str]]) \
-> None: -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
self._auth_provider = auth_provider super().__init__(auth_provider)
self._available_users = available_users self._available_users = available_users
self._ip_address = ip_address self._ip_address = ip_address
@ -109,7 +108,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
try: try:
self._auth_provider.async_validate_access(self._ip_address) cast(TrustedNetworksAuthProvider, self._auth_provider)\
.async_validate_access(self._ip_address)
except InvalidAuthError: except InvalidAuthError:
errors['base'] = 'invalid_auth' errors['base'] = 'invalid_auth'
@ -125,10 +125,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors['base'] = 'invalid_auth' errors['base'] = 'invalid_auth'
if not errors: if not errors:
return self.async_create_entry( return await self.async_finish(user_input)
title=self._auth_provider.name,
data=user_input
)
schema = {'user': vol.In(self._available_users)} schema = {'user': vol.In(self._available_users)}

View file

@ -4,7 +4,7 @@ from unittest.mock import Mock
import pytest import pytest
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth import auth_manager_from_config from homeassistant.auth import auth_manager_from_config, auth_store
from homeassistant.auth.providers import ( from homeassistant.auth.providers import (
auth_provider_from_config, homeassistant as hass_auth) auth_provider_from_config, homeassistant as hass_auth)
@ -24,7 +24,7 @@ async def test_adding_user(data, hass):
async def test_adding_user_duplicate_username(data, hass): async def test_adding_user_duplicate_username(data, hass):
"""Test adding a user.""" """Test adding a user with duplicate username."""
data.add_auth('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidUser): with pytest.raises(hass_auth.InvalidUser):
data.add_auth('test-user', 'other-pass') data.add_auth('test-user', 'other-pass')
@ -37,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass):
async def test_validating_password_invalid_password(data, hass): async def test_validating_password_invalid_password(data, hass):
"""Test validating an invalid user.""" """Test validating an invalid password."""
data.add_auth('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidAuth): with pytest.raises(hass_auth.InvalidAuth):
@ -67,8 +67,9 @@ async def test_login_flow_validates(data, hass):
data.add_auth('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
await data.async_save() await data.async_save()
provider = hass_auth.HassAuthProvider(hass, None, {}) provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass),
flow = hass_auth.LoginFlow(provider) {'type': 'homeassistant'})
flow = await provider.async_login_flow({})
result = await flow.async_step_init() result = await flow.async_step_init()
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
@ -91,6 +92,7 @@ async def test_login_flow_validates(data, hass):
'password': 'test-pass', 'password': 'test-pass',
}) })
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['data']['username'] == 'test-user'
async def test_saving_loading(data, hass): async def test_saving_loading(data, hass):

View file

@ -72,7 +72,7 @@ async def test_login_flow(manager, provider):
user = await manager.async_create_user("test-user") user = await manager.async_create_user("test-user")
# trusted network didn't loaded # trusted network didn't loaded
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'}) flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
step = await flow.async_step_init() step = await flow.async_step_init()
assert step['step_id'] == 'init' assert step['step_id'] == 'init'
assert step['errors']['base'] == 'invalid_auth' assert step['errors']['base'] == 'invalid_auth'
@ -80,13 +80,13 @@ async def test_login_flow(manager, provider):
provider.hass.http = Mock(trusted_networks=['192.168.0.1']) provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
# not from trusted network # not from trusted network
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'}) flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
step = await flow.async_step_init() step = await flow.async_step_init()
assert step['step_id'] == 'init' assert step['step_id'] == 'init'
assert step['errors']['base'] == 'invalid_auth' assert step['errors']['base'] == 'invalid_auth'
# from trusted network, list users # from trusted network, list users
flow = await provider.async_credential_flow({'ip_address': '192.168.0.1'}) flow = await provider.async_login_flow({'ip_address': '192.168.0.1'})
step = await flow.async_step_init() step = await flow.async_step_init()
assert step['step_id'] == 'init' assert step['step_id'] == 'init'