From 1ce51bfbd69110cb09562104799168533e73399f Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Tue, 21 Aug 2018 11:03:38 -0700 Subject: [PATCH] Refactoring login flow (#16104) * Abstract LoginFlow * Lint and typings --- homeassistant/auth/__init__.py | 2 +- homeassistant/auth/providers/__init__.py | 40 ++++++++++++++++--- homeassistant/auth/providers/homeassistant.py | 32 ++++++--------- .../auth/providers/insecure_example.py | 27 +++++-------- .../auth/providers/legacy_api_password.py | 33 +++++---------- .../auth/providers/trusted_networks.py | 23 +++++------ tests/auth/providers/test_homeassistant.py | 12 +++--- tests/auth/providers/test_trusted_networks.py | 6 +-- 8 files changed, 89 insertions(+), 86 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 51c9c25b474..3b61229d59a 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -259,7 +259,7 @@ class AuthManager: """Create a login flow.""" auth_provider = self._providers[handler] - return await auth_provider.async_credential_flow(context) + return await auth_provider.async_login_flow(context) async def _async_finish_login_flow( self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \ diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index 328d83343d7..b2338a8d6ea 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -10,10 +10,11 @@ from voluptuous.humanize import humanize_error from homeassistant import data_entry_flow, requirements from homeassistant.core import callback, HomeAssistant from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID +from homeassistant.util import dt as dt_util from homeassistant.util.decorator import Registry -from homeassistant.auth.auth_store import AuthStore -from homeassistant.auth.models import Credentials, UserMeta +from ..auth_store import AuthStore +from ..models import Credentials, UserMeta _LOGGER = logging.getLogger(__name__) DATA_REQS = 'auth_prov_reqs_processed' @@ -80,9 +81,11 @@ class AuthProvider: # Implement by extending class - async def async_credential_flow( - self, context: Optional[Dict]) -> data_entry_flow.FlowHandler: - """Return the data flow for logging in with auth provider.""" + async def async_login_flow(self, context: Optional[Dict]) -> 'LoginFlow': + """Return the data flow for logging in with auth provider. + + Auth provider should extend LoginFlow and return an instance. + """ raise NotImplementedError async def async_get_or_create_credentials( @@ -149,3 +152,30 @@ async def load_auth_provider_module( processed.add(provider) return module + + +class LoginFlow(data_entry_flow.FlowHandler): + """Handler for the login flow.""" + + def __init__(self, auth_provider: AuthProvider) -> None: + """Initialize the login flow.""" + self._auth_provider = auth_provider + self.created_at = dt_util.utcnow() + self.user = None + + async def async_step_init( + self, user_input: Optional[Dict[str, str]] = None) \ + -> Dict[str, Any]: + """Handle the first step of login flow. + + Return self.async_show_form(step_id='init') if user_input == None. + Return await self.async_finish(flow_result) if login init step pass. + """ + raise NotImplementedError + + async def async_finish(self, flow_result: Any) -> Dict: + """Handle the pass of login flow.""" + return self.async_create_entry( + title=self._auth_provider.name, + data=flow_result + ) diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index a2d91767b95..29be774cf8a 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -3,19 +3,18 @@ import base64 from collections import OrderedDict import hashlib import hmac -from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import +from typing import Any, Dict, List, Optional, cast import voluptuous as vol -from homeassistant import data_entry_flow from homeassistant.const import CONF_ID from homeassistant.core import callback, HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.auth.util import generate_secret - -from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS +from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow from ..models import Credentials, UserMeta +from ..util import generate_secret + STORAGE_VERSION = 1 STORAGE_KEY = 'auth_provider.homeassistant' @@ -159,10 +158,10 @@ class HassAuthProvider(AuthProvider): self.data = Data(self.hass) await self.data.async_load() - async def async_credential_flow( - self, context: Optional[Dict]) -> 'LoginFlow': + async def async_login_flow( + self, context: Optional[Dict]) -> LoginFlow: """Return a flow to login.""" - return LoginFlow(self) + return HassLoginFlow(self) async def async_validate_login(self, username: str, password: str) -> None: """Helper to validate a username and password.""" @@ -207,13 +206,9 @@ class HassAuthProvider(AuthProvider): pass -class LoginFlow(data_entry_flow.FlowHandler): +class HassLoginFlow(LoginFlow): """Handler for the login flow.""" - def __init__(self, auth_provider: HassAuthProvider) -> None: - """Initialize the login flow.""" - self._auth_provider = auth_provider - async def async_step_init( self, user_input: Optional[Dict[str, str]] = None) \ -> Dict[str, Any]: @@ -222,16 +217,15 @@ class LoginFlow(data_entry_flow.FlowHandler): if user_input is not None: try: - await self._auth_provider.async_validate_login( - user_input['username'], user_input['password']) + await cast(HassAuthProvider, self._auth_provider)\ + .async_validate_login(user_input['username'], + user_input['password']) except InvalidAuth: errors['base'] = 'invalid_auth' if not errors: - return self.async_create_entry( - title=self._auth_provider.name, - data=user_input - ) + user_input.pop('password') + return await self.async_finish(user_input) schema = OrderedDict() # type: Dict[str, type] schema['username'] = str diff --git a/homeassistant/auth/providers/insecure_example.py b/homeassistant/auth/providers/insecure_example.py index a4f411e69e0..d267ccb7a1c 100644 --- a/homeassistant/auth/providers/insecure_example.py +++ b/homeassistant/auth/providers/insecure_example.py @@ -1,15 +1,14 @@ """Example auth provider.""" from collections import OrderedDict import hmac -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast import voluptuous as vol from homeassistant.exceptions import HomeAssistantError -from homeassistant import data_entry_flow from homeassistant.core import callback -from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS +from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow from ..models import Credentials, UserMeta @@ -33,10 +32,9 @@ class InvalidAuthError(HomeAssistantError): class ExampleAuthProvider(AuthProvider): """Example auth provider based on hardcoded usernames and passwords.""" - async def async_credential_flow( - self, context: Optional[Dict]) -> 'LoginFlow': + async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: """Return a flow to login.""" - return LoginFlow(self) + return ExampleLoginFlow(self) @callback def async_validate_login(self, username: str, password: str) -> None: @@ -90,13 +88,9 @@ class ExampleAuthProvider(AuthProvider): return UserMeta(name=name, is_active=True) -class LoginFlow(data_entry_flow.FlowHandler): +class ExampleLoginFlow(LoginFlow): """Handler for the login flow.""" - def __init__(self, auth_provider: ExampleAuthProvider) -> None: - """Initialize the login flow.""" - self._auth_provider = auth_provider - async def async_step_init( self, user_input: Optional[Dict[str, str]] = None) \ -> Dict[str, Any]: @@ -105,16 +99,15 @@ class LoginFlow(data_entry_flow.FlowHandler): if user_input is not None: try: - self._auth_provider.async_validate_login( - user_input['username'], user_input['password']) + cast(ExampleAuthProvider, self._auth_provider)\ + .async_validate_login(user_input['username'], + user_input['password']) except InvalidAuthError: errors['base'] = 'invalid_auth' if not errors: - return self.async_create_entry( - title=self._auth_provider.name, - data=user_input - ) + user_input.pop('password') + return await self.async_finish(user_input) schema = OrderedDict() # type: Dict[str, type] schema['username'] = str diff --git a/homeassistant/auth/providers/legacy_api_password.py b/homeassistant/auth/providers/legacy_api_password.py index 064cfc046bd..dffe458976c 100644 --- a/homeassistant/auth/providers/legacy_api_password.py +++ b/homeassistant/auth/providers/legacy_api_password.py @@ -3,18 +3,16 @@ Support Legacy API password auth provider. It will be removed when auth system production ready """ -from collections import OrderedDict import hmac -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast import voluptuous as vol from homeassistant.components.http import HomeAssistantHTTP # noqa: F401 -from homeassistant.exceptions import HomeAssistantError -from homeassistant import data_entry_flow from homeassistant.core import callback +from homeassistant.exceptions import HomeAssistantError -from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS +from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow from ..models import Credentials, UserMeta @@ -39,10 +37,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider): DEFAULT_TITLE = 'Legacy API Password' - async def async_credential_flow( - self, context: Optional[Dict]) -> 'LoginFlow': + async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: """Return a flow to login.""" - return LoginFlow(self) + return LegacyLoginFlow(self) @callback def async_validate_login(self, password: str) -> None: @@ -81,13 +78,9 @@ class LegacyApiPasswordAuthProvider(AuthProvider): return UserMeta(name=LEGACY_USER, is_active=True) -class LoginFlow(data_entry_flow.FlowHandler): +class LegacyLoginFlow(LoginFlow): """Handler for the login flow.""" - def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None: - """Initialize the login flow.""" - self._auth_provider = auth_provider - async def async_step_init( self, user_input: Optional[Dict[str, str]] = None) \ -> Dict[str, Any]: @@ -96,22 +89,16 @@ class LoginFlow(data_entry_flow.FlowHandler): if user_input is not None: try: - self._auth_provider.async_validate_login( - user_input['password']) + cast(LegacyApiPasswordAuthProvider, self._auth_provider)\ + .async_validate_login(user_input['password']) except InvalidAuthError: errors['base'] = 'invalid_auth' if not errors: - return self.async_create_entry( - title=self._auth_provider.name, - data={} - ) - - schema = OrderedDict() # type: Dict[str, type] - schema['password'] = str + return await self.async_finish({}) return self.async_show_form( step_id='init', - data_schema=vol.Schema(schema), + data_schema=vol.Schema({'password': str}), errors=errors, ) diff --git a/homeassistant/auth/providers/trusted_networks.py b/homeassistant/auth/providers/trusted_networks.py index 3233fa5537f..0bc37946e0b 100644 --- a/homeassistant/auth/providers/trusted_networks.py +++ b/homeassistant/auth/providers/trusted_networks.py @@ -7,11 +7,11 @@ from typing import Any, Dict, Optional, cast import voluptuous as vol -from homeassistant import data_entry_flow from homeassistant.components.http import HomeAssistantHTTP # noqa: F401 from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError -from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS + +from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow from ..models import Credentials, UserMeta CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ @@ -35,8 +35,7 @@ class TrustedNetworksAuthProvider(AuthProvider): DEFAULT_TITLE = 'Trusted Networks' - async def async_credential_flow( - self, context: Optional[Dict]) -> 'LoginFlow': + async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: """Return a flow to login.""" assert context is not None users = await self.store.async_get_users() @@ -44,8 +43,8 @@ class TrustedNetworksAuthProvider(AuthProvider): for user in users if not user.system_generated and user.is_active} - return LoginFlow(self, cast(str, context.get('ip_address')), - available_users) + return TrustedNetworksLoginFlow( + self, cast(str, context.get('ip_address')), available_users) async def async_get_or_create_credentials( self, flow_result: Dict[str, str]) -> Credentials: @@ -92,14 +91,14 @@ class TrustedNetworksAuthProvider(AuthProvider): raise InvalidAuthError('Not in trusted_networks') -class LoginFlow(data_entry_flow.FlowHandler): +class TrustedNetworksLoginFlow(LoginFlow): """Handler for the login flow.""" def __init__(self, auth_provider: TrustedNetworksAuthProvider, ip_address: str, available_users: Dict[str, Optional[str]]) \ -> None: """Initialize the login flow.""" - self._auth_provider = auth_provider + super().__init__(auth_provider) self._available_users = available_users self._ip_address = ip_address @@ -109,7 +108,8 @@ class LoginFlow(data_entry_flow.FlowHandler): """Handle the step of the form.""" errors = {} try: - self._auth_provider.async_validate_access(self._ip_address) + cast(TrustedNetworksAuthProvider, self._auth_provider)\ + .async_validate_access(self._ip_address) except InvalidAuthError: errors['base'] = 'invalid_auth' @@ -125,10 +125,7 @@ class LoginFlow(data_entry_flow.FlowHandler): errors['base'] = 'invalid_auth' if not errors: - return self.async_create_entry( - title=self._auth_provider.name, - data=user_input - ) + return await self.async_finish(user_input) schema = {'user': vol.In(self._available_users)} diff --git a/tests/auth/providers/test_homeassistant.py b/tests/auth/providers/test_homeassistant.py index 9db6293d98a..b87f981570e 100644 --- a/tests/auth/providers/test_homeassistant.py +++ b/tests/auth/providers/test_homeassistant.py @@ -4,7 +4,7 @@ from unittest.mock import Mock import pytest from homeassistant import data_entry_flow -from homeassistant.auth import auth_manager_from_config +from homeassistant.auth import auth_manager_from_config, auth_store from homeassistant.auth.providers import ( auth_provider_from_config, homeassistant as hass_auth) @@ -24,7 +24,7 @@ async def test_adding_user(data, hass): async def test_adding_user_duplicate_username(data, hass): - """Test adding a user.""" + """Test adding a user with duplicate username.""" data.add_auth('test-user', 'test-pass') with pytest.raises(hass_auth.InvalidUser): data.add_auth('test-user', 'other-pass') @@ -37,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass): async def test_validating_password_invalid_password(data, hass): - """Test validating an invalid user.""" + """Test validating an invalid password.""" data.add_auth('test-user', 'test-pass') with pytest.raises(hass_auth.InvalidAuth): @@ -67,8 +67,9 @@ async def test_login_flow_validates(data, hass): data.add_auth('test-user', 'test-pass') await data.async_save() - provider = hass_auth.HassAuthProvider(hass, None, {}) - flow = hass_auth.LoginFlow(provider) + provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass), + {'type': 'homeassistant'}) + flow = await provider.async_login_flow({}) result = await flow.async_step_init() assert result['type'] == data_entry_flow.RESULT_TYPE_FORM @@ -91,6 +92,7 @@ async def test_login_flow_validates(data, hass): 'password': 'test-pass', }) assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result['data']['username'] == 'test-user' async def test_saving_loading(data, hass): diff --git a/tests/auth/providers/test_trusted_networks.py b/tests/auth/providers/test_trusted_networks.py index ca8b5bd90a2..5a7021a647a 100644 --- a/tests/auth/providers/test_trusted_networks.py +++ b/tests/auth/providers/test_trusted_networks.py @@ -72,7 +72,7 @@ async def test_login_flow(manager, provider): user = await manager.async_create_user("test-user") # trusted network didn't loaded - flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'}) + flow = await provider.async_login_flow({'ip_address': '127.0.0.1'}) step = await flow.async_step_init() assert step['step_id'] == 'init' assert step['errors']['base'] == 'invalid_auth' @@ -80,13 +80,13 @@ async def test_login_flow(manager, provider): provider.hass.http = Mock(trusted_networks=['192.168.0.1']) # not from trusted network - flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'}) + flow = await provider.async_login_flow({'ip_address': '127.0.0.1'}) step = await flow.async_step_init() assert step['step_id'] == 'init' assert step['errors']['base'] == 'invalid_auth' # from trusted network, list users - flow = await provider.async_credential_flow({'ip_address': '192.168.0.1'}) + flow = await provider.async_login_flow({'ip_address': '192.168.0.1'}) step = await flow.async_step_init() assert step['step_id'] == 'init'