diff --git a/src/keycloak/keycloak_openid.py b/src/keycloak/keycloak_openid.py index e080e8a..665d21c 100644 --- a/src/keycloak/keycloak_openid.py +++ b/src/keycloak/keycloak_openid.py @@ -28,7 +28,7 @@ class to handle authentication and token manipulation. """ import json -from typing import Optional +from typing import Optional, Union from jwcrypto import jwk, jwt @@ -582,9 +582,13 @@ class KeycloakOpenID: return raise_error_from_response(data_raw, KeycloakPostError) @staticmethod - def _verify_token(token, key, **kwargs): + def _verify_token(token, key: Union[str, jwk.JWK, jwk.JWKSet, None], **kwargs): # keep the function free of IO # this way it can be used by `decode_token` and `a_decode_token` + if isinstance(key, str): + key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----" + key = jwk.JWK.from_pem(key.encode("utf-8")) + if key is not None: leeway = kwargs.pop("leeway", 60) full_jwt = jwt.JWT(jwt=token, **kwargs) @@ -596,11 +600,6 @@ class KeycloakOpenID: full_jwt.token.objects["valid"] = True return json.loads(full_jwt.token.payload.decode("utf-8")) - @staticmethod - def _public_key_to_jwk(key: str) -> jwk.JWK: - key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----" - return jwk.JWK.from_pem(key.encode("utf-8")) - def decode_token(self, token, validate: bool = True, **kwargs): """Decode user token. @@ -625,7 +624,7 @@ class KeycloakOpenID: """ key = kwargs.pop("key", None) if validate and key is None: - key = self._public_key_to_jwk(self.public_key()) + key = self.public_key() return self._verify_token(token, key, **kwargs) @@ -1261,7 +1260,7 @@ class KeycloakOpenID: """ key = kwargs.pop("key", None) if validate and key is None: - key = self._public_key_to_jwk(await self.a_public_key()) + key = await self.a_public_key() return self._verify_token(token, key, **kwargs)