You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							460 lines
						
					
					
						
							13 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							460 lines
						
					
					
						
							13 KiB
						
					
					
				| package oidc | |
| 
 | |
| import ( | |
| 	"context" | |
| 	"crypto/rand" | |
| 	"crypto/rsa" | |
| 	"encoding/base64" | |
| 	"encoding/json" | |
| 	"net/http" | |
| 	"net/http/httptest" | |
| 	"strings" | |
| 	"testing" | |
| 	"time" | |
| 
 | |
| 	"github.com/golang-jwt/jwt/v5" | |
| 	"github.com/seaweedfs/seaweedfs/weed/iam/providers" | |
| 	"github.com/stretchr/testify/assert" | |
| 	"github.com/stretchr/testify/require" | |
| ) | |
| 
 | |
| // TestOIDCProviderInitialization tests OIDC provider initialization | |
| func TestOIDCProviderInitialization(t *testing.T) { | |
| 	tests := []struct { | |
| 		name    string | |
| 		config  *OIDCConfig | |
| 		wantErr bool | |
| 	}{ | |
| 		{ | |
| 			name: "valid config", | |
| 			config: &OIDCConfig{ | |
| 				Issuer:   "https://accounts.google.com", | |
| 				ClientID: "test-client-id", | |
| 				JWKSUri:  "https://www.googleapis.com/oauth2/v3/certs", | |
| 			}, | |
| 			wantErr: false, | |
| 		}, | |
| 		{ | |
| 			name: "missing issuer", | |
| 			config: &OIDCConfig{ | |
| 				ClientID: "test-client-id", | |
| 			}, | |
| 			wantErr: true, | |
| 		}, | |
| 		{ | |
| 			name: "missing client id", | |
| 			config: &OIDCConfig{ | |
| 				Issuer: "https://accounts.google.com", | |
| 			}, | |
| 			wantErr: true, | |
| 		}, | |
| 		{ | |
| 			name: "invalid issuer url", | |
| 			config: &OIDCConfig{ | |
| 				Issuer:   "not-a-url", | |
| 				ClientID: "test-client-id", | |
| 			}, | |
| 			wantErr: true, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			provider := NewOIDCProvider("test-provider") | |
| 
 | |
| 			err := provider.Initialize(tt.config) | |
| 
 | |
| 			if tt.wantErr { | |
| 				assert.Error(t, err) | |
| 			} else { | |
| 				assert.NoError(t, err) | |
| 				assert.Equal(t, "test-provider", provider.Name()) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestOIDCProviderJWTValidation tests JWT token validation | |
| func TestOIDCProviderJWTValidation(t *testing.T) { | |
| 	// Set up test server with JWKS endpoint | |
| 	privateKey, publicKey := generateTestKeys(t) | |
| 
 | |
| 	jwks := map[string]interface{}{ | |
| 		"keys": []map[string]interface{}{ | |
| 			{ | |
| 				"kty": "RSA", | |
| 				"kid": "test-key-id", | |
| 				"use": "sig", | |
| 				"alg": "RS256", | |
| 				"n":   encodePublicKey(t, publicKey), | |
| 				"e":   "AQAB", | |
| 			}, | |
| 		}, | |
| 	} | |
| 
 | |
| 	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| 		if r.URL.Path == "/.well-known/openid_configuration" { | |
| 			config := map[string]interface{}{ | |
| 				"issuer":   "http://" + r.Host, | |
| 				"jwks_uri": "http://" + r.Host + "/jwks", | |
| 			} | |
| 			json.NewEncoder(w).Encode(config) | |
| 		} else if r.URL.Path == "/jwks" { | |
| 			json.NewEncoder(w).Encode(jwks) | |
| 		} | |
| 	})) | |
| 	defer server.Close() | |
| 
 | |
| 	provider := NewOIDCProvider("test-oidc") | |
| 	config := &OIDCConfig{ | |
| 		Issuer:   server.URL, | |
| 		ClientID: "test-client", | |
| 		JWKSUri:  server.URL + "/jwks", | |
| 	} | |
| 
 | |
| 	err := provider.Initialize(config) | |
| 	require.NoError(t, err) | |
| 
 | |
| 	t.Run("valid token", func(t *testing.T) { | |
| 		// Create valid JWT token | |
| 		token := createTestJWT(t, privateKey, jwt.MapClaims{ | |
| 			"iss":   server.URL, | |
| 			"aud":   "test-client", | |
| 			"sub":   "user123", | |
| 			"exp":   time.Now().Add(time.Hour).Unix(), | |
| 			"iat":   time.Now().Unix(), | |
| 			"email": "user@example.com", | |
| 			"name":  "Test User", | |
| 		}) | |
| 
 | |
| 		claims, err := provider.ValidateToken(context.Background(), token) | |
| 		require.NoError(t, err) | |
| 		require.NotNil(t, claims) | |
| 		assert.Equal(t, "user123", claims.Subject) | |
| 		assert.Equal(t, server.URL, claims.Issuer) | |
| 
 | |
| 		email, exists := claims.GetClaimString("email") | |
| 		assert.True(t, exists) | |
| 		assert.Equal(t, "user@example.com", email) | |
| 	}) | |
| 
 | |
| 	t.Run("valid token with array audience", func(t *testing.T) { | |
| 		// Create valid JWT token with audience as an array (per RFC 7519) | |
| 		token := createTestJWT(t, privateKey, jwt.MapClaims{ | |
| 			"iss":   server.URL, | |
| 			"aud":   []string{"test-client", "another-client"}, | |
| 			"sub":   "user456", | |
| 			"exp":   time.Now().Add(time.Hour).Unix(), | |
| 			"iat":   time.Now().Unix(), | |
| 			"email": "user2@example.com", | |
| 			"name":  "Test User 2", | |
| 		}) | |
| 
 | |
| 		claims, err := provider.ValidateToken(context.Background(), token) | |
| 		require.NoError(t, err) | |
| 		require.NotNil(t, claims) | |
| 		assert.Equal(t, "user456", claims.Subject) | |
| 		assert.Equal(t, server.URL, claims.Issuer) | |
| 
 | |
| 		email, exists := claims.GetClaimString("email") | |
| 		assert.True(t, exists) | |
| 		assert.Equal(t, "user2@example.com", email) | |
| 	}) | |
| 
 | |
| 	t.Run("expired token", func(t *testing.T) { | |
| 		// Create expired JWT token | |
| 		token := createTestJWT(t, privateKey, jwt.MapClaims{ | |
| 			"iss": server.URL, | |
| 			"aud": "test-client", | |
| 			"sub": "user123", | |
| 			"exp": time.Now().Add(-time.Hour).Unix(), // Expired | |
| 			"iat": time.Now().Add(-time.Hour * 2).Unix(), | |
| 		}) | |
| 
 | |
| 		_, err := provider.ValidateToken(context.Background(), token) | |
| 		assert.Error(t, err) | |
| 		assert.Contains(t, err.Error(), "expired") | |
| 	}) | |
| 
 | |
| 	t.Run("invalid signature", func(t *testing.T) { | |
| 		// Create token with wrong key | |
| 		wrongKey, _ := generateTestKeys(t) | |
| 		token := createTestJWT(t, wrongKey, jwt.MapClaims{ | |
| 			"iss": server.URL, | |
| 			"aud": "test-client", | |
| 			"sub": "user123", | |
| 			"exp": time.Now().Add(time.Hour).Unix(), | |
| 			"iat": time.Now().Unix(), | |
| 		}) | |
| 
 | |
| 		_, err := provider.ValidateToken(context.Background(), token) | |
| 		assert.Error(t, err) | |
| 	}) | |
| } | |
| 
 | |
| // TestOIDCProviderAuthentication tests authentication flow | |
| func TestOIDCProviderAuthentication(t *testing.T) { | |
| 	// Set up test OIDC provider | |
| 	privateKey, publicKey := generateTestKeys(t) | |
| 
 | |
| 	server := setupOIDCTestServer(t, publicKey) | |
| 	defer server.Close() | |
| 
 | |
| 	provider := NewOIDCProvider("test-oidc") | |
| 	config := &OIDCConfig{ | |
| 		Issuer:   server.URL, | |
| 		ClientID: "test-client", | |
| 		JWKSUri:  server.URL + "/jwks", | |
| 		RoleMapping: &providers.RoleMapping{ | |
| 			Rules: []providers.MappingRule{ | |
| 				{ | |
| 					Claim: "email", | |
| 					Value: "*@example.com", | |
| 					Role:  "arn:seaweed:iam::role/UserRole", | |
| 				}, | |
| 				{ | |
| 					Claim: "groups", | |
| 					Value: "admins", | |
| 					Role:  "arn:seaweed:iam::role/AdminRole", | |
| 				}, | |
| 			}, | |
| 			DefaultRole: "arn:seaweed:iam::role/GuestRole", | |
| 		}, | |
| 	} | |
| 
 | |
| 	err := provider.Initialize(config) | |
| 	require.NoError(t, err) | |
| 
 | |
| 	t.Run("successful authentication", func(t *testing.T) { | |
| 		token := createTestJWT(t, privateKey, jwt.MapClaims{ | |
| 			"iss":    server.URL, | |
| 			"aud":    "test-client", | |
| 			"sub":    "user123", | |
| 			"exp":    time.Now().Add(time.Hour).Unix(), | |
| 			"iat":    time.Now().Unix(), | |
| 			"email":  "user@example.com", | |
| 			"name":   "Test User", | |
| 			"groups": []string{"users", "developers"}, | |
| 		}) | |
| 
 | |
| 		identity, err := provider.Authenticate(context.Background(), token) | |
| 		require.NoError(t, err) | |
| 		require.NotNil(t, identity) | |
| 		assert.Equal(t, "user123", identity.UserID) | |
| 		assert.Equal(t, "user@example.com", identity.Email) | |
| 		assert.Equal(t, "Test User", identity.DisplayName) | |
| 		assert.Equal(t, "test-oidc", identity.Provider) | |
| 		assert.Contains(t, identity.Groups, "users") | |
| 		assert.Contains(t, identity.Groups, "developers") | |
| 	}) | |
| 
 | |
| 	t.Run("authentication with invalid token", func(t *testing.T) { | |
| 		_, err := provider.Authenticate(context.Background(), "invalid-token") | |
| 		assert.Error(t, err) | |
| 	}) | |
| } | |
| 
 | |
| // TestOIDCProviderUserInfo tests user info retrieval | |
| func TestOIDCProviderUserInfo(t *testing.T) { | |
| 	// Set up test server with UserInfo endpoint | |
| 	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| 		if r.URL.Path == "/userinfo" { | |
| 			// Check for Authorization header | |
| 			authHeader := r.Header.Get("Authorization") | |
| 			if !strings.HasPrefix(authHeader, "Bearer ") { | |
| 				w.WriteHeader(http.StatusUnauthorized) | |
| 				w.Write([]byte(`{"error": "unauthorized"}`)) | |
| 				return | |
| 			} | |
| 
 | |
| 			accessToken := strings.TrimPrefix(authHeader, "Bearer ") | |
| 
 | |
| 			// Return 401 for explicitly invalid tokens | |
| 			if accessToken == "invalid-token" { | |
| 				w.WriteHeader(http.StatusUnauthorized) | |
| 				w.Write([]byte(`{"error": "invalid_token"}`)) | |
| 				return | |
| 			} | |
| 
 | |
| 			// Mock user info response | |
| 			userInfo := map[string]interface{}{ | |
| 				"sub":    "user123", | |
| 				"email":  "user@example.com", | |
| 				"name":   "Test User", | |
| 				"groups": []string{"users", "developers"}, | |
| 			} | |
| 
 | |
| 			// Customize response based on token | |
| 			if strings.Contains(accessToken, "admin") { | |
| 				userInfo["groups"] = []string{"admins"} | |
| 			} | |
| 
 | |
| 			w.Header().Set("Content-Type", "application/json") | |
| 			json.NewEncoder(w).Encode(userInfo) | |
| 		} | |
| 	})) | |
| 	defer server.Close() | |
| 
 | |
| 	provider := NewOIDCProvider("test-oidc") | |
| 	config := &OIDCConfig{ | |
| 		Issuer:      server.URL, | |
| 		ClientID:    "test-client", | |
| 		UserInfoUri: server.URL + "/userinfo", | |
| 	} | |
| 
 | |
| 	err := provider.Initialize(config) | |
| 	require.NoError(t, err) | |
| 
 | |
| 	t.Run("get user info with access token", func(t *testing.T) { | |
| 		// Test using access token (real UserInfo endpoint call) | |
| 		identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token") | |
| 		require.NoError(t, err) | |
| 		require.NotNil(t, identity) | |
| 		assert.Equal(t, "user123", identity.UserID) | |
| 		assert.Equal(t, "user@example.com", identity.Email) | |
| 		assert.Equal(t, "Test User", identity.DisplayName) | |
| 		assert.Contains(t, identity.Groups, "users") | |
| 		assert.Contains(t, identity.Groups, "developers") | |
| 		assert.Equal(t, "test-oidc", identity.Provider) | |
| 	}) | |
| 
 | |
| 	t.Run("get admin user info", func(t *testing.T) { | |
| 		// Test admin token response | |
| 		identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token") | |
| 		require.NoError(t, err) | |
| 		require.NotNil(t, identity) | |
| 		assert.Equal(t, "user123", identity.UserID) | |
| 		assert.Contains(t, identity.Groups, "admins") | |
| 	}) | |
| 
 | |
| 	t.Run("get user info without token", func(t *testing.T) { | |
| 		// Test without access token (should fail) | |
| 		_, err := provider.GetUserInfoWithToken(context.Background(), "") | |
| 		assert.Error(t, err) | |
| 		assert.Contains(t, err.Error(), "access token cannot be empty") | |
| 	}) | |
| 
 | |
| 	t.Run("get user info with invalid token", func(t *testing.T) { | |
| 		// Test with invalid access token (should get 401) | |
| 		_, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token") | |
| 		assert.Error(t, err) | |
| 		assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401") | |
| 	}) | |
| 
 | |
| 	t.Run("get user info with custom claims mapping", func(t *testing.T) { | |
| 		// Create provider with custom claims mapping | |
| 		customProvider := NewOIDCProvider("test-custom-oidc") | |
| 		customConfig := &OIDCConfig{ | |
| 			Issuer:      server.URL, | |
| 			ClientID:    "test-client", | |
| 			UserInfoUri: server.URL + "/userinfo", | |
| 			ClaimsMapping: map[string]string{ | |
| 				"customEmail": "email", | |
| 				"customName":  "name", | |
| 			}, | |
| 		} | |
| 
 | |
| 		err := customProvider.Initialize(customConfig) | |
| 		require.NoError(t, err) | |
| 
 | |
| 		identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token") | |
| 		require.NoError(t, err) | |
| 		require.NotNil(t, identity) | |
| 
 | |
| 		// Standard claims should still work | |
| 		assert.Equal(t, "user123", identity.UserID) | |
| 		assert.Equal(t, "user@example.com", identity.Email) | |
| 		assert.Equal(t, "Test User", identity.DisplayName) | |
| 	}) | |
| 
 | |
| 	t.Run("get user info with empty id", func(t *testing.T) { | |
| 		_, err := provider.GetUserInfo(context.Background(), "") | |
| 		assert.Error(t, err) | |
| 	}) | |
| } | |
| 
 | |
| // Helper functions for testing | |
|  | |
| func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { | |
| 	privateKey, err := rsa.GenerateKey(rand.Reader, 2048) | |
| 	require.NoError(t, err) | |
| 	return privateKey, &privateKey.PublicKey | |
| } | |
| 
 | |
| func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string { | |
| 	token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) | |
| 	token.Header["kid"] = "test-key-id" | |
| 
 | |
| 	tokenString, err := token.SignedString(privateKey) | |
| 	require.NoError(t, err) | |
| 	return tokenString | |
| } | |
| 
 | |
| func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string { | |
| 	// Properly encode the RSA modulus (N) as base64url | |
| 	return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()) | |
| } | |
| 
 | |
| func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server { | |
| 	jwks := map[string]interface{}{ | |
| 		"keys": []map[string]interface{}{ | |
| 			{ | |
| 				"kty": "RSA", | |
| 				"kid": "test-key-id", | |
| 				"use": "sig", | |
| 				"alg": "RS256", | |
| 				"n":   encodePublicKey(t, publicKey), | |
| 				"e":   "AQAB", | |
| 			}, | |
| 		}, | |
| 	} | |
| 
 | |
| 	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| 		switch r.URL.Path { | |
| 		case "/.well-known/openid_configuration": | |
| 			config := map[string]interface{}{ | |
| 				"issuer":            "http://" + r.Host, | |
| 				"jwks_uri":          "http://" + r.Host + "/jwks", | |
| 				"userinfo_endpoint": "http://" + r.Host + "/userinfo", | |
| 			} | |
| 			json.NewEncoder(w).Encode(config) | |
| 		case "/jwks": | |
| 			json.NewEncoder(w).Encode(jwks) | |
| 		case "/userinfo": | |
| 			// Mock UserInfo endpoint | |
| 			authHeader := r.Header.Get("Authorization") | |
| 			if !strings.HasPrefix(authHeader, "Bearer ") { | |
| 				w.WriteHeader(http.StatusUnauthorized) | |
| 				w.Write([]byte(`{"error": "unauthorized"}`)) | |
| 				return | |
| 			} | |
| 
 | |
| 			accessToken := strings.TrimPrefix(authHeader, "Bearer ") | |
| 
 | |
| 			// Return 401 for explicitly invalid tokens | |
| 			if accessToken == "invalid-token" { | |
| 				w.WriteHeader(http.StatusUnauthorized) | |
| 				w.Write([]byte(`{"error": "invalid_token"}`)) | |
| 				return | |
| 			} | |
| 
 | |
| 			// Mock user info response based on access token | |
| 			userInfo := map[string]interface{}{ | |
| 				"sub":    "user123", | |
| 				"email":  "user@example.com", | |
| 				"name":   "Test User", | |
| 				"groups": []string{"users", "developers"}, | |
| 			} | |
| 
 | |
| 			// Customize response based on token | |
| 			if strings.Contains(accessToken, "admin") { | |
| 				userInfo["groups"] = []string{"admins"} | |
| 			} | |
| 
 | |
| 			w.Header().Set("Content-Type", "application/json") | |
| 			json.NewEncoder(w).Encode(userInfo) | |
| 		default: | |
| 			http.NotFound(w, r) | |
| 		} | |
| 	})) | |
| }
 |