Refactoring login flow (#16104)

* Abstract LoginFlow

* Lint and typings
This commit is contained in:
Jason Hu 2018-08-21 11:03:38 -07:00 committed by GitHub
parent cdb8361050
commit 1ce51bfbd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 89 additions and 86 deletions

View file

@ -259,7 +259,7 @@ class AuthManager:
"""Create a login flow."""
auth_provider = self._providers[handler]
return await auth_provider.async_credential_flow(context)
return await auth_provider.async_login_flow(context)
async def _async_finish_login_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \

View file

@ -10,10 +10,11 @@ from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback, HomeAssistant
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry
from homeassistant.auth.auth_store import AuthStore
from homeassistant.auth.models import Credentials, UserMeta
from ..auth_store import AuthStore
from ..models import Credentials, UserMeta
_LOGGER = logging.getLogger(__name__)
DATA_REQS = 'auth_prov_reqs_processed'
@ -80,9 +81,11 @@ class AuthProvider:
# Implement by extending class
async def async_credential_flow(
self, context: Optional[Dict]) -> data_entry_flow.FlowHandler:
"""Return the data flow for logging in with auth provider."""
async def async_login_flow(self, context: Optional[Dict]) -> 'LoginFlow':
"""Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance.
"""
raise NotImplementedError
async def async_get_or_create_credentials(
@ -149,3 +152,30 @@ async def load_auth_provider_module(
processed.add(provider)
return module
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
self.created_at = dt_util.utcnow()
self.user = None
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
"""Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input == None.
Return await self.async_finish(flow_result) if login init step pass.
"""
raise NotImplementedError
async def async_finish(self, flow_result: Any) -> Dict:
"""Handle the pass of login flow."""
return self.async_create_entry(
title=self._auth_provider.name,
data=flow_result
)

View file

@ -3,19 +3,18 @@ import base64
from collections import OrderedDict
import hashlib
import hmac
from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import
from typing import Any, Dict, List, Optional, cast
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.const import CONF_ID
from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.auth.util import generate_secret
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta
from ..util import generate_secret
STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant'
@ -159,10 +158,10 @@ class HassAuthProvider(AuthProvider):
self.data = Data(self.hass)
await self.data.async_load()
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
async def async_login_flow(
self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login."""
return LoginFlow(self)
return HassLoginFlow(self)
async def async_validate_login(self, username: str, password: str) -> None:
"""Helper to validate a username and password."""
@ -207,13 +206,9 @@ class HassAuthProvider(AuthProvider):
pass
class LoginFlow(data_entry_flow.FlowHandler):
class HassLoginFlow(LoginFlow):
"""Handler for the login flow."""
def __init__(self, auth_provider: HassAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
@ -222,16 +217,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
if user_input is not None:
try:
await self._auth_provider.async_validate_login(
user_input['username'], user_input['password'])
await cast(HassAuthProvider, self._auth_provider)\
.async_validate_login(user_input['username'],
user_input['password'])
except InvalidAuth:
errors['base'] = 'invalid_auth'
if not errors:
return self.async_create_entry(
title=self._auth_provider.name,
data=user_input
)
user_input.pop('password')
return await self.async_finish(user_input)
schema = OrderedDict() # type: Dict[str, type]
schema['username'] = str

View file

@ -1,15 +1,14 @@
"""Example auth provider."""
from collections import OrderedDict
import hmac
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, cast
import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError
from homeassistant import data_entry_flow
from homeassistant.core import callback
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta
@ -33,10 +32,9 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login."""
return LoginFlow(self)
return ExampleLoginFlow(self)
@callback
def async_validate_login(self, username: str, password: str) -> None:
@ -90,13 +88,9 @@ class ExampleAuthProvider(AuthProvider):
return UserMeta(name=name, is_active=True)
class LoginFlow(data_entry_flow.FlowHandler):
class ExampleLoginFlow(LoginFlow):
"""Handler for the login flow."""
def __init__(self, auth_provider: ExampleAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
@ -105,16 +99,15 @@ class LoginFlow(data_entry_flow.FlowHandler):
if user_input is not None:
try:
self._auth_provider.async_validate_login(
user_input['username'], user_input['password'])
cast(ExampleAuthProvider, self._auth_provider)\
.async_validate_login(user_input['username'],
user_input['password'])
except InvalidAuthError:
errors['base'] = 'invalid_auth'
if not errors:
return self.async_create_entry(
title=self._auth_provider.name,
data=user_input
)
user_input.pop('password')
return await self.async_finish(user_input)
schema = OrderedDict() # type: Dict[str, type]
schema['username'] = str

View file

@ -3,18 +3,16 @@ Support Legacy API password auth provider.
It will be removed when auth system production ready
"""
from collections import OrderedDict
import hmac
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, cast
import voluptuous as vol
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
from homeassistant.exceptions import HomeAssistantError
from homeassistant import data_entry_flow
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta
@ -39,10 +37,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Legacy API Password'
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login."""
return LoginFlow(self)
return LegacyLoginFlow(self)
@callback
def async_validate_login(self, password: str) -> None:
@ -81,13 +78,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
return UserMeta(name=LEGACY_USER, is_active=True)
class LoginFlow(data_entry_flow.FlowHandler):
class LegacyLoginFlow(LoginFlow):
"""Handler for the login flow."""
def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
@ -96,22 +89,16 @@ class LoginFlow(data_entry_flow.FlowHandler):
if user_input is not None:
try:
self._auth_provider.async_validate_login(
user_input['password'])
cast(LegacyApiPasswordAuthProvider, self._auth_provider)\
.async_validate_login(user_input['password'])
except InvalidAuthError:
errors['base'] = 'invalid_auth'
if not errors:
return self.async_create_entry(
title=self._auth_provider.name,
data={}
)
schema = OrderedDict() # type: Dict[str, type]
schema['password'] = str
return await self.async_finish({})
return self.async_show_form(
step_id='init',
data_schema=vol.Schema(schema),
data_schema=vol.Schema({'password': str}),
errors=errors,
)

View file

@ -7,11 +7,11 @@ from typing import Any, Dict, Optional, cast
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
@ -35,8 +35,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Trusted Networks'
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login."""
assert context is not None
users = await self.store.async_get_users()
@ -44,8 +43,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
for user in users
if not user.system_generated and user.is_active}
return LoginFlow(self, cast(str, context.get('ip_address')),
available_users)
return TrustedNetworksLoginFlow(
self, cast(str, context.get('ip_address')), available_users)
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials:
@ -92,14 +91,14 @@ class TrustedNetworksAuthProvider(AuthProvider):
raise InvalidAuthError('Not in trusted_networks')
class LoginFlow(data_entry_flow.FlowHandler):
class TrustedNetworksLoginFlow(LoginFlow):
"""Handler for the login flow."""
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
ip_address: str, available_users: Dict[str, Optional[str]]) \
-> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
super().__init__(auth_provider)
self._available_users = available_users
self._ip_address = ip_address
@ -109,7 +108,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
"""Handle the step of the form."""
errors = {}
try:
self._auth_provider.async_validate_access(self._ip_address)
cast(TrustedNetworksAuthProvider, self._auth_provider)\
.async_validate_access(self._ip_address)
except InvalidAuthError:
errors['base'] = 'invalid_auth'
@ -125,10 +125,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors['base'] = 'invalid_auth'
if not errors:
return self.async_create_entry(
title=self._auth_provider.name,
data=user_input
)
return await self.async_finish(user_input)
schema = {'user': vol.In(self._available_users)}

View file

@ -4,7 +4,7 @@ from unittest.mock import Mock
import pytest
from homeassistant import data_entry_flow
from homeassistant.auth import auth_manager_from_config
from homeassistant.auth import auth_manager_from_config, auth_store
from homeassistant.auth.providers import (
auth_provider_from_config, homeassistant as hass_auth)
@ -24,7 +24,7 @@ async def test_adding_user(data, hass):
async def test_adding_user_duplicate_username(data, hass):
"""Test adding a user."""
"""Test adding a user with duplicate username."""
data.add_auth('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidUser):
data.add_auth('test-user', 'other-pass')
@ -37,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass):
async def test_validating_password_invalid_password(data, hass):
"""Test validating an invalid user."""
"""Test validating an invalid password."""
data.add_auth('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidAuth):
@ -67,8 +67,9 @@ async def test_login_flow_validates(data, hass):
data.add_auth('test-user', 'test-pass')
await data.async_save()
provider = hass_auth.HassAuthProvider(hass, None, {})
flow = hass_auth.LoginFlow(provider)
provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass),
{'type': 'homeassistant'})
flow = await provider.async_login_flow({})
result = await flow.async_step_init()
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
@ -91,6 +92,7 @@ async def test_login_flow_validates(data, hass):
'password': 'test-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['data']['username'] == 'test-user'
async def test_saving_loading(data, hass):

View file

@ -72,7 +72,7 @@ async def test_login_flow(manager, provider):
user = await manager.async_create_user("test-user")
# trusted network didn't loaded
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
assert step['errors']['base'] == 'invalid_auth'
@ -80,13 +80,13 @@ async def test_login_flow(manager, provider):
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
# not from trusted network
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
assert step['errors']['base'] == 'invalid_auth'
# from trusted network, list users
flow = await provider.async_credential_flow({'ip_address': '192.168.0.1'})
flow = await provider.async_login_flow({'ip_address': '192.168.0.1'})
step = await flow.async_step_init()
assert step['step_id'] == 'init'