diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go index b95b17fbd..dda970665 100644 --- a/weed/iam/oidc/oidc_provider.go +++ b/weed/iam/oidc/oidc_provider.go @@ -397,7 +397,7 @@ func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*provid validatedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { // Verify signing method switch token.Method.(type) { - case *jwt.SigningMethodRSA: + case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA: return publicKey, nil default: return nil, fmt.Errorf("unsupported signing method: %v", token.Header["alg"]) diff --git a/weed/iam/oidc/oidc_provider_test.go b/weed/iam/oidc/oidc_provider_test.go index 5a96c2c86..b226929d2 100644 --- a/weed/iam/oidc/oidc_provider_test.go +++ b/weed/iam/oidc/oidc_provider_test.go @@ -2,6 +2,8 @@ package oidc import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "encoding/base64" @@ -188,7 +190,123 @@ func TestOIDCProviderJWTValidation(t *testing.T) { }) _, err := provider.ValidateToken(context.Background(), token) - assert.Error(t, err) + require.Error(t, err) + assert.ErrorIs(t, err, providers.ErrProviderInvalidToken) + }) +} + +func TestOIDCProviderJWTValidationECDSA(t *testing.T) { + privateKey, publicKey := generateTestECKeys(t) + x, y := encodeECPublicKey(t, publicKey) + + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "EC", + "kid": "test-ec-key-id", + "use": "sig", + "alg": "ES256", + "crv": "P-256", + "x": x, + "y": y, + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid_configuration" { + config := map[string]interface{}{ + "issuer": "http://" + r.Host, + "jwks_uri": "http://" + r.Host + "/jwks", + } + json.NewEncoder(w).Encode(config) + } else if r.URL.Path == "/jwks" { + json.NewEncoder(w).Encode(jwks) + } + })) + defer server.Close() + + provider := NewOIDCProvider("test-oidc-ecdsa") + config := &OIDCConfig{ + Issuer: server.URL, + ClientID: "test-client", + JWKSUri: server.URL + "/jwks", + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("valid token", func(t *testing.T) { + token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "test-client", + "sub": "user789", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + + claims, err := provider.ValidateToken(context.Background(), token) + require.NoError(t, err) + require.NotNil(t, claims) + assert.Equal(t, "user789", claims.Subject) + assert.Equal(t, server.URL, claims.Issuer) + }) + + t.Run("expired token", func(t *testing.T) { + token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "test-client", + "sub": "user789", + "exp": time.Now().Add(-time.Hour).Unix(), + "iat": time.Now().Add(-time.Hour * 2).Unix(), + }) + + _, err := provider.ValidateToken(context.Background(), token) + require.Error(t, err) + assert.ErrorIs(t, err, providers.ErrProviderTokenExpired) + }) + + t.Run("invalid signature", func(t *testing.T) { + wrongKey, _ := generateTestECKeys(t) + token := createTestECDSAJWT(t, wrongKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "test-client", + "sub": "user789", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + + _, err := provider.ValidateToken(context.Background(), token) + require.Error(t, err) + assert.ErrorIs(t, err, providers.ErrProviderInvalidToken) + }) + + t.Run("invalid issuer", func(t *testing.T) { + token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{ + "iss": "http://wrong-issuer", + "aud": "test-client", + "sub": "user789", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + + _, err := provider.ValidateToken(context.Background(), token) + require.Error(t, err) + assert.ErrorIs(t, err, providers.ErrProviderInvalidIssuer) + }) + + t.Run("invalid audience", func(t *testing.T) { + token := createTestECDSAJWT(t, privateKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "wrong-client", + "sub": "user789", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + + _, err := provider.ValidateToken(context.Background(), token) + require.Error(t, err) + assert.ErrorIs(t, err, providers.ErrProviderInvalidAudience) }) } @@ -435,6 +553,12 @@ func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { return privateKey, &privateKey.PublicKey } +func generateTestECKeys(t *testing.T) (*ecdsa.PrivateKey, *ecdsa.PublicKey) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return privateKey, &privateKey.PublicKey +} + func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string { token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = "test-key-id" @@ -444,11 +568,35 @@ func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaim return tokenString } +func createTestECDSAJWT(t *testing.T, privateKey *ecdsa.PrivateKey, claims jwt.MapClaims) string { + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token.Header["kid"] = "test-ec-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + return tokenString +} + func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string { // Properly encode the RSA modulus (N) as base64url return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()) } +func encodeECPublicKey(t *testing.T, publicKey *ecdsa.PublicKey) (string, string) { + // RFC 7518 ยง6.2.1.2 requires EC coordinates to be zero-padded to the full field size + curveParams := publicKey.Curve.Params() + size := (curveParams.BitSize + 7) / 8 + xBytes := publicKey.X.Bytes() + yBytes := publicKey.Y.Bytes() + xPadded := make([]byte, size) + yPadded := make([]byte, size) + // Right-align the coordinate bytes and leave leading zeros for padding + copy(xPadded[size-len(xBytes):], xBytes) + copy(yPadded[size-len(yBytes):], yBytes) + return base64.RawURLEncoding.EncodeToString(xPadded), + base64.RawURLEncoding.EncodeToString(yPadded) +} + func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server { jwks := map[string]interface{}{ "keys": []map[string]interface{}{