Convert getting and removing access tokens to normal functions (#108670)

This commit is contained in:
J. Nick Koston 2024-01-22 20:51:33 -10:00 committed by GitHub
parent 904032e944
commit 2eea658fd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 98 additions and 124 deletions

View file

@ -458,23 +458,22 @@ class AuthManager:
credential, credential,
) )
async def async_get_refresh_token( @callback
self, token_id: str def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
) -> models.RefreshToken | None:
"""Get refresh token by id.""" """Get refresh token by id."""
return await self._store.async_get_refresh_token(token_id) return self._store.async_get_refresh_token(token_id)
async def async_get_refresh_token_by_token( @callback
def async_get_refresh_token_by_token(
self, token: str self, token: str
) -> models.RefreshToken | None: ) -> models.RefreshToken | None:
"""Get refresh token by token.""" """Get refresh token by token."""
return await self._store.async_get_refresh_token_by_token(token) return self._store.async_get_refresh_token_by_token(token)
async def async_remove_refresh_token( @callback
self, refresh_token: models.RefreshToken def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
) -> None:
"""Delete a refresh token.""" """Delete a refresh token."""
await self._store.async_remove_refresh_token(refresh_token) self._store.async_remove_refresh_token(refresh_token)
callbacks = self._revoke_callbacks.pop(refresh_token.id, ()) callbacks = self._revoke_callbacks.pop(refresh_token.id, ())
for revoke_callback in callbacks: for revoke_callback in callbacks:
@ -554,16 +553,15 @@ class AuthManager:
if provider := self._async_resolve_provider(refresh_token): if provider := self._async_resolve_provider(refresh_token):
provider.async_validate_refresh_token(refresh_token, remote_ip) provider.async_validate_refresh_token(refresh_token, remote_ip)
async def async_validate_access_token( @callback
self, token: str def async_validate_access_token(self, token: str) -> models.RefreshToken | None:
) -> models.RefreshToken | None:
"""Return refresh token if an access token is valid.""" """Return refresh token if an access token is valid."""
try: try:
unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token) unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return None return None
refresh_token = await self.async_get_refresh_token( refresh_token = self.async_get_refresh_token(
cast(str, unverif_claims.get("iss")) cast(str, unverif_claims.get("iss"))
) )

View file

@ -207,18 +207,16 @@ class AuthStore:
self._async_schedule_save() self._async_schedule_save()
return refresh_token return refresh_token
async def async_remove_refresh_token( @callback
self, refresh_token: models.RefreshToken def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
) -> None:
"""Remove a refresh token.""" """Remove a refresh token."""
for user in self._users.values(): for user in self._users.values():
if user.refresh_tokens.pop(refresh_token.id, None): if user.refresh_tokens.pop(refresh_token.id, None):
self._async_schedule_save() self._async_schedule_save()
break break
async def async_get_refresh_token( @callback
self, token_id: str def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
) -> models.RefreshToken | None:
"""Get refresh token by id.""" """Get refresh token by id."""
for user in self._users.values(): for user in self._users.values():
refresh_token = user.refresh_tokens.get(token_id) refresh_token = user.refresh_tokens.get(token_id)
@ -227,7 +225,8 @@ class AuthStore:
return None return None
async def async_get_refresh_token_by_token( @callback
def async_get_refresh_token_by_token(
self, token: str self, token: str
) -> models.RefreshToken | None: ) -> models.RefreshToken | None:
"""Get refresh token by token.""" """Get refresh token by token."""

View file

@ -124,7 +124,6 @@ as part of a config flow.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from http import HTTPStatus from http import HTTPStatus
@ -220,12 +219,12 @@ class RevokeTokenView(HomeAssistantView):
if (token := data.get("token")) is None: if (token := data.get("token")) is None:
return web.Response(status=HTTPStatus.OK) return web.Response(status=HTTPStatus.OK)
refresh_token = await hass.auth.async_get_refresh_token_by_token(token) refresh_token = hass.auth.async_get_refresh_token_by_token(token)
if refresh_token is None: if refresh_token is None:
return web.Response(status=HTTPStatus.OK) return web.Response(status=HTTPStatus.OK)
await hass.auth.async_remove_refresh_token(refresh_token) hass.auth.async_remove_refresh_token(refresh_token)
return web.Response(status=HTTPStatus.OK) return web.Response(status=HTTPStatus.OK)
@ -355,7 +354,7 @@ class TokenView(HomeAssistantView):
{"error": "invalid_request"}, status_code=HTTPStatus.BAD_REQUEST {"error": "invalid_request"}, status_code=HTTPStatus.BAD_REQUEST
) )
refresh_token = await hass.auth.async_get_refresh_token_by_token(token) refresh_token = hass.auth.async_get_refresh_token_by_token(token)
if refresh_token is None: if refresh_token is None:
return self.json( return self.json(
@ -597,7 +596,7 @@ async def websocket_delete_refresh_token(
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token") connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
return return
await hass.auth.async_remove_refresh_token(refresh_token) hass.auth.async_remove_refresh_token(refresh_token)
connection.send_result(msg["id"], {}) connection.send_result(msg["id"], {})
@ -613,28 +612,23 @@ async def websocket_delete_all_refresh_tokens(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Handle delete all refresh tokens request.""" """Handle delete all refresh tokens request."""
tasks = []
current_refresh_token: RefreshToken current_refresh_token: RefreshToken
for token in connection.user.refresh_tokens.values(): remove_failed = False
for token in list(connection.user.refresh_tokens.values()):
if token.id == connection.refresh_token_id: if token.id == connection.refresh_token_id:
# Skip the current refresh token as it has revoke_callback, # Skip the current refresh token as it has revoke_callback,
# which cancels/closes the connection. # which cancels/closes the connection.
# It will be removed after sending the result. # It will be removed after sending the result.
current_refresh_token = token current_refresh_token = token
continue continue
tasks.append( try:
hass.async_create_task(hass.auth.async_remove_refresh_token(token)) hass.auth.async_remove_refresh_token(token)
) except Exception as err: # pylint: disable=broad-except
getLogger(__name__).exception(
remove_failed = False "During refresh token removal, the following error occurred: %s",
if tasks: err,
for result in await asyncio.gather(*tasks, return_exceptions=True): )
if isinstance(result, Exception): remove_failed = True
getLogger(__name__).exception(
"During refresh token removal, the following error occurred: %s",
result,
)
remove_failed = True
if remove_failed: if remove_failed:
connection.send_error( connection.send_error(
@ -643,7 +637,8 @@ async def websocket_delete_all_refresh_tokens(
else: else:
connection.send_result(msg["id"], {}) connection.send_result(msg["id"], {})
hass.async_create_task(hass.auth.async_remove_refresh_token(current_refresh_token)) # This will close the connection so we need to send the result first.
hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token)
@websocket_api.websocket_command( @websocket_api.websocket_command(

View file

@ -151,7 +151,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
if auth_type != "Bearer": if auth_type != "Bearer":
return False return False
refresh_token = await hass.auth.async_validate_access_token(auth_val) refresh_token = hass.auth.async_validate_access_token(auth_val)
if refresh_token is None: if refresh_token is None:
return False return False
@ -189,7 +189,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
if claims["params"] != params: if claims["params"] != params:
return False return False
refresh_token = await hass.auth.async_get_refresh_token(claims["iss"]) refresh_token = hass.auth.async_get_refresh_token(claims["iss"])
if refresh_token is None: if refresh_token is None:
return False return False

View file

@ -259,7 +259,7 @@ class IntegrationOnboardingView(_BaseOnboardingView):
"invalid client id or redirect uri", HTTPStatus.BAD_REQUEST "invalid client id or redirect uri", HTTPStatus.BAD_REQUEST
) )
refresh_token = await hass.auth.async_get_refresh_token(refresh_token_id) refresh_token = hass.auth.async_get_refresh_token(refresh_token_id)
if refresh_token is None or refresh_token.credential is None: if refresh_token is None or refresh_token.credential is None:
return self.json_message( return self.json_message(
"Credentials for user not available", HTTPStatus.FORBIDDEN "Credentials for user not available", HTTPStatus.FORBIDDEN

View file

@ -80,9 +80,7 @@ class AuthPhase:
raise Disconnect from err raise Disconnect from err
if (access_token := valid_msg.get("access_token")) and ( if (access_token := valid_msg.get("access_token")) and (
refresh_token := await self._hass.auth.async_validate_access_token( refresh_token := self._hass.auth.async_validate_access_token(access_token)
access_token
)
): ):
conn = ActiveConnection( conn = ActiveConnection(
self._logger, self._logger,

View file

@ -371,7 +371,7 @@ async def test_cannot_retrieve_expired_access_token(hass: HomeAssistant) -> None
assert refresh_token.client_id == CLIENT_ID assert refresh_token.client_id == CLIENT_ID
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
assert await manager.async_validate_access_token(access_token) is refresh_token assert manager.async_validate_access_token(access_token) is refresh_token
# We patch time directly here because we want the access token to be created with # We patch time directly here because we want the access token to be created with
# an expired time, but we do not want to freeze time so that jwt will compare it # an expired time, but we do not want to freeze time so that jwt will compare it
@ -385,7 +385,7 @@ async def test_cannot_retrieve_expired_access_token(hass: HomeAssistant) -> None
): ):
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
assert await manager.async_validate_access_token(access_token) is None assert manager.async_validate_access_token(access_token) is None
async def test_generating_system_user(hass: HomeAssistant) -> None: async def test_generating_system_user(hass: HomeAssistant) -> None:
@ -572,10 +572,10 @@ async def test_remove_refresh_token(mock_hass) -> None:
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
await manager.async_remove_refresh_token(refresh_token) manager.async_remove_refresh_token(refresh_token)
assert await manager.async_get_refresh_token(refresh_token.id) is None assert manager.async_get_refresh_token(refresh_token.id) is None
assert await manager.async_validate_access_token(access_token) is None assert manager.async_validate_access_token(access_token) is None
async def test_register_revoke_token_callback(mock_hass) -> None: async def test_register_revoke_token_callback(mock_hass) -> None:
@ -591,7 +591,7 @@ async def test_register_revoke_token_callback(mock_hass) -> None:
called = True called = True
manager.async_register_revoke_token_callback(refresh_token.id, cb) manager.async_register_revoke_token_callback(refresh_token.id, cb)
await manager.async_remove_refresh_token(refresh_token) manager.async_remove_refresh_token(refresh_token)
assert called assert called
@ -610,7 +610,7 @@ async def test_unregister_revoke_token_callback(mock_hass) -> None:
unregister = manager.async_register_revoke_token_callback(refresh_token.id, cb) unregister = manager.async_register_revoke_token_callback(refresh_token.id, cb)
unregister() unregister()
await manager.async_remove_refresh_token(refresh_token) manager.async_remove_refresh_token(refresh_token)
assert not called assert not called
@ -664,7 +664,7 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
jwt_key = refresh_token.jwt_key jwt_key = refresh_token.jwt_key
rt = await manager.async_validate_access_token(access_token) rt = manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id assert rt.id == refresh_token.id
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -675,9 +675,9 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
access_token_expiration=timedelta(days=3000), access_token_expiration=timedelta(days=3000),
) )
await manager.async_remove_refresh_token(refresh_token) manager.async_remove_refresh_token(refresh_token)
assert refresh_token.id not in user.refresh_tokens assert refresh_token.id not in user.refresh_tokens
rt = await manager.async_validate_access_token(access_token) rt = manager.async_validate_access_token(access_token)
assert rt is None, "Previous issued access token has been invoked" assert rt is None, "Previous issued access token has been invoked"
refresh_token_2 = await manager.async_create_refresh_token( refresh_token_2 = await manager.async_create_refresh_token(
@ -694,7 +694,7 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
assert access_token != access_token_2 assert access_token != access_token_2
assert jwt_key != jwt_key_2 assert jwt_key != jwt_key_2
rt = await manager.async_validate_access_token(access_token_2) rt = manager.async_validate_access_token(access_token_2)
jwt_payload = jwt.decode(access_token_2, rt.jwt_key, algorithms=["HS256"]) jwt_payload = jwt.decode(access_token_2, rt.jwt_key, algorithms=["HS256"])
assert jwt_payload["iss"] == refresh_token_2.id assert jwt_payload["iss"] == refresh_token_2.id
assert ( assert (
@ -1144,7 +1144,7 @@ async def test_access_token_with_invalid_signature(mock_hass) -> None:
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token) rt = manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id assert rt.id == refresh_token.id
# Now we corrupt the signature # Now we corrupt the signature
@ -1154,7 +1154,7 @@ async def test_access_token_with_invalid_signature(mock_hass) -> None:
assert access_token != invalid_token assert access_token != invalid_token
result = await manager.async_validate_access_token(invalid_token) result = manager.async_validate_access_token(invalid_token)
assert result is None assert result is None
@ -1171,7 +1171,7 @@ async def test_access_token_with_null_signature(mock_hass) -> None:
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token) rt = manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id assert rt.id == refresh_token.id
# Now we make the signature all nulls # Now we make the signature all nulls
@ -1181,7 +1181,7 @@ async def test_access_token_with_null_signature(mock_hass) -> None:
assert access_token != invalid_token assert access_token != invalid_token
result = await manager.async_validate_access_token(invalid_token) result = manager.async_validate_access_token(invalid_token)
assert result is None assert result is None
@ -1198,7 +1198,7 @@ async def test_access_token_with_empty_signature(mock_hass) -> None:
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token) rt = manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id assert rt.id == refresh_token.id
# Now we make the signature all nulls # Now we make the signature all nulls
@ -1207,7 +1207,7 @@ async def test_access_token_with_empty_signature(mock_hass) -> None:
assert access_token != invalid_token assert access_token != invalid_token
result = await manager.async_validate_access_token(invalid_token) result = manager.async_validate_access_token(invalid_token)
assert result is None assert result is None
@ -1225,17 +1225,17 @@ async def test_access_token_with_empty_key(mock_hass) -> None:
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
await manager.async_remove_refresh_token(refresh_token) manager.async_remove_refresh_token(refresh_token)
# Now remove the token from the keyring # Now remove the token from the keyring
# so we will get an empty key # so we will get an empty key
assert await manager.async_validate_access_token(access_token) is None assert manager.async_validate_access_token(access_token) is None
async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None: async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None:
"""Test rejecting access tokens with impossible sizes.""" """Test rejecting access tokens with impossible sizes."""
manager = await auth.auth_manager_from_config(mock_hass, [], []) manager = await auth.auth_manager_from_config(mock_hass, [], [])
assert await manager.async_validate_access_token("a" * 10000) is None assert manager.async_validate_access_token("a" * 10000) is None
async def test_reject_token_with_invalid_json_payload(mock_hass) -> None: async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
@ -1245,7 +1245,7 @@ async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"} b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
) )
manager = await auth.auth_manager_from_config(mock_hass, [], []) manager = await auth.auth_manager_from_config(mock_hass, [], [])
assert await manager.async_validate_access_token(token_with_invalid_json) is None assert manager.async_validate_access_token(token_with_invalid_json) is None
async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None: async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
@ -1255,7 +1255,7 @@ async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"} b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
) )
manager = await auth.auth_manager_from_config(mock_hass, [], []) manager = await auth.auth_manager_from_config(mock_hass, [], [])
assert await manager.async_validate_access_token(token_not_a_dict_json) is None assert manager.async_validate_access_token(token_not_a_dict_json) is None
async def test_access_token_that_expires_soon(mock_hass) -> None: async def test_access_token_that_expires_soon(mock_hass) -> None:
@ -1272,11 +1272,11 @@ async def test_access_token_that_expires_soon(mock_hass) -> None:
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token) rt = manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id assert rt.id == refresh_token.id
with freeze_time(now + timedelta(minutes=1)): with freeze_time(now + timedelta(minutes=1)):
assert await manager.async_validate_access_token(access_token) is None assert manager.async_validate_access_token(access_token) is None
async def test_access_token_from_the_future(mock_hass) -> None: async def test_access_token_from_the_future(mock_hass) -> None:
@ -1296,8 +1296,8 @@ async def test_access_token_from_the_future(mock_hass) -> None:
) )
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
assert await manager.async_validate_access_token(access_token) is None assert manager.async_validate_access_token(access_token) is None
with freeze_time(now + timedelta(days=365)): with freeze_time(now + timedelta(days=365)):
rt = await manager.async_validate_access_token(access_token) rt = manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id assert rt.id == refresh_token.id

View file

@ -588,7 +588,7 @@ async def test_api_fire_event_context(
) )
await hass.async_block_till_done() await hass.async_block_till_done()
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
assert len(test_value) == 1 assert len(test_value) == 1
assert test_value[0].context.user_id == refresh_token.user.id assert test_value[0].context.user_id == refresh_token.user.id
@ -606,7 +606,7 @@ async def test_api_call_service_context(
) )
await hass.async_block_till_done() await hass.async_block_till_done()
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context.user_id == refresh_token.user.id assert calls[0].context.user_id == refresh_token.user.id
@ -622,7 +622,7 @@ async def test_api_set_state_context(
headers={"authorization": f"Bearer {hass_access_token}"}, headers={"authorization": f"Bearer {hass_access_token}"},
) )
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
state = hass.states.get("light.kitchen") state = hass.states.get("light.kitchen")
assert state.context.user_id == refresh_token.user.id assert state.context.user_id == refresh_token.user.id

View file

@ -88,9 +88,7 @@ async def test_login_new_user_and_trying_refresh_token(
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
tokens = await resp.json() tokens = await resp.json()
assert ( assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
)
assert tokens["ha_auth_provider"] == "insecure_example" assert tokens["ha_auth_provider"] == "insecure_example"
# Use refresh token to get more tokens. # Use refresh token to get more tokens.
@ -106,9 +104,7 @@ async def test_login_new_user_and_trying_refresh_token(
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
tokens = await resp.json() tokens = await resp.json()
assert "refresh_token" not in tokens assert "refresh_token" not in tokens
assert ( assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
)
# Test using access token to hit API. # Test using access token to hit API.
resp = await client.get("/api/") resp = await client.get("/api/")
@ -205,7 +201,7 @@ async def test_ws_current_user(
"""Test the current user command with Home Assistant creds.""" """Test the current user command with Home Assistant creds."""
assert await async_setup_component(hass, "auth", {}) assert await async_setup_component(hass, "auth", {})
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
user = refresh_token.user user = refresh_token.user
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
@ -275,9 +271,7 @@ async def test_refresh_token_system_generated(
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
tokens = await resp.json() tokens = await resp.json()
assert ( assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
)
async def test_refresh_token_different_client_id( async def test_refresh_token_different_client_id(
@ -323,9 +317,7 @@ async def test_refresh_token_different_client_id(
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
tokens = await resp.json() tokens = await resp.json()
assert ( assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
)
async def test_refresh_token_checks_local_only_user( async def test_refresh_token_checks_local_only_user(
@ -406,16 +398,14 @@ async def test_revoking_refresh_token(
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
tokens = await resp.json() tokens = await resp.json()
assert ( assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
)
# Revoke refresh token # Revoke refresh token
resp = await client.post(url, data={**base_data, "token": refresh_token.token}) resp = await client.post(url, data={**base_data, "token": refresh_token.token})
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
# Old access token should be no longer valid # Old access token should be no longer valid
assert await hass.auth.async_validate_access_token(tokens["access_token"]) is None assert hass.auth.async_validate_access_token(tokens["access_token"]) is None
# Test that we no longer can create an access token # Test that we no longer can create an access token
resp = await client.post( resp = await client.post(
@ -454,7 +444,7 @@ async def test_ws_long_lived_access_token(
long_lived_access_token = result["result"] long_lived_access_token = result["result"]
assert long_lived_access_token is not None assert long_lived_access_token is not None
refresh_token = await hass.auth.async_validate_access_token(long_lived_access_token) refresh_token = hass.auth.async_validate_access_token(long_lived_access_token)
assert refresh_token.client_id is None assert refresh_token.client_id is None
assert refresh_token.client_name == "GPS Logger" assert refresh_token.client_name == "GPS Logger"
assert refresh_token.client_icon is None assert refresh_token.client_icon is None
@ -474,7 +464,7 @@ async def test_ws_refresh_tokens(
assert result["success"], result assert result["success"], result
assert len(result["result"]) == 1 assert len(result["result"]) == 1
token = result["result"][0] token = result["result"][0]
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
assert token["id"] == refresh_token.id assert token["id"] == refresh_token.id
assert token["type"] == refresh_token.token_type assert token["type"] == refresh_token.token_type
assert token["client_id"] == refresh_token.client_id assert token["client_id"] == refresh_token.client_id
@ -514,7 +504,7 @@ async def test_ws_delete_refresh_token(
result = await ws_client.receive_json() result = await ws_client.receive_json()
assert result["success"], result assert result["success"], result
refresh_token = await hass.auth.async_get_refresh_token(refresh_token.id) refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
assert refresh_token is None assert refresh_token is None
@ -573,7 +563,7 @@ async def test_ws_delete_all_refresh_tokens_error(
) in caplog.record_tuples ) in caplog.record_tuples
for token in tokens: for token in tokens:
refresh_token = await hass.auth.async_get_refresh_token(token["id"]) refresh_token = hass.auth.async_get_refresh_token(token["id"])
assert refresh_token is None assert refresh_token is None
@ -614,7 +604,7 @@ async def test_ws_delete_all_refresh_tokens(
result = await ws_client.receive_json() result = await ws_client.receive_json()
assert result, result["success"] assert result, result["success"]
for token in tokens: for token in tokens:
refresh_token = await hass.auth.async_get_refresh_token(token["id"]) refresh_token = hass.auth.async_get_refresh_token(token["id"])
assert refresh_token is None assert refresh_token is None

View file

@ -136,7 +136,7 @@ async def test_delete_unable_self_account(
) -> None: ) -> None:
"""Test we cannot delete our own account.""" """Test we cannot delete our own account."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
await client.send_json( await client.send_json(
{"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": refresh_token.user.id} {"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": refresh_token.user.id}

View file

@ -211,7 +211,7 @@ async def test_auth_active_access_with_access_token_in_header(
token = hass_access_token token = hass_access_token
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
req = await client.get("/", headers={"Authorization": f"Bearer {token}"}) req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
@ -231,7 +231,7 @@ async def test_auth_active_access_with_access_token_in_header(
req = await client.get("/", headers={"Authorization": f"BEARER {token}"}) req = await client.get("/", headers={"Authorization": f"BEARER {token}"})
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
refresh_token.user.is_active = False refresh_token.user.is_active = False
req = await client.get("/", headers={"Authorization": f"Bearer {token}"}) req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED
@ -297,7 +297,7 @@ async def test_auth_access_signed_path_with_refresh_token(
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
signed_path = async_sign_path( signed_path = async_sign_path(
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
@ -325,7 +325,7 @@ async def test_auth_access_signed_path_with_refresh_token(
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED
# refresh token gone should also invalidate signature # refresh token gone should also invalidate signature
await hass.auth.async_remove_refresh_token(refresh_token) hass.auth.async_remove_refresh_token(refresh_token)
req = await client.get(signed_path) req = await client.get(signed_path)
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED
@ -342,7 +342,7 @@ async def test_auth_access_signed_path_with_query_param(
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
signed_path = async_sign_path( signed_path = async_sign_path(
hass, "/?test=test", timedelta(seconds=5), refresh_token_id=refresh_token.id hass, "/?test=test", timedelta(seconds=5), refresh_token_id=refresh_token.id
@ -372,7 +372,7 @@ async def test_auth_access_signed_path_with_query_param_order(
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
signed_path = async_sign_path( signed_path = async_sign_path(
hass, hass,
@ -413,7 +413,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param(
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
signed_path = async_sign_path( signed_path = async_sign_path(
hass, hass,
@ -452,7 +452,7 @@ async def test_auth_access_signed_path_with_query_param_tamper(
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
signed_path = async_sign_path( signed_path = async_sign_path(
hass, base_url, timedelta(seconds=5), refresh_token_id=refresh_token.id hass, base_url, timedelta(seconds=5), refresh_token_id=refresh_token.id
@ -491,9 +491,7 @@ async def test_auth_access_signed_path_via_websocket(
assert msg["id"] == 5 assert msg["id"] == 5
assert msg["success"] assert msg["success"]
refresh_token = await hass.auth.async_validate_access_token( refresh_token = hass.auth.async_validate_access_token(hass_read_only_access_token)
hass_read_only_access_token
)
signature = yarl.URL(msg["result"]["path"]).query["authSig"] signature = yarl.URL(msg["result"]["path"]).query["authSig"]
claims = jwt.decode( claims = jwt.decode(
signature, signature,
@ -523,7 +521,7 @@ async def test_auth_access_signed_path_with_http(
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
req = await client.get( req = await client.get(
"/hello", headers={"Authorization": f"Bearer {hass_access_token}"} "/hello", headers={"Authorization": f"Bearer {hass_access_token}"}
@ -567,7 +565,7 @@ async def test_local_only_user_rejected(
await async_setup_auth(hass, app) await async_setup_auth(hass, app)
set_mock_ip = mock_real_ip(app) set_mock_ip = mock_real_ip(app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
req = await client.get("/", headers={"Authorization": f"Bearer {token}"}) req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK

View file

@ -232,9 +232,7 @@ async def test_onboarding_user(
assert resp.status == 200 assert resp.status == 200
tokens = await resp.json() tokens = await resp.json()
assert ( assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
)
# Validate created areas # Validate created areas
assert len(area_registry.areas) == 3 assert len(area_registry.areas) == 3
@ -347,9 +345,7 @@ async def test_onboarding_integration(
assert const.STEP_INTEGRATION in hass_storage[const.DOMAIN]["data"]["done"] assert const.STEP_INTEGRATION in hass_storage[const.DOMAIN]["data"]["done"]
tokens = await resp.json() tokens = await resp.json()
assert ( assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
)
# Onboarding refresh token and new refresh token # Onboarding refresh token and new refresh token
user = await hass.auth.async_get_user(hass_admin_user.id) user = await hass.auth.async_get_user(hass_admin_user.id)
@ -368,7 +364,7 @@ async def test_onboarding_integration_missing_credential(
assert await async_setup_component(hass, "onboarding", {}) assert await async_setup_component(hass, "onboarding", {})
await hass.async_block_till_done() await hass.async_block_till_done()
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
refresh_token.credential = None refresh_token.credential = None
client = await hass_client() client = await hass_client()

View file

@ -134,7 +134,7 @@ async def test_auth_active_user_inactive(
hass_access_token: str, hass_access_token: str,
) -> None: ) -> None:
"""Test authenticating with a token.""" """Test authenticating with a token."""
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
refresh_token.user.is_active = False refresh_token.user.is_active = False
assert await async_setup_component(hass, "websocket_api", {}) assert await async_setup_component(hass, "websocket_api", {})
await hass.async_block_till_done() await hass.async_block_till_done()
@ -216,8 +216,8 @@ async def test_auth_close_after_revoke(
"""Test that a websocket is closed after the refresh token is revoked.""" """Test that a websocket is closed after the refresh token is revoked."""
assert not websocket_client.closed assert not websocket_client.closed
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
await hass.auth.async_remove_refresh_token(refresh_token) hass.auth.async_remove_refresh_token(refresh_token)
msg = await websocket_client.receive() msg = await websocket_client.receive()
assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.type == aiohttp.WSMsgType.CLOSE

View file

@ -775,7 +775,7 @@ async def test_call_service_context_with_user(
msg = await ws.receive_json() msg = await ws.receive_json()
assert msg["success"] assert msg["success"]
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = hass.auth.async_validate_access_token(hass_access_token)
assert len(calls) == 1 assert len(calls) == 1
call = calls[0] call = calls[0]