Cache decode of JWT tokens (#90013)

This commit is contained in:
J. Nick Koston 2023-03-22 16:03:41 -10:00 committed by GitHub
parent 8a591fa16e
commit ca576d45ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 308 additions and 6 deletions

View file

@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.util import dt as dt_util
from . import auth_store, models
from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
@ -555,9 +555,7 @@ class AuthManager:
) -> models.RefreshToken | None:
"""Return refresh token if an access token is valid."""
try:
unverif_claims = jwt.decode(
token, algorithms=["HS256"], options={"verify_signature": False}
)
unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
except jwt.InvalidTokenError:
return None
@ -573,7 +571,9 @@ class AuthManager:
issuer = refresh_token.id
try:
jwt.decode(token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"])
jwt_wrapper.verify_and_decode(
token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"]
)
except jwt.InvalidTokenError:
return None

View file

@ -0,0 +1,116 @@
"""Provide a wrapper around JWT that caches decoding tokens.
Since we decode the same tokens over and over again
we can cache the result of the decode of valid tokens
to speed up the process.
"""
from __future__ import annotations
from datetime import timedelta
from functools import lru_cache, partial
from typing import Any
from jwt import DecodeError, PyJWS, PyJWT
from homeassistant.util.json import json_loads
JWT_TOKEN_CACHE_SIZE = 16
MAX_TOKEN_SIZE = 8192
_VERIFY_KEYS = ("signature", "exp", "nbf", "iat", "aud", "iss")
_VERIFY_OPTIONS: dict[str, Any] = {f"verify_{key}": True for key in _VERIFY_KEYS} | {
"require": []
}
_NO_VERIFY_OPTIONS = {f"verify_{key}": False for key in _VERIFY_KEYS}
class _PyJWSWithLoadCache(PyJWS):
"""PyJWS with a dedicated load implementation."""
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
# We only ever have a global instance of this class
# so we do not have to worry about the LRU growing
# each time we create a new instance.
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
"""Load a JWS."""
return super()._load(jwt)
_jws = _PyJWSWithLoadCache()
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
def _decode_payload(json_payload: str) -> dict[str, Any]:
"""Decode the payload from a JWS dictionary."""
try:
payload = json_loads(json_payload)
except ValueError as err:
raise DecodeError(f"Invalid payload string: {err}") from err
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
return payload
class _PyJWTWithVerify(PyJWT):
"""PyJWT with a fast decode implementation."""
def decode_payload(
self, jwt: str, key: str, options: dict[str, Any], algorithms: list[str]
) -> dict[str, Any]:
"""Decode a JWT's payload."""
if len(jwt) > MAX_TOKEN_SIZE:
# Avoid caching impossible tokens
raise DecodeError("Token too large")
return _decode_payload(
_jws.decode_complete(
jwt=jwt,
key=key,
algorithms=algorithms,
options=options,
)["payload"]
)
def verify_and_decode(
self,
jwt: str,
key: str,
algorithms: list[str],
issuer: str | None = None,
leeway: int | float | timedelta = 0,
options: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Verify a JWT's signature and claims."""
merged_options = {**_VERIFY_OPTIONS, **(options or {})}
payload = self.decode_payload(
jwt=jwt,
key=key,
options=merged_options,
algorithms=algorithms,
)
# These should never be missing since we verify them
# but this is an additional safeguard to make sure
# nothing slips through.
assert "exp" in payload, "exp claim is required"
assert "iat" in payload, "iat claim is required"
self._validate_claims( # type: ignore[no-untyped-call]
payload=payload,
options=merged_options,
issuer=issuer,
leeway=leeway,
)
return payload
_jwt = _PyJWTWithVerify() # type: ignore[no-untyped-call]
verify_and_decode = _jwt.verify_and_decode
unverified_hs256_token_decode = lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)(
partial(
_jwt.decode_payload, key="", algorithms=["HS256"], options=_NO_VERIFY_OPTIONS
)
)
__all__ = [
"unverified_hs256_token_decode",
"verify_and_decode",
]

View file

@ -13,6 +13,7 @@ from aiohttp.web import Application, Request, StreamResponse, middleware
import jwt
from yarl import URL
from homeassistant.auth import jwt_wrapper
from homeassistant.auth.const import GROUP_ID_READ_ONLY
from homeassistant.auth.models import User
from homeassistant.components import websocket_api
@ -175,7 +176,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
return False
try:
claims = jwt.decode(
claims = jwt_wrapper.verify_and_decode(
signature, secret, algorithms=["HS256"], options={"verify_iss": False}
)
except jwt.InvalidTokenError:

View file

@ -3,6 +3,7 @@ from datetime import timedelta
from typing import Any
from unittest.mock import Mock, patch
from freezegun import freeze_time
import jwt
import pytest
import voluptuous as vol
@ -1127,3 +1128,175 @@ async def test_event_user_updated_fires(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
assert len(events) == 1
async def test_access_token_with_invalid_signature(mock_hass) -> None:
"""Test rejecting access tokens with an invalid signature."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(
user,
client_name="Good Client",
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
access_token_expiration=timedelta(days=3000),
)
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id
# Now we corrupt the signature
header, payload, signature = access_token.split(".")
invalid_signature = "a" * len(signature)
invalid_token = f"{header}.{payload}.{invalid_signature}"
assert access_token != invalid_token
result = await manager.async_validate_access_token(invalid_token)
assert result is None
async def test_access_token_with_null_signature(mock_hass) -> None:
"""Test rejecting access tokens with a null signature."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(
user,
client_name="Good Client",
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
access_token_expiration=timedelta(days=3000),
)
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id
# Now we make the signature all nulls
header, payload, signature = access_token.split(".")
invalid_signature = "\0" * len(signature)
invalid_token = f"{header}.{payload}.{invalid_signature}"
assert access_token != invalid_token
result = await manager.async_validate_access_token(invalid_token)
assert result is None
async def test_access_token_with_empty_signature(mock_hass) -> None:
"""Test rejecting access tokens with an empty signature."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(
user,
client_name="Good Client",
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
access_token_expiration=timedelta(days=3000),
)
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id
# Now we make the signature all nulls
header, payload, _ = access_token.split(".")
invalid_token = f"{header}.{payload}."
assert access_token != invalid_token
result = await manager.async_validate_access_token(invalid_token)
assert result is None
async def test_access_token_with_empty_key(mock_hass) -> None:
"""Test rejecting access tokens with an empty key."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(
user,
client_name="Good Client",
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
access_token_expiration=timedelta(days=3000),
)
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token)
await manager.async_remove_refresh_token(refresh_token)
# Now remove the token from the keyring
# so we will get an empty key
assert await manager.async_validate_access_token(access_token) is None
async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None:
"""Test rejecting access tokens with impossible sizes."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
assert await manager.async_validate_access_token("a" * 10000) is None
async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
"""Test rejecting access tokens with invalid json payload."""
jws = jwt.PyJWS()
token_with_invalid_json = jws.encode(
b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
)
manager = await auth.auth_manager_from_config(mock_hass, [], [])
assert await manager.async_validate_access_token(token_with_invalid_json) is None
async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
"""Test rejecting access tokens with not a dict json payload."""
jws = jwt.PyJWS()
token_not_a_dict_json = jws.encode(
b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
)
manager = await auth.auth_manager_from_config(mock_hass, [], [])
assert await manager.async_validate_access_token(token_not_a_dict_json) is None
async def test_access_token_that_expires_soon(mock_hass) -> None:
"""Test access token from refresh token that expires very soon."""
now = dt_util.utcnow()
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(
user,
client_name="Token that expires very soon",
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
access_token_expiration=timedelta(seconds=1),
)
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
access_token = manager.async_create_access_token(refresh_token)
rt = await manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id
with freeze_time(now + timedelta(minutes=1)):
assert await manager.async_validate_access_token(access_token) is None
async def test_access_token_from_the_future(mock_hass) -> None:
"""Test we reject an access token from the future."""
now = dt_util.utcnow()
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
with freeze_time(now + timedelta(days=365)):
refresh_token = await manager.async_create_refresh_token(
user,
client_name="Token that expires very soon",
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
access_token_expiration=timedelta(days=10),
)
assert (
refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
)
access_token = manager.async_create_access_token(refresh_token)
assert await manager.async_validate_access_token(access_token) is None
with freeze_time(now + timedelta(days=365)):
rt = await manager.async_validate_access_token(access_token)
assert rt.id == refresh_token.id

View file

@ -0,0 +1,12 @@
"""Tests for the Home Assistant auth jwt_wrapper module."""
import jwt
import pytest
from homeassistant.auth import jwt_wrapper
async def test_reject_access_token_with_impossible_large_size() -> None:
"""Test rejecting access tokens with impossible sizes."""
with pytest.raises(jwt.DecodeError):
jwt_wrapper.unverified_hs256_token_decode("a" * 10000)