You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							460 lines
						
					
					
						
							13 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							460 lines
						
					
					
						
							13 KiB
						
					
					
				
								package oidc
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"context"
							 | 
						|
									"crypto/rand"
							 | 
						|
									"crypto/rsa"
							 | 
						|
									"encoding/base64"
							 | 
						|
									"encoding/json"
							 | 
						|
									"net/http"
							 | 
						|
									"net/http/httptest"
							 | 
						|
									"strings"
							 | 
						|
									"testing"
							 | 
						|
									"time"
							 | 
						|
								
							 | 
						|
									"github.com/golang-jwt/jwt/v5"
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/iam/providers"
							 | 
						|
									"github.com/stretchr/testify/assert"
							 | 
						|
									"github.com/stretchr/testify/require"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// TestOIDCProviderInitialization tests OIDC provider initialization
							 | 
						|
								func TestOIDCProviderInitialization(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name    string
							 | 
						|
										config  *OIDCConfig
							 | 
						|
										wantErr bool
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name: "valid config",
							 | 
						|
											config: &OIDCConfig{
							 | 
						|
												Issuer:   "https://accounts.google.com",
							 | 
						|
												ClientID: "test-client-id",
							 | 
						|
												JWKSUri:  "https://www.googleapis.com/oauth2/v3/certs",
							 | 
						|
											},
							 | 
						|
											wantErr: false,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name: "missing issuer",
							 | 
						|
											config: &OIDCConfig{
							 | 
						|
												ClientID: "test-client-id",
							 | 
						|
											},
							 | 
						|
											wantErr: true,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name: "missing client id",
							 | 
						|
											config: &OIDCConfig{
							 | 
						|
												Issuer: "https://accounts.google.com",
							 | 
						|
											},
							 | 
						|
											wantErr: true,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name: "invalid issuer url",
							 | 
						|
											config: &OIDCConfig{
							 | 
						|
												Issuer:   "not-a-url",
							 | 
						|
												ClientID: "test-client-id",
							 | 
						|
											},
							 | 
						|
											wantErr: true,
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											provider := NewOIDCProvider("test-provider")
							 | 
						|
								
							 | 
						|
											err := provider.Initialize(tt.config)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												assert.Error(t, err)
							 | 
						|
											} else {
							 | 
						|
												assert.NoError(t, err)
							 | 
						|
												assert.Equal(t, "test-provider", provider.Name())
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestOIDCProviderJWTValidation tests JWT token validation
							 | 
						|
								func TestOIDCProviderJWTValidation(t *testing.T) {
							 | 
						|
									// Set up test server with JWKS endpoint
							 | 
						|
									privateKey, publicKey := generateTestKeys(t)
							 | 
						|
								
							 | 
						|
									jwks := map[string]interface{}{
							 | 
						|
										"keys": []map[string]interface{}{
							 | 
						|
											{
							 | 
						|
												"kty": "RSA",
							 | 
						|
												"kid": "test-key-id",
							 | 
						|
												"use": "sig",
							 | 
						|
												"alg": "RS256",
							 | 
						|
												"n":   encodePublicKey(t, publicKey),
							 | 
						|
												"e":   "AQAB",
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
							 | 
						|
										if r.URL.Path == "/.well-known/openid_configuration" {
							 | 
						|
											config := map[string]interface{}{
							 | 
						|
												"issuer":   "http://" + r.Host,
							 | 
						|
												"jwks_uri": "http://" + r.Host + "/jwks",
							 | 
						|
											}
							 | 
						|
											json.NewEncoder(w).Encode(config)
							 | 
						|
										} else if r.URL.Path == "/jwks" {
							 | 
						|
											json.NewEncoder(w).Encode(jwks)
							 | 
						|
										}
							 | 
						|
									}))
							 | 
						|
									defer server.Close()
							 | 
						|
								
							 | 
						|
									provider := NewOIDCProvider("test-oidc")
							 | 
						|
									config := &OIDCConfig{
							 | 
						|
										Issuer:   server.URL,
							 | 
						|
										ClientID: "test-client",
							 | 
						|
										JWKSUri:  server.URL + "/jwks",
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									err := provider.Initialize(config)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
								
							 | 
						|
									t.Run("valid token", func(t *testing.T) {
							 | 
						|
										// Create valid JWT token
							 | 
						|
										token := createTestJWT(t, privateKey, jwt.MapClaims{
							 | 
						|
											"iss":   server.URL,
							 | 
						|
											"aud":   "test-client",
							 | 
						|
											"sub":   "user123",
							 | 
						|
											"exp":   time.Now().Add(time.Hour).Unix(),
							 | 
						|
											"iat":   time.Now().Unix(),
							 | 
						|
											"email": "user@example.com",
							 | 
						|
											"name":  "Test User",
							 | 
						|
										})
							 | 
						|
								
							 | 
						|
										claims, err := provider.ValidateToken(context.Background(), token)
							 | 
						|
										require.NoError(t, err)
							 | 
						|
										require.NotNil(t, claims)
							 | 
						|
										assert.Equal(t, "user123", claims.Subject)
							 | 
						|
										assert.Equal(t, server.URL, claims.Issuer)
							 | 
						|
								
							 | 
						|
										email, exists := claims.GetClaimString("email")
							 | 
						|
										assert.True(t, exists)
							 | 
						|
										assert.Equal(t, "user@example.com", email)
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("valid token with array audience", func(t *testing.T) {
							 | 
						|
										// Create valid JWT token with audience as an array (per RFC 7519)
							 | 
						|
										token := createTestJWT(t, privateKey, jwt.MapClaims{
							 | 
						|
											"iss":   server.URL,
							 | 
						|
											"aud":   []string{"test-client", "another-client"},
							 | 
						|
											"sub":   "user456",
							 | 
						|
											"exp":   time.Now().Add(time.Hour).Unix(),
							 | 
						|
											"iat":   time.Now().Unix(),
							 | 
						|
											"email": "user2@example.com",
							 | 
						|
											"name":  "Test User 2",
							 | 
						|
										})
							 | 
						|
								
							 | 
						|
										claims, err := provider.ValidateToken(context.Background(), token)
							 | 
						|
										require.NoError(t, err)
							 | 
						|
										require.NotNil(t, claims)
							 | 
						|
										assert.Equal(t, "user456", claims.Subject)
							 | 
						|
										assert.Equal(t, server.URL, claims.Issuer)
							 | 
						|
								
							 | 
						|
										email, exists := claims.GetClaimString("email")
							 | 
						|
										assert.True(t, exists)
							 | 
						|
										assert.Equal(t, "user2@example.com", email)
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("expired token", func(t *testing.T) {
							 | 
						|
										// Create expired JWT token
							 | 
						|
										token := createTestJWT(t, privateKey, jwt.MapClaims{
							 | 
						|
											"iss": server.URL,
							 | 
						|
											"aud": "test-client",
							 | 
						|
											"sub": "user123",
							 | 
						|
											"exp": time.Now().Add(-time.Hour).Unix(), // Expired
							 | 
						|
											"iat": time.Now().Add(-time.Hour * 2).Unix(),
							 | 
						|
										})
							 | 
						|
								
							 | 
						|
										_, err := provider.ValidateToken(context.Background(), token)
							 | 
						|
										assert.Error(t, err)
							 | 
						|
										assert.Contains(t, err.Error(), "expired")
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("invalid signature", func(t *testing.T) {
							 | 
						|
										// Create token with wrong key
							 | 
						|
										wrongKey, _ := generateTestKeys(t)
							 | 
						|
										token := createTestJWT(t, wrongKey, jwt.MapClaims{
							 | 
						|
											"iss": server.URL,
							 | 
						|
											"aud": "test-client",
							 | 
						|
											"sub": "user123",
							 | 
						|
											"exp": time.Now().Add(time.Hour).Unix(),
							 | 
						|
											"iat": time.Now().Unix(),
							 | 
						|
										})
							 | 
						|
								
							 | 
						|
										_, err := provider.ValidateToken(context.Background(), token)
							 | 
						|
										assert.Error(t, err)
							 | 
						|
									})
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestOIDCProviderAuthentication tests authentication flow
							 | 
						|
								func TestOIDCProviderAuthentication(t *testing.T) {
							 | 
						|
									// Set up test OIDC provider
							 | 
						|
									privateKey, publicKey := generateTestKeys(t)
							 | 
						|
								
							 | 
						|
									server := setupOIDCTestServer(t, publicKey)
							 | 
						|
									defer server.Close()
							 | 
						|
								
							 | 
						|
									provider := NewOIDCProvider("test-oidc")
							 | 
						|
									config := &OIDCConfig{
							 | 
						|
										Issuer:   server.URL,
							 | 
						|
										ClientID: "test-client",
							 | 
						|
										JWKSUri:  server.URL + "/jwks",
							 | 
						|
										RoleMapping: &providers.RoleMapping{
							 | 
						|
											Rules: []providers.MappingRule{
							 | 
						|
												{
							 | 
						|
													Claim: "email",
							 | 
						|
													Value: "*@example.com",
							 | 
						|
													Role:  "arn:seaweed:iam::role/UserRole",
							 | 
						|
												},
							 | 
						|
												{
							 | 
						|
													Claim: "groups",
							 | 
						|
													Value: "admins",
							 | 
						|
													Role:  "arn:seaweed:iam::role/AdminRole",
							 | 
						|
												},
							 | 
						|
											},
							 | 
						|
											DefaultRole: "arn:seaweed:iam::role/GuestRole",
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									err := provider.Initialize(config)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
								
							 | 
						|
									t.Run("successful authentication", func(t *testing.T) {
							 | 
						|
										token := createTestJWT(t, privateKey, jwt.MapClaims{
							 | 
						|
											"iss":    server.URL,
							 | 
						|
											"aud":    "test-client",
							 | 
						|
											"sub":    "user123",
							 | 
						|
											"exp":    time.Now().Add(time.Hour).Unix(),
							 | 
						|
											"iat":    time.Now().Unix(),
							 | 
						|
											"email":  "user@example.com",
							 | 
						|
											"name":   "Test User",
							 | 
						|
											"groups": []string{"users", "developers"},
							 | 
						|
										})
							 | 
						|
								
							 | 
						|
										identity, err := provider.Authenticate(context.Background(), token)
							 | 
						|
										require.NoError(t, err)
							 | 
						|
										require.NotNil(t, identity)
							 | 
						|
										assert.Equal(t, "user123", identity.UserID)
							 | 
						|
										assert.Equal(t, "user@example.com", identity.Email)
							 | 
						|
										assert.Equal(t, "Test User", identity.DisplayName)
							 | 
						|
										assert.Equal(t, "test-oidc", identity.Provider)
							 | 
						|
										assert.Contains(t, identity.Groups, "users")
							 | 
						|
										assert.Contains(t, identity.Groups, "developers")
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("authentication with invalid token", func(t *testing.T) {
							 | 
						|
										_, err := provider.Authenticate(context.Background(), "invalid-token")
							 | 
						|
										assert.Error(t, err)
							 | 
						|
									})
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestOIDCProviderUserInfo tests user info retrieval
							 | 
						|
								func TestOIDCProviderUserInfo(t *testing.T) {
							 | 
						|
									// Set up test server with UserInfo endpoint
							 | 
						|
									server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
							 | 
						|
										if r.URL.Path == "/userinfo" {
							 | 
						|
											// Check for Authorization header
							 | 
						|
											authHeader := r.Header.Get("Authorization")
							 | 
						|
											if !strings.HasPrefix(authHeader, "Bearer ") {
							 | 
						|
												w.WriteHeader(http.StatusUnauthorized)
							 | 
						|
												w.Write([]byte(`{"error": "unauthorized"}`))
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											accessToken := strings.TrimPrefix(authHeader, "Bearer ")
							 | 
						|
								
							 | 
						|
											// Return 401 for explicitly invalid tokens
							 | 
						|
											if accessToken == "invalid-token" {
							 | 
						|
												w.WriteHeader(http.StatusUnauthorized)
							 | 
						|
												w.Write([]byte(`{"error": "invalid_token"}`))
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Mock user info response
							 | 
						|
											userInfo := map[string]interface{}{
							 | 
						|
												"sub":    "user123",
							 | 
						|
												"email":  "user@example.com",
							 | 
						|
												"name":   "Test User",
							 | 
						|
												"groups": []string{"users", "developers"},
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Customize response based on token
							 | 
						|
											if strings.Contains(accessToken, "admin") {
							 | 
						|
												userInfo["groups"] = []string{"admins"}
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											w.Header().Set("Content-Type", "application/json")
							 | 
						|
											json.NewEncoder(w).Encode(userInfo)
							 | 
						|
										}
							 | 
						|
									}))
							 | 
						|
									defer server.Close()
							 | 
						|
								
							 | 
						|
									provider := NewOIDCProvider("test-oidc")
							 | 
						|
									config := &OIDCConfig{
							 | 
						|
										Issuer:      server.URL,
							 | 
						|
										ClientID:    "test-client",
							 | 
						|
										UserInfoUri: server.URL + "/userinfo",
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									err := provider.Initialize(config)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
								
							 | 
						|
									t.Run("get user info with access token", func(t *testing.T) {
							 | 
						|
										// Test using access token (real UserInfo endpoint call)
							 | 
						|
										identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token")
							 | 
						|
										require.NoError(t, err)
							 | 
						|
										require.NotNil(t, identity)
							 | 
						|
										assert.Equal(t, "user123", identity.UserID)
							 | 
						|
										assert.Equal(t, "user@example.com", identity.Email)
							 | 
						|
										assert.Equal(t, "Test User", identity.DisplayName)
							 | 
						|
										assert.Contains(t, identity.Groups, "users")
							 | 
						|
										assert.Contains(t, identity.Groups, "developers")
							 | 
						|
										assert.Equal(t, "test-oidc", identity.Provider)
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("get admin user info", func(t *testing.T) {
							 | 
						|
										// Test admin token response
							 | 
						|
										identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token")
							 | 
						|
										require.NoError(t, err)
							 | 
						|
										require.NotNil(t, identity)
							 | 
						|
										assert.Equal(t, "user123", identity.UserID)
							 | 
						|
										assert.Contains(t, identity.Groups, "admins")
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("get user info without token", func(t *testing.T) {
							 | 
						|
										// Test without access token (should fail)
							 | 
						|
										_, err := provider.GetUserInfoWithToken(context.Background(), "")
							 | 
						|
										assert.Error(t, err)
							 | 
						|
										assert.Contains(t, err.Error(), "access token cannot be empty")
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("get user info with invalid token", func(t *testing.T) {
							 | 
						|
										// Test with invalid access token (should get 401)
							 | 
						|
										_, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token")
							 | 
						|
										assert.Error(t, err)
							 | 
						|
										assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401")
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("get user info with custom claims mapping", func(t *testing.T) {
							 | 
						|
										// Create provider with custom claims mapping
							 | 
						|
										customProvider := NewOIDCProvider("test-custom-oidc")
							 | 
						|
										customConfig := &OIDCConfig{
							 | 
						|
											Issuer:      server.URL,
							 | 
						|
											ClientID:    "test-client",
							 | 
						|
											UserInfoUri: server.URL + "/userinfo",
							 | 
						|
											ClaimsMapping: map[string]string{
							 | 
						|
												"customEmail": "email",
							 | 
						|
												"customName":  "name",
							 | 
						|
											},
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										err := customProvider.Initialize(customConfig)
							 | 
						|
										require.NoError(t, err)
							 | 
						|
								
							 | 
						|
										identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token")
							 | 
						|
										require.NoError(t, err)
							 | 
						|
										require.NotNil(t, identity)
							 | 
						|
								
							 | 
						|
										// Standard claims should still work
							 | 
						|
										assert.Equal(t, "user123", identity.UserID)
							 | 
						|
										assert.Equal(t, "user@example.com", identity.Email)
							 | 
						|
										assert.Equal(t, "Test User", identity.DisplayName)
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									t.Run("get user info with empty id", func(t *testing.T) {
							 | 
						|
										_, err := provider.GetUserInfo(context.Background(), "")
							 | 
						|
										assert.Error(t, err)
							 | 
						|
									})
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Helper functions for testing
							 | 
						|
								
							 | 
						|
								func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
							 | 
						|
									privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
									return privateKey, &privateKey.PublicKey
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
							 | 
						|
									token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
							 | 
						|
									token.Header["kid"] = "test-key-id"
							 | 
						|
								
							 | 
						|
									tokenString, err := token.SignedString(privateKey)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
									return tokenString
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
							 | 
						|
									// Properly encode the RSA modulus (N) as base64url
							 | 
						|
									return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes())
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
							 | 
						|
									jwks := map[string]interface{}{
							 | 
						|
										"keys": []map[string]interface{}{
							 | 
						|
											{
							 | 
						|
												"kty": "RSA",
							 | 
						|
												"kid": "test-key-id",
							 | 
						|
												"use": "sig",
							 | 
						|
												"alg": "RS256",
							 | 
						|
												"n":   encodePublicKey(t, publicKey),
							 | 
						|
												"e":   "AQAB",
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
							 | 
						|
										switch r.URL.Path {
							 | 
						|
										case "/.well-known/openid_configuration":
							 | 
						|
											config := map[string]interface{}{
							 | 
						|
												"issuer":            "http://" + r.Host,
							 | 
						|
												"jwks_uri":          "http://" + r.Host + "/jwks",
							 | 
						|
												"userinfo_endpoint": "http://" + r.Host + "/userinfo",
							 | 
						|
											}
							 | 
						|
											json.NewEncoder(w).Encode(config)
							 | 
						|
										case "/jwks":
							 | 
						|
											json.NewEncoder(w).Encode(jwks)
							 | 
						|
										case "/userinfo":
							 | 
						|
											// Mock UserInfo endpoint
							 | 
						|
											authHeader := r.Header.Get("Authorization")
							 | 
						|
											if !strings.HasPrefix(authHeader, "Bearer ") {
							 | 
						|
												w.WriteHeader(http.StatusUnauthorized)
							 | 
						|
												w.Write([]byte(`{"error": "unauthorized"}`))
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											accessToken := strings.TrimPrefix(authHeader, "Bearer ")
							 | 
						|
								
							 | 
						|
											// Return 401 for explicitly invalid tokens
							 | 
						|
											if accessToken == "invalid-token" {
							 | 
						|
												w.WriteHeader(http.StatusUnauthorized)
							 | 
						|
												w.Write([]byte(`{"error": "invalid_token"}`))
							 | 
						|
												return
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Mock user info response based on access token
							 | 
						|
											userInfo := map[string]interface{}{
							 | 
						|
												"sub":    "user123",
							 | 
						|
												"email":  "user@example.com",
							 | 
						|
												"name":   "Test User",
							 | 
						|
												"groups": []string{"users", "developers"},
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Customize response based on token
							 | 
						|
											if strings.Contains(accessToken, "admin") {
							 | 
						|
												userInfo["groups"] = []string{"admins"}
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											w.Header().Set("Content-Type", "application/json")
							 | 
						|
											json.NewEncoder(w).Encode(userInfo)
							 | 
						|
										default:
							 | 
						|
											http.NotFound(w, r)
							 | 
						|
										}
							 | 
						|
									}))
							 | 
						|
								}
							 |