parent
63614a477a
commit
563588651c
5 changed files with 23 additions and 23 deletions
|
@ -137,8 +137,9 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
|
||||||
# user_input has been validate in caller
|
# user_input has been validate in caller
|
||||||
|
# set INPUT_FIELD_CODE as vol.Required is not user friendly
|
||||||
return await self.hass.async_add_executor_job(
|
return await self.hass.async_add_executor_job(
|
||||||
self._validate_2fa, user_id, user_input[INPUT_FIELD_CODE])
|
self._validate_2fa, user_id, user_input.get(INPUT_FIELD_CODE, ''))
|
||||||
|
|
||||||
def _validate_2fa(self, user_id: str, code: str) -> bool:
|
def _validate_2fa(self, user_id: str, code: str) -> bool:
|
||||||
"""Validate two factor authentication code."""
|
"""Validate two factor authentication code."""
|
||||||
|
|
|
@ -224,19 +224,27 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
expires = self.created_at + SESSION_EXPIRATION
|
expires = self.created_at + SESSION_EXPIRATION
|
||||||
if dt_util.utcnow() > expires:
|
if dt_util.utcnow() > expires:
|
||||||
errors['base'] = 'login_expired'
|
return self.async_abort(
|
||||||
else:
|
reason='login_expired'
|
||||||
result = await auth_module.async_validation(
|
)
|
||||||
self.user.id, user_input) # type: ignore
|
|
||||||
if not result:
|
result = await auth_module.async_validation(
|
||||||
errors['base'] = 'invalid_auth'
|
self.user.id, user_input) # type: ignore
|
||||||
|
if not result:
|
||||||
|
errors['base'] = 'invalid_code'
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
return await self.async_finish(self.user)
|
return await self.async_finish(self.user)
|
||||||
|
|
||||||
|
description_placeholders = {
|
||||||
|
'mfa_module_name': auth_module.name,
|
||||||
|
'mfa_module_id': auth_module.id
|
||||||
|
} # type: Dict[str, str]
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id='mfa',
|
step_id='mfa',
|
||||||
data_schema=auth_module.input_schema,
|
data_schema=auth_module.input_schema,
|
||||||
|
description_placeholders=description_placeholders,
|
||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,7 @@ async def test_login(hass):
|
||||||
result = await hass.auth.login_flow.async_configure(
|
result = await hass.auth.login_flow.async_configure(
|
||||||
result['flow_id'], {'pin': 'invalid-code'})
|
result['flow_id'], {'pin': 'invalid-code'})
|
||||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
assert result['errors']['base'] == 'invalid_auth'
|
assert result['errors']['base'] == 'invalid_code'
|
||||||
|
|
||||||
result = await hass.auth.login_flow.async_configure(
|
result = await hass.auth.login_flow.async_configure(
|
||||||
result['flow_id'], {'pin': '123456'})
|
result['flow_id'], {'pin': '123456'})
|
||||||
|
|
|
@ -121,7 +121,7 @@ async def test_login_flow_validates_mfa(hass):
|
||||||
result['flow_id'], {'code': 'invalid-code'})
|
result['flow_id'], {'code': 'invalid-code'})
|
||||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
assert result['step_id'] == 'mfa'
|
assert result['step_id'] == 'mfa'
|
||||||
assert result['errors']['base'] == 'invalid_auth'
|
assert result['errors']['base'] == 'invalid_code'
|
||||||
|
|
||||||
with patch('pyotp.TOTP.verify', return_value=True):
|
with patch('pyotp.TOTP.verify', return_value=True):
|
||||||
result = await hass.auth.login_flow.async_configure(
|
result = await hass.auth.login_flow.async_configure(
|
||||||
|
|
|
@ -428,10 +428,10 @@ async def test_login_with_auth_module(mock_hass):
|
||||||
'pin': 'invalid-pin',
|
'pin': 'invalid-pin',
|
||||||
})
|
})
|
||||||
|
|
||||||
# Invalid auth error
|
# Invalid code error
|
||||||
assert step['type'] == data_entry_flow.RESULT_TYPE_FORM
|
assert step['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
assert step['step_id'] == 'mfa'
|
assert step['step_id'] == 'mfa'
|
||||||
assert step['errors'] == {'base': 'invalid_auth'}
|
assert step['errors'] == {'base': 'invalid_code'}
|
||||||
|
|
||||||
step = await manager.login_flow.async_configure(step['flow_id'], {
|
step = await manager.login_flow.async_configure(step['flow_id'], {
|
||||||
'pin': 'test-pin',
|
'pin': 'test-pin',
|
||||||
|
@ -571,18 +571,9 @@ async def test_auth_module_expired_session(mock_hass):
|
||||||
step = await manager.login_flow.async_configure(step['flow_id'], {
|
step = await manager.login_flow.async_configure(step['flow_id'], {
|
||||||
'pin': 'test-pin',
|
'pin': 'test-pin',
|
||||||
})
|
})
|
||||||
# Invalid auth due session timeout
|
# login flow abort due session timeout
|
||||||
assert step['type'] == data_entry_flow.RESULT_TYPE_FORM
|
assert step['type'] == data_entry_flow.RESULT_TYPE_ABORT
|
||||||
assert step['step_id'] == 'mfa'
|
assert step['reason'] == 'login_expired'
|
||||||
assert step['errors']['base'] == 'login_expired'
|
|
||||||
|
|
||||||
# The second try will fail as well
|
|
||||||
step = await manager.login_flow.async_configure(step['flow_id'], {
|
|
||||||
'pin': 'test-pin',
|
|
||||||
})
|
|
||||||
assert step['type'] == data_entry_flow.RESULT_TYPE_FORM
|
|
||||||
assert step['step_id'] == 'mfa'
|
|
||||||
assert step['errors']['base'] == 'login_expired'
|
|
||||||
|
|
||||||
|
|
||||||
async def test_enable_mfa_for_user(hass, hass_storage):
|
async def test_enable_mfa_for_user(hass, hass_storage):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue