Browse Source

iam: add ECDSA support for OIDC token validation (#8166)

* iam: add ECDSA support for OIDC token validation

Fixes seaweedfs/seaweedfs#8148

* iam: refactor OIDC ECDSA tests and add failure cases

- Refactored TestOIDCProviderJWTValidationECDSA to use t.Run
- Added sub-tests for expired token, wrong key, invalid issuer, and invalid audience

* Update weed/iam/oidc/oidc_provider_test.go

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* iam: improve error type assertions for OIDC invalid signature tests

- Updated both RSA and ECDSA tests to specifically check for ErrProviderInvalidToken

* iam: pad EC coordinates in OIDC tests to comply with RFC 7518

- Coordinates are now zero-padded to the full field size (e.g., 32 bytes for P-256)
- Ensures interoperability with strict OIDC providers

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
pull/8115/merge
Chris Lu 3 days ago
committed by GitHub
parent
commit
23c25379ca
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 2
      weed/iam/oidc/oidc_provider.go
  2. 150
      weed/iam/oidc/oidc_provider_test.go

2
weed/iam/oidc/oidc_provider.go

@ -397,7 +397,7 @@ func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*provid
validatedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
switch token.Method.(type) {
case *jwt.SigningMethodRSA:
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
return publicKey, nil
default:
return nil, fmt.Errorf("unsupported signing method: %v", token.Header["alg"])

150
weed/iam/oidc/oidc_provider_test.go

@ -2,6 +2,8 @@ package oidc
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
@ -188,7 +190,123 @@ func TestOIDCProviderJWTValidation(t *testing.T) {
})
_, err := provider.ValidateToken(context.Background(), token)
assert.Error(t, err)
require.Error(t, err)
assert.ErrorIs(t, err, providers.ErrProviderInvalidToken)
})
}
func TestOIDCProviderJWTValidationECDSA(t *testing.T) {
privateKey, publicKey := generateTestECKeys(t)
x, y := encodeECPublicKey(t, publicKey)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "EC",
"kid": "test-ec-key-id",
"use": "sig",
"alg": "ES256",
"crv": "P-256",
"x": x,
"y": y,
},
},
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid_configuration" {
config := map[string]interface{}{
"issuer": "http://" + r.Host,
"jwks_uri": "http://" + r.Host + "/jwks",
}
json.NewEncoder(w).Encode(config)
} else if r.URL.Path == "/jwks" {
json.NewEncoder(w).Encode(jwks)
}
}))
defer server.Close()
provider := NewOIDCProvider("test-oidc-ecdsa")
config := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
JWKSUri: server.URL + "/jwks",
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("valid token", func(t *testing.T) {
token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user789",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
claims, err := provider.ValidateToken(context.Background(), token)
require.NoError(t, err)
require.NotNil(t, claims)
assert.Equal(t, "user789", claims.Subject)
assert.Equal(t, server.URL, claims.Issuer)
})
t.Run("expired token", func(t *testing.T) {
token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user789",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-time.Hour * 2).Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
require.Error(t, err)
assert.ErrorIs(t, err, providers.ErrProviderTokenExpired)
})
t.Run("invalid signature", func(t *testing.T) {
wrongKey, _ := generateTestECKeys(t)
token := createTestECDSAJWT(t, wrongKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user789",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
require.Error(t, err)
assert.ErrorIs(t, err, providers.ErrProviderInvalidToken)
})
t.Run("invalid issuer", func(t *testing.T) {
token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{
"iss": "http://wrong-issuer",
"aud": "test-client",
"sub": "user789",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
require.Error(t, err)
assert.ErrorIs(t, err, providers.ErrProviderInvalidIssuer)
})
t.Run("invalid audience", func(t *testing.T) {
token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "wrong-client",
"sub": "user789",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
require.Error(t, err)
assert.ErrorIs(t, err, providers.ErrProviderInvalidAudience)
})
}
@ -435,6 +553,12 @@ func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
return privateKey, &privateKey.PublicKey
}
func generateTestECKeys(t *testing.T) (*ecdsa.PrivateKey, *ecdsa.PublicKey) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
return privateKey, &privateKey.PublicKey
}
func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
@ -444,11 +568,35 @@ func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaim
return tokenString
}
func createTestECDSAJWT(t *testing.T, privateKey *ecdsa.PrivateKey, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
token.Header["kid"] = "test-ec-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
// Properly encode the RSA modulus (N) as base64url
return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes())
}
func encodeECPublicKey(t *testing.T, publicKey *ecdsa.PublicKey) (string, string) {
// RFC 7518 §6.2.1.2 requires EC coordinates to be zero-padded to the full field size
curveParams := publicKey.Curve.Params()
size := (curveParams.BitSize + 7) / 8
xBytes := publicKey.X.Bytes()
yBytes := publicKey.Y.Bytes()
xPadded := make([]byte, size)
yPadded := make([]byte, size)
// Right-align the coordinate bytes and leave leading zeros for padding
copy(xPadded[size-len(xBytes):], xBytes)
copy(yPadded[size-len(yBytes):], yBytes)
return base64.RawURLEncoding.EncodeToString(xPadded),
base64.RawURLEncoding.EncodeToString(yPadded)
}
func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
jwks := map[string]interface{}{
"keys": []map[string]interface{}{

Loading…
Cancel
Save