Cache decode of JWT tokens (#90013)
This commit is contained in:
parent
8a591fa16e
commit
ca576d45ac
5 changed files with 308 additions and 6 deletions
|
@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
|||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import auth_store, models
|
||||
from . import auth_store, jwt_wrapper, models
|
||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
|
||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||
|
@ -555,9 +555,7 @@ class AuthManager:
|
|||
) -> models.RefreshToken | None:
|
||||
"""Return refresh token if an access token is valid."""
|
||||
try:
|
||||
unverif_claims = jwt.decode(
|
||||
token, algorithms=["HS256"], options={"verify_signature": False}
|
||||
)
|
||||
unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
|
@ -573,7 +571,9 @@ class AuthManager:
|
|||
issuer = refresh_token.id
|
||||
|
||||
try:
|
||||
jwt.decode(token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"])
|
||||
jwt_wrapper.verify_and_decode(
|
||||
token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"]
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
|
|
116
homeassistant/auth/jwt_wrapper.py
Normal file
116
homeassistant/auth/jwt_wrapper.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
"""Provide a wrapper around JWT that caches decoding tokens.
|
||||
|
||||
Since we decode the same tokens over and over again
|
||||
we can cache the result of the decode of valid tokens
|
||||
to speed up the process.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache, partial
|
||||
from typing import Any
|
||||
|
||||
from jwt import DecodeError, PyJWS, PyJWT
|
||||
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
JWT_TOKEN_CACHE_SIZE = 16
|
||||
MAX_TOKEN_SIZE = 8192
|
||||
|
||||
_VERIFY_KEYS = ("signature", "exp", "nbf", "iat", "aud", "iss")
|
||||
|
||||
_VERIFY_OPTIONS: dict[str, Any] = {f"verify_{key}": True for key in _VERIFY_KEYS} | {
|
||||
"require": []
|
||||
}
|
||||
_NO_VERIFY_OPTIONS = {f"verify_{key}": False for key in _VERIFY_KEYS}
|
||||
|
||||
|
||||
class _PyJWSWithLoadCache(PyJWS):
|
||||
"""PyJWS with a dedicated load implementation."""
|
||||
|
||||
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
|
||||
# We only ever have a global instance of this class
|
||||
# so we do not have to worry about the LRU growing
|
||||
# each time we create a new instance.
|
||||
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
|
||||
"""Load a JWS."""
|
||||
return super()._load(jwt)
|
||||
|
||||
|
||||
_jws = _PyJWSWithLoadCache()
|
||||
|
||||
|
||||
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
|
||||
def _decode_payload(json_payload: str) -> dict[str, Any]:
|
||||
"""Decode the payload from a JWS dictionary."""
|
||||
try:
|
||||
payload = json_loads(json_payload)
|
||||
except ValueError as err:
|
||||
raise DecodeError(f"Invalid payload string: {err}") from err
|
||||
if not isinstance(payload, dict):
|
||||
raise DecodeError("Invalid payload string: must be a json object")
|
||||
return payload
|
||||
|
||||
|
||||
class _PyJWTWithVerify(PyJWT):
|
||||
"""PyJWT with a fast decode implementation."""
|
||||
|
||||
def decode_payload(
|
||||
self, jwt: str, key: str, options: dict[str, Any], algorithms: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""Decode a JWT's payload."""
|
||||
if len(jwt) > MAX_TOKEN_SIZE:
|
||||
# Avoid caching impossible tokens
|
||||
raise DecodeError("Token too large")
|
||||
return _decode_payload(
|
||||
_jws.decode_complete(
|
||||
jwt=jwt,
|
||||
key=key,
|
||||
algorithms=algorithms,
|
||||
options=options,
|
||||
)["payload"]
|
||||
)
|
||||
|
||||
def verify_and_decode(
|
||||
self,
|
||||
jwt: str,
|
||||
key: str,
|
||||
algorithms: list[str],
|
||||
issuer: str | None = None,
|
||||
leeway: int | float | timedelta = 0,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Verify a JWT's signature and claims."""
|
||||
merged_options = {**_VERIFY_OPTIONS, **(options or {})}
|
||||
payload = self.decode_payload(
|
||||
jwt=jwt,
|
||||
key=key,
|
||||
options=merged_options,
|
||||
algorithms=algorithms,
|
||||
)
|
||||
# These should never be missing since we verify them
|
||||
# but this is an additional safeguard to make sure
|
||||
# nothing slips through.
|
||||
assert "exp" in payload, "exp claim is required"
|
||||
assert "iat" in payload, "iat claim is required"
|
||||
self._validate_claims( # type: ignore[no-untyped-call]
|
||||
payload=payload,
|
||||
options=merged_options,
|
||||
issuer=issuer,
|
||||
leeway=leeway,
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
_jwt = _PyJWTWithVerify() # type: ignore[no-untyped-call]
|
||||
verify_and_decode = _jwt.verify_and_decode
|
||||
unverified_hs256_token_decode = lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)(
|
||||
partial(
|
||||
_jwt.decode_payload, key="", algorithms=["HS256"], options=_NO_VERIFY_OPTIONS
|
||||
)
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"unverified_hs256_token_decode",
|
||||
"verify_and_decode",
|
||||
]
|
|
@ -13,6 +13,7 @@ from aiohttp.web import Application, Request, StreamResponse, middleware
|
|||
import jwt
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.auth import jwt_wrapper
|
||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import websocket_api
|
||||
|
@ -175,7 +176,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||
return False
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
claims = jwt_wrapper.verify_and_decode(
|
||||
signature, secret, algorithms=["HS256"], options={"verify_iss": False}
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
|
|
|
@ -3,6 +3,7 @@ from datetime import timedelta
|
|||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
import jwt
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
@ -1127,3 +1128,175 @@ async def test_event_user_updated_fires(hass: HomeAssistant) -> None:
|
|||
|
||||
await hass.async_block_till_done()
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
async def test_access_token_with_invalid_signature(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with an invalid signature."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user,
|
||||
client_name="Good Client",
|
||||
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
access_token_expiration=timedelta(days=3000),
|
||||
)
|
||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
# Now we corrupt the signature
|
||||
header, payload, signature = access_token.split(".")
|
||||
invalid_signature = "a" * len(signature)
|
||||
invalid_token = f"{header}.{payload}.{invalid_signature}"
|
||||
|
||||
assert access_token != invalid_token
|
||||
|
||||
result = await manager.async_validate_access_token(invalid_token)
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_access_token_with_null_signature(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with a null signature."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user,
|
||||
client_name="Good Client",
|
||||
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
access_token_expiration=timedelta(days=3000),
|
||||
)
|
||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
# Now we make the signature all nulls
|
||||
header, payload, signature = access_token.split(".")
|
||||
invalid_signature = "\0" * len(signature)
|
||||
invalid_token = f"{header}.{payload}.{invalid_signature}"
|
||||
|
||||
assert access_token != invalid_token
|
||||
|
||||
result = await manager.async_validate_access_token(invalid_token)
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_access_token_with_empty_signature(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with an empty signature."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user,
|
||||
client_name="Good Client",
|
||||
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
access_token_expiration=timedelta(days=3000),
|
||||
)
|
||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
# Now we make the signature all nulls
|
||||
header, payload, _ = access_token.split(".")
|
||||
invalid_token = f"{header}.{payload}."
|
||||
|
||||
assert access_token != invalid_token
|
||||
|
||||
result = await manager.async_validate_access_token(invalid_token)
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_access_token_with_empty_key(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with an empty key."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user,
|
||||
client_name="Good Client",
|
||||
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
access_token_expiration=timedelta(days=3000),
|
||||
)
|
||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
await manager.async_remove_refresh_token(refresh_token)
|
||||
# Now remove the token from the keyring
|
||||
# so we will get an empty key
|
||||
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
|
||||
|
||||
async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with impossible sizes."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
assert await manager.async_validate_access_token("a" * 10000) is None
|
||||
|
||||
|
||||
async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with invalid json payload."""
|
||||
jws = jwt.PyJWS()
|
||||
token_with_invalid_json = jws.encode(
|
||||
b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||
)
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
assert await manager.async_validate_access_token(token_with_invalid_json) is None
|
||||
|
||||
|
||||
async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with not a dict json payload."""
|
||||
jws = jwt.PyJWS()
|
||||
token_not_a_dict_json = jws.encode(
|
||||
b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||
)
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
assert await manager.async_validate_access_token(token_not_a_dict_json) is None
|
||||
|
||||
|
||||
async def test_access_token_that_expires_soon(mock_hass) -> None:
|
||||
"""Test access token from refresh token that expires very soon."""
|
||||
now = dt_util.utcnow()
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user,
|
||||
client_name="Token that expires very soon",
|
||||
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
access_token_expiration=timedelta(seconds=1),
|
||||
)
|
||||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
with freeze_time(now + timedelta(minutes=1)):
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
|
||||
|
||||
async def test_access_token_from_the_future(mock_hass) -> None:
|
||||
"""Test we reject an access token from the future."""
|
||||
now = dt_util.utcnow()
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
with freeze_time(now + timedelta(days=365)):
|
||||
refresh_token = await manager.async_create_refresh_token(
|
||||
user,
|
||||
client_name="Token that expires very soon",
|
||||
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
access_token_expiration=timedelta(days=10),
|
||||
)
|
||||
assert (
|
||||
refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
)
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
|
||||
with freeze_time(now + timedelta(days=365)):
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
|
12
tests/auth/test_jwt_wrapper.py
Normal file
12
tests/auth/test_jwt_wrapper.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
"""Tests for the Home Assistant auth jwt_wrapper module."""
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth import jwt_wrapper
|
||||
|
||||
|
||||
async def test_reject_access_token_with_impossible_large_size() -> None:
|
||||
"""Test rejecting access tokens with impossible sizes."""
|
||||
with pytest.raises(jwt.DecodeError):
|
||||
jwt_wrapper.unverified_hs256_token_decode("a" * 10000)
|
Loading…
Add table
Reference in a new issue