Browse Source

fix: Feature parity for `a_decode_token` and `decode_token` (#616)

* Consistency for token decoding

* Mark as staticmethod

* Helper function to convert key

* Refactor key handling

* Add tests for validate=False

* Change test name

* Fix failing test

* Remove special case for str

* Some docstring

* docs: missing docstrings

---------

Co-authored-by: Richard Nemeth <ryshoooo@gmail.com>
pull/622/head v4.7.2
Cristi 1 month ago
committed by GitHub
parent
commit
ac07820524
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 59
      src/keycloak/keycloak_openid.py
  2. 75
      tests/test_keycloak_openid.py

59
src/keycloak/keycloak_openid.py

@ -28,7 +28,7 @@ class to handle authentication and token manipulation.
"""
import json
from typing import Optional
from typing import Optional, Union
from jwcrypto import jwk, jwt
@ -581,6 +581,33 @@ class KeycloakOpenID:
)
return raise_error_from_response(data_raw, KeycloakPostError)
@staticmethod
def _verify_token(token, key: Union[jwk.JWK, jwk.JWKSet, None], **kwargs):
"""Decode and optionally validate a token.
:param token: The token to verify
:type token: str
:param key: Which key should be used for validation.
If not provided, the validation is not performed and the token is implicitly valid.
:type key: Union[jwk.JWK, jwk.JWKSet, None]
:param kwargs: Additional keyword arguments for jwcrypto's JWT object
:type kwargs: dict
:returns: Decoded token
"""
# keep the function free of IO
# this way it can be used by `decode_token` and `a_decode_token`
if key is not None:
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
def decode_token(self, token, validate: bool = True, **kwargs):
"""Decode user token.
@ -603,26 +630,19 @@ class KeycloakOpenID:
:returns: Decoded token
:rtype: dict
"""
key = kwargs.pop("key", None)
if validate:
if "key" not in kwargs:
if key is None:
key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self.public_key()
+ "\n-----END PUBLIC KEY-----"
)
key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key
key = kwargs.pop("key")
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None
return self._verify_token(token, key, **kwargs)
def load_authorization_config(self, path):
"""Load Keycloak settings (authorization).
@ -1273,22 +1293,19 @@ class KeycloakOpenID:
:returns: Decoded token
:rtype: dict
"""
key = kwargs.pop("key", None)
if validate:
if "key" not in kwargs:
if key is None:
key = (
"-----BEGIN PUBLIC KEY-----\n"
+ await self.a_public_key()
+ "\n-----END PUBLIC KEY-----"
)
key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key
full_jwt = jwt.JWT(jwt=token, **kwargs)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None
return self._verify_token(token, key, **kwargs)
async def a_load_authorization_config(self, path):
"""Load Keycloak settings (authorization) asynchronously.

75
tests/test_keycloak_openid.py

@ -4,6 +4,8 @@ from inspect import iscoroutinefunction, signature
from typing import Tuple
from unittest import mock
import jwcrypto.jwk
import jwcrypto.jws
import pytest
from keycloak import KeycloakAdmin, KeycloakOpenID
@ -317,6 +319,39 @@ def test_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token
def test_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token with an invalid token.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = oid.token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = oid.decode_token(token=access_token)
key = oid.public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))
invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=True)
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=True, key=key
)
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=False)
assert decoded_access_token == decoded_invalid_access_token
decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token
def test_load_authorization_config(oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]):
"""Test load authorization config.
@ -765,7 +800,7 @@ async def test_a_introspect(oid_with_credentials: Tuple[KeycloakOpenID, str, str
@pytest.mark.asyncio
async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token.
"""Test decode token asynchronously.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
@ -781,6 +816,44 @@ async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, s
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token
@pytest.mark.asyncio
async def test_a_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token asynchronously an invalid token.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = await oid.a_token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = await oid.a_decode_token(token=access_token)
key = await oid.a_public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))
invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True
)
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True, key=key
)
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False
)
assert decoded_access_token == decoded_invalid_access_token
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token
@pytest.mark.asyncio
async def test_a_load_authorization_config(
oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]

Loading…
Cancel
Save