Browse Source

fix: Feature parity for `a_decode_token` and `decode_token` (#616)

* Consistency for token decoding

* Mark as staticmethod

* Helper function to convert key

* Refactor key handling

* Add tests for validate=False

* Change test name

* Fix failing test

* Remove special case for str

* Some docstring

* docs: missing docstrings

---------

Co-authored-by: Richard Nemeth <ryshoooo@gmail.com>
master v4.7.2
Cristi 12 hours ago
committed by GitHub
parent
commit
ac07820524
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 59
      src/keycloak/keycloak_openid.py
  2. 75
      tests/test_keycloak_openid.py

59
src/keycloak/keycloak_openid.py

@ -28,7 +28,7 @@ class to handle authentication and token manipulation.
""" """
import json import json
from typing import Optional
from typing import Optional, Union
from jwcrypto import jwk, jwt from jwcrypto import jwk, jwt
@ -581,6 +581,33 @@ class KeycloakOpenID:
) )
return raise_error_from_response(data_raw, KeycloakPostError) return raise_error_from_response(data_raw, KeycloakPostError)
@staticmethod
def _verify_token(token, key: Union[jwk.JWK, jwk.JWKSet, None], **kwargs):
"""Decode and optionally validate a token.
:param token: The token to verify
:type token: str
:param key: Which key should be used for validation.
If not provided, the validation is not performed and the token is implicitly valid.
:type key: Union[jwk.JWK, jwk.JWKSet, None]
:param kwargs: Additional keyword arguments for jwcrypto's JWT object
:type kwargs: dict
:returns: Decoded token
"""
# keep the function free of IO
# this way it can be used by `decode_token` and `a_decode_token`
if key is not None:
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
def decode_token(self, token, validate: bool = True, **kwargs): def decode_token(self, token, validate: bool = True, **kwargs):
"""Decode user token. """Decode user token.
@ -603,26 +630,19 @@ class KeycloakOpenID:
:returns: Decoded token :returns: Decoded token
:rtype: dict :rtype: dict
""" """
key = kwargs.pop("key", None)
if validate: if validate:
if "key" not in kwargs:
if key is None:
key = ( key = (
"-----BEGIN PUBLIC KEY-----\n" "-----BEGIN PUBLIC KEY-----\n"
+ self.public_key() + self.public_key()
+ "\n-----END PUBLIC KEY-----" + "\n-----END PUBLIC KEY-----"
) )
key = jwk.JWK.from_pem(key.encode("utf-8")) key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key
key = kwargs.pop("key")
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else: else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None
return self._verify_token(token, key, **kwargs)
def load_authorization_config(self, path): def load_authorization_config(self, path):
"""Load Keycloak settings (authorization). """Load Keycloak settings (authorization).
@ -1273,22 +1293,19 @@ class KeycloakOpenID:
:returns: Decoded token :returns: Decoded token
:rtype: dict :rtype: dict
""" """
key = kwargs.pop("key", None)
if validate: if validate:
if "key" not in kwargs:
if key is None:
key = ( key = (
"-----BEGIN PUBLIC KEY-----\n" "-----BEGIN PUBLIC KEY-----\n"
+ await self.a_public_key() + await self.a_public_key()
+ "\n-----END PUBLIC KEY-----" + "\n-----END PUBLIC KEY-----"
) )
key = jwk.JWK.from_pem(key.encode("utf-8")) key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key
full_jwt = jwt.JWT(jwt=token, **kwargs)
return jwt.json_decode(full_jwt.claims)
else: else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None
return self._verify_token(token, key, **kwargs)
async def a_load_authorization_config(self, path): async def a_load_authorization_config(self, path):
"""Load Keycloak settings (authorization) asynchronously. """Load Keycloak settings (authorization) asynchronously.

75
tests/test_keycloak_openid.py

@ -4,6 +4,8 @@ from inspect import iscoroutinefunction, signature
from typing import Tuple from typing import Tuple
from unittest import mock from unittest import mock
import jwcrypto.jwk
import jwcrypto.jws
import pytest import pytest
from keycloak import KeycloakAdmin, KeycloakOpenID from keycloak import KeycloakAdmin, KeycloakOpenID
@ -317,6 +319,39 @@ def test_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token
def test_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token with an invalid token.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = oid.token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = oid.decode_token(token=access_token)
key = oid.public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))
invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=True)
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=True, key=key
)
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=False)
assert decoded_access_token == decoded_invalid_access_token
decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token
def test_load_authorization_config(oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]): def test_load_authorization_config(oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]):
"""Test load authorization config. """Test load authorization config.
@ -765,7 +800,7 @@ async def test_a_introspect(oid_with_credentials: Tuple[KeycloakOpenID, str, str
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]): async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token.
"""Test decode token asynchronously.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials :param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str] :type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
@ -781,6 +816,44 @@ async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, s
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token
@pytest.mark.asyncio
async def test_a_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token asynchronously an invalid token.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = await oid.a_token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = await oid.a_decode_token(token=access_token)
key = await oid.a_public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))
invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True
)
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True, key=key
)
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False
)
assert decoded_access_token == decoded_invalid_access_token
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_a_load_authorization_config( async def test_a_load_authorization_config(
oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str] oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]

Loading…
Cancel
Save