Convert getting and removing access tokens to normal functions (#108670)
This commit is contained in:
parent
904032e944
commit
2eea658fd8
14 changed files with 98 additions and 124 deletions
|
@ -458,23 +458,22 @@ class AuthManager:
|
||||||
credential,
|
credential,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_refresh_token(
|
@callback
|
||||||
self, token_id: str
|
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
|
||||||
) -> models.RefreshToken | None:
|
|
||||||
"""Get refresh token by id."""
|
"""Get refresh token by id."""
|
||||||
return await self._store.async_get_refresh_token(token_id)
|
return self._store.async_get_refresh_token(token_id)
|
||||||
|
|
||||||
async def async_get_refresh_token_by_token(
|
@callback
|
||||||
|
def async_get_refresh_token_by_token(
|
||||||
self, token: str
|
self, token: str
|
||||||
) -> models.RefreshToken | None:
|
) -> models.RefreshToken | None:
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
return await self._store.async_get_refresh_token_by_token(token)
|
return self._store.async_get_refresh_token_by_token(token)
|
||||||
|
|
||||||
async def async_remove_refresh_token(
|
@callback
|
||||||
self, refresh_token: models.RefreshToken
|
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
|
||||||
) -> None:
|
|
||||||
"""Delete a refresh token."""
|
"""Delete a refresh token."""
|
||||||
await self._store.async_remove_refresh_token(refresh_token)
|
self._store.async_remove_refresh_token(refresh_token)
|
||||||
|
|
||||||
callbacks = self._revoke_callbacks.pop(refresh_token.id, ())
|
callbacks = self._revoke_callbacks.pop(refresh_token.id, ())
|
||||||
for revoke_callback in callbacks:
|
for revoke_callback in callbacks:
|
||||||
|
@ -554,16 +553,15 @@ class AuthManager:
|
||||||
if provider := self._async_resolve_provider(refresh_token):
|
if provider := self._async_resolve_provider(refresh_token):
|
||||||
provider.async_validate_refresh_token(refresh_token, remote_ip)
|
provider.async_validate_refresh_token(refresh_token, remote_ip)
|
||||||
|
|
||||||
async def async_validate_access_token(
|
@callback
|
||||||
self, token: str
|
def async_validate_access_token(self, token: str) -> models.RefreshToken | None:
|
||||||
) -> models.RefreshToken | None:
|
|
||||||
"""Return refresh token if an access token is valid."""
|
"""Return refresh token if an access token is valid."""
|
||||||
try:
|
try:
|
||||||
unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
|
unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
refresh_token = await self.async_get_refresh_token(
|
refresh_token = self.async_get_refresh_token(
|
||||||
cast(str, unverif_claims.get("iss"))
|
cast(str, unverif_claims.get("iss"))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -207,18 +207,16 @@ class AuthStore:
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
|
||||||
async def async_remove_refresh_token(
|
@callback
|
||||||
self, refresh_token: models.RefreshToken
|
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
|
||||||
) -> None:
|
|
||||||
"""Remove a refresh token."""
|
"""Remove a refresh token."""
|
||||||
for user in self._users.values():
|
for user in self._users.values():
|
||||||
if user.refresh_tokens.pop(refresh_token.id, None):
|
if user.refresh_tokens.pop(refresh_token.id, None):
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
break
|
break
|
||||||
|
|
||||||
async def async_get_refresh_token(
|
@callback
|
||||||
self, token_id: str
|
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
|
||||||
) -> models.RefreshToken | None:
|
|
||||||
"""Get refresh token by id."""
|
"""Get refresh token by id."""
|
||||||
for user in self._users.values():
|
for user in self._users.values():
|
||||||
refresh_token = user.refresh_tokens.get(token_id)
|
refresh_token = user.refresh_tokens.get(token_id)
|
||||||
|
@ -227,7 +225,8 @@ class AuthStore:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_get_refresh_token_by_token(
|
@callback
|
||||||
|
def async_get_refresh_token_by_token(
|
||||||
self, token: str
|
self, token: str
|
||||||
) -> models.RefreshToken | None:
|
) -> models.RefreshToken | None:
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
|
|
|
@ -124,7 +124,6 @@ as part of a config flow.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
@ -220,12 +219,12 @@ class RevokeTokenView(HomeAssistantView):
|
||||||
if (token := data.get("token")) is None:
|
if (token := data.get("token")) is None:
|
||||||
return web.Response(status=HTTPStatus.OK)
|
return web.Response(status=HTTPStatus.OK)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
|
refresh_token = hass.auth.async_get_refresh_token_by_token(token)
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
return web.Response(status=HTTPStatus.OK)
|
return web.Response(status=HTTPStatus.OK)
|
||||||
|
|
||||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
hass.auth.async_remove_refresh_token(refresh_token)
|
||||||
return web.Response(status=HTTPStatus.OK)
|
return web.Response(status=HTTPStatus.OK)
|
||||||
|
|
||||||
|
|
||||||
|
@ -355,7 +354,7 @@ class TokenView(HomeAssistantView):
|
||||||
{"error": "invalid_request"}, status_code=HTTPStatus.BAD_REQUEST
|
{"error": "invalid_request"}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
|
refresh_token = hass.auth.async_get_refresh_token_by_token(token)
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
return self.json(
|
return self.json(
|
||||||
|
@ -597,7 +596,7 @@ async def websocket_delete_refresh_token(
|
||||||
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
|
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
|
||||||
return
|
return
|
||||||
|
|
||||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
hass.auth.async_remove_refresh_token(refresh_token)
|
||||||
|
|
||||||
connection.send_result(msg["id"], {})
|
connection.send_result(msg["id"], {})
|
||||||
|
|
||||||
|
@ -613,28 +612,23 @@ async def websocket_delete_all_refresh_tokens(
|
||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle delete all refresh tokens request."""
|
"""Handle delete all refresh tokens request."""
|
||||||
tasks = []
|
|
||||||
current_refresh_token: RefreshToken
|
current_refresh_token: RefreshToken
|
||||||
for token in connection.user.refresh_tokens.values():
|
remove_failed = False
|
||||||
|
for token in list(connection.user.refresh_tokens.values()):
|
||||||
if token.id == connection.refresh_token_id:
|
if token.id == connection.refresh_token_id:
|
||||||
# Skip the current refresh token as it has revoke_callback,
|
# Skip the current refresh token as it has revoke_callback,
|
||||||
# which cancels/closes the connection.
|
# which cancels/closes the connection.
|
||||||
# It will be removed after sending the result.
|
# It will be removed after sending the result.
|
||||||
current_refresh_token = token
|
current_refresh_token = token
|
||||||
continue
|
continue
|
||||||
tasks.append(
|
try:
|
||||||
hass.async_create_task(hass.auth.async_remove_refresh_token(token))
|
hass.auth.async_remove_refresh_token(token)
|
||||||
)
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
getLogger(__name__).exception(
|
||||||
remove_failed = False
|
"During refresh token removal, the following error occurred: %s",
|
||||||
if tasks:
|
err,
|
||||||
for result in await asyncio.gather(*tasks, return_exceptions=True):
|
)
|
||||||
if isinstance(result, Exception):
|
remove_failed = True
|
||||||
getLogger(__name__).exception(
|
|
||||||
"During refresh token removal, the following error occurred: %s",
|
|
||||||
result,
|
|
||||||
)
|
|
||||||
remove_failed = True
|
|
||||||
|
|
||||||
if remove_failed:
|
if remove_failed:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
|
@ -643,7 +637,8 @@ async def websocket_delete_all_refresh_tokens(
|
||||||
else:
|
else:
|
||||||
connection.send_result(msg["id"], {})
|
connection.send_result(msg["id"], {})
|
||||||
|
|
||||||
hass.async_create_task(hass.auth.async_remove_refresh_token(current_refresh_token))
|
# This will close the connection so we need to send the result first.
|
||||||
|
hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token)
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
|
|
|
@ -151,7 +151,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
||||||
if auth_type != "Bearer":
|
if auth_type != "Bearer":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(auth_val)
|
refresh_token = hass.auth.async_validate_access_token(auth_val)
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
return False
|
return False
|
||||||
|
@ -189,7 +189,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
||||||
if claims["params"] != params:
|
if claims["params"] != params:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_get_refresh_token(claims["iss"])
|
refresh_token = hass.auth.async_get_refresh_token(claims["iss"])
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -259,7 +259,7 @@ class IntegrationOnboardingView(_BaseOnboardingView):
|
||||||
"invalid client id or redirect uri", HTTPStatus.BAD_REQUEST
|
"invalid client id or redirect uri", HTTPStatus.BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_get_refresh_token(refresh_token_id)
|
refresh_token = hass.auth.async_get_refresh_token(refresh_token_id)
|
||||||
if refresh_token is None or refresh_token.credential is None:
|
if refresh_token is None or refresh_token.credential is None:
|
||||||
return self.json_message(
|
return self.json_message(
|
||||||
"Credentials for user not available", HTTPStatus.FORBIDDEN
|
"Credentials for user not available", HTTPStatus.FORBIDDEN
|
||||||
|
|
|
@ -80,9 +80,7 @@ class AuthPhase:
|
||||||
raise Disconnect from err
|
raise Disconnect from err
|
||||||
|
|
||||||
if (access_token := valid_msg.get("access_token")) and (
|
if (access_token := valid_msg.get("access_token")) and (
|
||||||
refresh_token := await self._hass.auth.async_validate_access_token(
|
refresh_token := self._hass.auth.async_validate_access_token(access_token)
|
||||||
access_token
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
conn = ActiveConnection(
|
conn = ActiveConnection(
|
||||||
self._logger,
|
self._logger,
|
||||||
|
|
|
@ -371,7 +371,7 @@ async def test_cannot_retrieve_expired_access_token(hass: HomeAssistant) -> None
|
||||||
assert refresh_token.client_id == CLIENT_ID
|
assert refresh_token.client_id == CLIENT_ID
|
||||||
|
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
assert await manager.async_validate_access_token(access_token) is refresh_token
|
assert manager.async_validate_access_token(access_token) is refresh_token
|
||||||
|
|
||||||
# We patch time directly here because we want the access token to be created with
|
# We patch time directly here because we want the access token to be created with
|
||||||
# an expired time, but we do not want to freeze time so that jwt will compare it
|
# an expired time, but we do not want to freeze time so that jwt will compare it
|
||||||
|
@ -385,7 +385,7 @@ async def test_cannot_retrieve_expired_access_token(hass: HomeAssistant) -> None
|
||||||
):
|
):
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
assert await manager.async_validate_access_token(access_token) is None
|
assert manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_generating_system_user(hass: HomeAssistant) -> None:
|
async def test_generating_system_user(hass: HomeAssistant) -> None:
|
||||||
|
@ -572,10 +572,10 @@ async def test_remove_refresh_token(mock_hass) -> None:
|
||||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
await manager.async_remove_refresh_token(refresh_token)
|
manager.async_remove_refresh_token(refresh_token)
|
||||||
|
|
||||||
assert await manager.async_get_refresh_token(refresh_token.id) is None
|
assert manager.async_get_refresh_token(refresh_token.id) is None
|
||||||
assert await manager.async_validate_access_token(access_token) is None
|
assert manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_register_revoke_token_callback(mock_hass) -> None:
|
async def test_register_revoke_token_callback(mock_hass) -> None:
|
||||||
|
@ -591,7 +591,7 @@ async def test_register_revoke_token_callback(mock_hass) -> None:
|
||||||
called = True
|
called = True
|
||||||
|
|
||||||
manager.async_register_revoke_token_callback(refresh_token.id, cb)
|
manager.async_register_revoke_token_callback(refresh_token.id, cb)
|
||||||
await manager.async_remove_refresh_token(refresh_token)
|
manager.async_remove_refresh_token(refresh_token)
|
||||||
assert called
|
assert called
|
||||||
|
|
||||||
|
|
||||||
|
@ -610,7 +610,7 @@ async def test_unregister_revoke_token_callback(mock_hass) -> None:
|
||||||
unregister = manager.async_register_revoke_token_callback(refresh_token.id, cb)
|
unregister = manager.async_register_revoke_token_callback(refresh_token.id, cb)
|
||||||
unregister()
|
unregister()
|
||||||
|
|
||||||
await manager.async_remove_refresh_token(refresh_token)
|
manager.async_remove_refresh_token(refresh_token)
|
||||||
assert not called
|
assert not called
|
||||||
|
|
||||||
|
|
||||||
|
@ -664,7 +664,7 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
jwt_key = refresh_token.jwt_key
|
jwt_key = refresh_token.jwt_key
|
||||||
|
|
||||||
rt = await manager.async_validate_access_token(access_token)
|
rt = manager.async_validate_access_token(access_token)
|
||||||
assert rt.id == refresh_token.id
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
@ -675,9 +675,9 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
|
||||||
access_token_expiration=timedelta(days=3000),
|
access_token_expiration=timedelta(days=3000),
|
||||||
)
|
)
|
||||||
|
|
||||||
await manager.async_remove_refresh_token(refresh_token)
|
manager.async_remove_refresh_token(refresh_token)
|
||||||
assert refresh_token.id not in user.refresh_tokens
|
assert refresh_token.id not in user.refresh_tokens
|
||||||
rt = await manager.async_validate_access_token(access_token)
|
rt = manager.async_validate_access_token(access_token)
|
||||||
assert rt is None, "Previous issued access token has been invoked"
|
assert rt is None, "Previous issued access token has been invoked"
|
||||||
|
|
||||||
refresh_token_2 = await manager.async_create_refresh_token(
|
refresh_token_2 = await manager.async_create_refresh_token(
|
||||||
|
@ -694,7 +694,7 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
|
||||||
assert access_token != access_token_2
|
assert access_token != access_token_2
|
||||||
assert jwt_key != jwt_key_2
|
assert jwt_key != jwt_key_2
|
||||||
|
|
||||||
rt = await manager.async_validate_access_token(access_token_2)
|
rt = manager.async_validate_access_token(access_token_2)
|
||||||
jwt_payload = jwt.decode(access_token_2, rt.jwt_key, algorithms=["HS256"])
|
jwt_payload = jwt.decode(access_token_2, rt.jwt_key, algorithms=["HS256"])
|
||||||
assert jwt_payload["iss"] == refresh_token_2.id
|
assert jwt_payload["iss"] == refresh_token_2.id
|
||||||
assert (
|
assert (
|
||||||
|
@ -1144,7 +1144,7 @@ async def test_access_token_with_invalid_signature(mock_hass) -> None:
|
||||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
rt = await manager.async_validate_access_token(access_token)
|
rt = manager.async_validate_access_token(access_token)
|
||||||
assert rt.id == refresh_token.id
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
# Now we corrupt the signature
|
# Now we corrupt the signature
|
||||||
|
@ -1154,7 +1154,7 @@ async def test_access_token_with_invalid_signature(mock_hass) -> None:
|
||||||
|
|
||||||
assert access_token != invalid_token
|
assert access_token != invalid_token
|
||||||
|
|
||||||
result = await manager.async_validate_access_token(invalid_token)
|
result = manager.async_validate_access_token(invalid_token)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1171,7 +1171,7 @@ async def test_access_token_with_null_signature(mock_hass) -> None:
|
||||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
rt = await manager.async_validate_access_token(access_token)
|
rt = manager.async_validate_access_token(access_token)
|
||||||
assert rt.id == refresh_token.id
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
# Now we make the signature all nulls
|
# Now we make the signature all nulls
|
||||||
|
@ -1181,7 +1181,7 @@ async def test_access_token_with_null_signature(mock_hass) -> None:
|
||||||
|
|
||||||
assert access_token != invalid_token
|
assert access_token != invalid_token
|
||||||
|
|
||||||
result = await manager.async_validate_access_token(invalid_token)
|
result = manager.async_validate_access_token(invalid_token)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1198,7 +1198,7 @@ async def test_access_token_with_empty_signature(mock_hass) -> None:
|
||||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
rt = await manager.async_validate_access_token(access_token)
|
rt = manager.async_validate_access_token(access_token)
|
||||||
assert rt.id == refresh_token.id
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
# Now we make the signature all nulls
|
# Now we make the signature all nulls
|
||||||
|
@ -1207,7 +1207,7 @@ async def test_access_token_with_empty_signature(mock_hass) -> None:
|
||||||
|
|
||||||
assert access_token != invalid_token
|
assert access_token != invalid_token
|
||||||
|
|
||||||
result = await manager.async_validate_access_token(invalid_token)
|
result = manager.async_validate_access_token(invalid_token)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1225,17 +1225,17 @@ async def test_access_token_with_empty_key(mock_hass) -> None:
|
||||||
|
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
await manager.async_remove_refresh_token(refresh_token)
|
manager.async_remove_refresh_token(refresh_token)
|
||||||
# Now remove the token from the keyring
|
# Now remove the token from the keyring
|
||||||
# so we will get an empty key
|
# so we will get an empty key
|
||||||
|
|
||||||
assert await manager.async_validate_access_token(access_token) is None
|
assert manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None:
|
async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None:
|
||||||
"""Test rejecting access tokens with impossible sizes."""
|
"""Test rejecting access tokens with impossible sizes."""
|
||||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
assert await manager.async_validate_access_token("a" * 10000) is None
|
assert manager.async_validate_access_token("a" * 10000) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
|
async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
|
||||||
|
@ -1245,7 +1245,7 @@ async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
|
||||||
b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||||
)
|
)
|
||||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
assert await manager.async_validate_access_token(token_with_invalid_json) is None
|
assert manager.async_validate_access_token(token_with_invalid_json) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
|
async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
|
||||||
|
@ -1255,7 +1255,7 @@ async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
|
||||||
b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||||
)
|
)
|
||||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
assert await manager.async_validate_access_token(token_not_a_dict_json) is None
|
assert manager.async_validate_access_token(token_not_a_dict_json) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_access_token_that_expires_soon(mock_hass) -> None:
|
async def test_access_token_that_expires_soon(mock_hass) -> None:
|
||||||
|
@ -1272,11 +1272,11 @@ async def test_access_token_that_expires_soon(mock_hass) -> None:
|
||||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
rt = await manager.async_validate_access_token(access_token)
|
rt = manager.async_validate_access_token(access_token)
|
||||||
assert rt.id == refresh_token.id
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
with freeze_time(now + timedelta(minutes=1)):
|
with freeze_time(now + timedelta(minutes=1)):
|
||||||
assert await manager.async_validate_access_token(access_token) is None
|
assert manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_access_token_from_the_future(mock_hass) -> None:
|
async def test_access_token_from_the_future(mock_hass) -> None:
|
||||||
|
@ -1296,8 +1296,8 @@ async def test_access_token_from_the_future(mock_hass) -> None:
|
||||||
)
|
)
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
assert await manager.async_validate_access_token(access_token) is None
|
assert manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
with freeze_time(now + timedelta(days=365)):
|
with freeze_time(now + timedelta(days=365)):
|
||||||
rt = await manager.async_validate_access_token(access_token)
|
rt = manager.async_validate_access_token(access_token)
|
||||||
assert rt.id == refresh_token.id
|
assert rt.id == refresh_token.id
|
||||||
|
|
|
@ -588,7 +588,7 @@ async def test_api_fire_event_context(
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
assert len(test_value) == 1
|
assert len(test_value) == 1
|
||||||
assert test_value[0].context.user_id == refresh_token.user.id
|
assert test_value[0].context.user_id == refresh_token.user.id
|
||||||
|
@ -606,7 +606,7 @@ async def test_api_call_service_context(
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert calls[0].context.user_id == refresh_token.user.id
|
assert calls[0].context.user_id == refresh_token.user.id
|
||||||
|
@ -622,7 +622,7 @@ async def test_api_set_state_context(
|
||||||
headers={"authorization": f"Bearer {hass_access_token}"},
|
headers={"authorization": f"Bearer {hass_access_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
state = hass.states.get("light.kitchen")
|
state = hass.states.get("light.kitchen")
|
||||||
assert state.context.user_id == refresh_token.user.id
|
assert state.context.user_id == refresh_token.user.id
|
||||||
|
|
|
@ -88,9 +88,7 @@ async def test_login_new_user_and_trying_refresh_token(
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
|
|
||||||
assert (
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
|
||||||
)
|
|
||||||
assert tokens["ha_auth_provider"] == "insecure_example"
|
assert tokens["ha_auth_provider"] == "insecure_example"
|
||||||
|
|
||||||
# Use refresh token to get more tokens.
|
# Use refresh token to get more tokens.
|
||||||
|
@ -106,9 +104,7 @@ async def test_login_new_user_and_trying_refresh_token(
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
assert "refresh_token" not in tokens
|
assert "refresh_token" not in tokens
|
||||||
assert (
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test using access token to hit API.
|
# Test using access token to hit API.
|
||||||
resp = await client.get("/api/")
|
resp = await client.get("/api/")
|
||||||
|
@ -205,7 +201,7 @@ async def test_ws_current_user(
|
||||||
"""Test the current user command with Home Assistant creds."""
|
"""Test the current user command with Home Assistant creds."""
|
||||||
assert await async_setup_component(hass, "auth", {})
|
assert await async_setup_component(hass, "auth", {})
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
user = refresh_token.user
|
user = refresh_token.user
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
|
|
||||||
|
@ -275,9 +271,7 @@ async def test_refresh_token_system_generated(
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
assert (
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_refresh_token_different_client_id(
|
async def test_refresh_token_different_client_id(
|
||||||
|
@ -323,9 +317,7 @@ async def test_refresh_token_different_client_id(
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
assert (
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_refresh_token_checks_local_only_user(
|
async def test_refresh_token_checks_local_only_user(
|
||||||
|
@ -406,16 +398,14 @@ async def test_revoking_refresh_token(
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
assert (
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Revoke refresh token
|
# Revoke refresh token
|
||||||
resp = await client.post(url, data={**base_data, "token": refresh_token.token})
|
resp = await client.post(url, data={**base_data, "token": refresh_token.token})
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
|
|
||||||
# Old access token should be no longer valid
|
# Old access token should be no longer valid
|
||||||
assert await hass.auth.async_validate_access_token(tokens["access_token"]) is None
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is None
|
||||||
|
|
||||||
# Test that we no longer can create an access token
|
# Test that we no longer can create an access token
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
|
@ -454,7 +444,7 @@ async def test_ws_long_lived_access_token(
|
||||||
long_lived_access_token = result["result"]
|
long_lived_access_token = result["result"]
|
||||||
assert long_lived_access_token is not None
|
assert long_lived_access_token is not None
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(long_lived_access_token)
|
refresh_token = hass.auth.async_validate_access_token(long_lived_access_token)
|
||||||
assert refresh_token.client_id is None
|
assert refresh_token.client_id is None
|
||||||
assert refresh_token.client_name == "GPS Logger"
|
assert refresh_token.client_name == "GPS Logger"
|
||||||
assert refresh_token.client_icon is None
|
assert refresh_token.client_icon is None
|
||||||
|
@ -474,7 +464,7 @@ async def test_ws_refresh_tokens(
|
||||||
assert result["success"], result
|
assert result["success"], result
|
||||||
assert len(result["result"]) == 1
|
assert len(result["result"]) == 1
|
||||||
token = result["result"][0]
|
token = result["result"][0]
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
assert token["id"] == refresh_token.id
|
assert token["id"] == refresh_token.id
|
||||||
assert token["type"] == refresh_token.token_type
|
assert token["type"] == refresh_token.token_type
|
||||||
assert token["client_id"] == refresh_token.client_id
|
assert token["client_id"] == refresh_token.client_id
|
||||||
|
@ -514,7 +504,7 @@ async def test_ws_delete_refresh_token(
|
||||||
|
|
||||||
result = await ws_client.receive_json()
|
result = await ws_client.receive_json()
|
||||||
assert result["success"], result
|
assert result["success"], result
|
||||||
refresh_token = await hass.auth.async_get_refresh_token(refresh_token.id)
|
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
|
||||||
assert refresh_token is None
|
assert refresh_token is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -573,7 +563,7 @@ async def test_ws_delete_all_refresh_tokens_error(
|
||||||
) in caplog.record_tuples
|
) in caplog.record_tuples
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
refresh_token = await hass.auth.async_get_refresh_token(token["id"])
|
refresh_token = hass.auth.async_get_refresh_token(token["id"])
|
||||||
assert refresh_token is None
|
assert refresh_token is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -614,7 +604,7 @@ async def test_ws_delete_all_refresh_tokens(
|
||||||
result = await ws_client.receive_json()
|
result = await ws_client.receive_json()
|
||||||
assert result, result["success"]
|
assert result, result["success"]
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
refresh_token = await hass.auth.async_get_refresh_token(token["id"])
|
refresh_token = hass.auth.async_get_refresh_token(token["id"])
|
||||||
assert refresh_token is None
|
assert refresh_token is None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -136,7 +136,7 @@ async def test_delete_unable_self_account(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we cannot delete our own account."""
|
"""Test we cannot delete our own account."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
{"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": refresh_token.user.id}
|
{"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": refresh_token.user.id}
|
||||||
|
|
|
@ -211,7 +211,7 @@ async def test_auth_active_access_with_access_token_in_header(
|
||||||
token = hass_access_token
|
token = hass_access_token
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||||
assert req.status == HTTPStatus.OK
|
assert req.status == HTTPStatus.OK
|
||||||
|
@ -231,7 +231,7 @@ async def test_auth_active_access_with_access_token_in_header(
|
||||||
req = await client.get("/", headers={"Authorization": f"BEARER {token}"})
|
req = await client.get("/", headers={"Authorization": f"BEARER {token}"})
|
||||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
refresh_token.user.is_active = False
|
refresh_token.user.is_active = False
|
||||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
@ -297,7 +297,7 @@ async def test_auth_access_signed_path_with_refresh_token(
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
signed_path = async_sign_path(
|
signed_path = async_sign_path(
|
||||||
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||||
|
@ -325,7 +325,7 @@ async def test_auth_access_signed_path_with_refresh_token(
|
||||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
|
||||||
# refresh token gone should also invalidate signature
|
# refresh token gone should also invalidate signature
|
||||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
hass.auth.async_remove_refresh_token(refresh_token)
|
||||||
req = await client.get(signed_path)
|
req = await client.get(signed_path)
|
||||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
|
||||||
|
@ -342,7 +342,7 @@ async def test_auth_access_signed_path_with_query_param(
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
signed_path = async_sign_path(
|
signed_path = async_sign_path(
|
||||||
hass, "/?test=test", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
hass, "/?test=test", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||||
|
@ -372,7 +372,7 @@ async def test_auth_access_signed_path_with_query_param_order(
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
signed_path = async_sign_path(
|
signed_path = async_sign_path(
|
||||||
hass,
|
hass,
|
||||||
|
@ -413,7 +413,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param(
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
signed_path = async_sign_path(
|
signed_path = async_sign_path(
|
||||||
hass,
|
hass,
|
||||||
|
@ -452,7 +452,7 @@ async def test_auth_access_signed_path_with_query_param_tamper(
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
signed_path = async_sign_path(
|
signed_path = async_sign_path(
|
||||||
hass, base_url, timedelta(seconds=5), refresh_token_id=refresh_token.id
|
hass, base_url, timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||||
|
@ -491,9 +491,7 @@ async def test_auth_access_signed_path_via_websocket(
|
||||||
assert msg["id"] == 5
|
assert msg["id"] == 5
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(
|
refresh_token = hass.auth.async_validate_access_token(hass_read_only_access_token)
|
||||||
hass_read_only_access_token
|
|
||||||
)
|
|
||||||
signature = yarl.URL(msg["result"]["path"]).query["authSig"]
|
signature = yarl.URL(msg["result"]["path"]).query["authSig"]
|
||||||
claims = jwt.decode(
|
claims = jwt.decode(
|
||||||
signature,
|
signature,
|
||||||
|
@ -523,7 +521,7 @@ async def test_auth_access_signed_path_with_http(
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
req = await client.get(
|
req = await client.get(
|
||||||
"/hello", headers={"Authorization": f"Bearer {hass_access_token}"}
|
"/hello", headers={"Authorization": f"Bearer {hass_access_token}"}
|
||||||
|
@ -567,7 +565,7 @@ async def test_local_only_user_rejected(
|
||||||
await async_setup_auth(hass, app)
|
await async_setup_auth(hass, app)
|
||||||
set_mock_ip = mock_real_ip(app)
|
set_mock_ip = mock_real_ip(app)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||||
assert req.status == HTTPStatus.OK
|
assert req.status == HTTPStatus.OK
|
||||||
|
|
|
@ -232,9 +232,7 @@ async def test_onboarding_user(
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
|
|
||||||
assert (
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate created areas
|
# Validate created areas
|
||||||
assert len(area_registry.areas) == 3
|
assert len(area_registry.areas) == 3
|
||||||
|
@ -347,9 +345,7 @@ async def test_onboarding_integration(
|
||||||
assert const.STEP_INTEGRATION in hass_storage[const.DOMAIN]["data"]["done"]
|
assert const.STEP_INTEGRATION in hass_storage[const.DOMAIN]["data"]["done"]
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
|
|
||||||
assert (
|
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Onboarding refresh token and new refresh token
|
# Onboarding refresh token and new refresh token
|
||||||
user = await hass.auth.async_get_user(hass_admin_user.id)
|
user = await hass.auth.async_get_user(hass_admin_user.id)
|
||||||
|
@ -368,7 +364,7 @@ async def test_onboarding_integration_missing_credential(
|
||||||
assert await async_setup_component(hass, "onboarding", {})
|
assert await async_setup_component(hass, "onboarding", {})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
refresh_token.credential = None
|
refresh_token.credential = None
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
|
@ -134,7 +134,7 @@ async def test_auth_active_user_inactive(
|
||||||
hass_access_token: str,
|
hass_access_token: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test authenticating with a token."""
|
"""Test authenticating with a token."""
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
refresh_token.user.is_active = False
|
refresh_token.user.is_active = False
|
||||||
assert await async_setup_component(hass, "websocket_api", {})
|
assert await async_setup_component(hass, "websocket_api", {})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
@ -216,8 +216,8 @@ async def test_auth_close_after_revoke(
|
||||||
"""Test that a websocket is closed after the refresh token is revoked."""
|
"""Test that a websocket is closed after the refresh token is revoked."""
|
||||||
assert not websocket_client.closed
|
assert not websocket_client.closed
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
hass.auth.async_remove_refresh_token(refresh_token)
|
||||||
|
|
||||||
msg = await websocket_client.receive()
|
msg = await websocket_client.receive()
|
||||||
assert msg.type == aiohttp.WSMsgType.CLOSE
|
assert msg.type == aiohttp.WSMsgType.CLOSE
|
||||||
|
|
|
@ -775,7 +775,7 @@ async def test_call_service_context_with_user(
|
||||||
msg = await ws.receive_json()
|
msg = await ws.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
call = calls[0]
|
call = calls[0]
|
||||||
|
|
Loading…
Add table
Reference in a new issue