diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go index dda970665..70e69e3fd 100644 --- a/weed/iam/oidc/oidc_provider.go +++ b/weed/iam/oidc/oidc_provider.go @@ -16,11 +16,13 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/iam/providers" + "golang.org/x/sync/singleflight" ) // OIDCProvider implements OpenID Connect authentication @@ -32,6 +34,8 @@ type OIDCProvider struct { httpClient *http.Client jwksFetchedAt time.Time jwksTTL time.Duration + jwksMutex sync.RWMutex // Protects jwksCache and jwksFetchedAt + jwksFetchGroup singleflight.Group // Prevents duplicate concurrent JWKS fetches } // OIDCConfig holds OIDC provider configuration @@ -551,45 +555,110 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims) } // getPublicKey retrieves the public key for the given key ID from JWKS +// Uses singleflight pattern to prevent duplicate concurrent JWKS fetches +// and proper locking to avoid use-after-unlock bugs func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) { - // Fetch JWKS if not cached or refresh if expired - if p.jwksCache == nil || (!p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) > p.jwksTTL) { - if err := p.fetchJWKS(ctx); err != nil { - return nil, fmt.Errorf("failed to fetch JWKS: %v", err) + // Fast path: Try to find key in cache with read lock + p.jwksMutex.RLock() + cacheValid := p.jwksCache != nil && !p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) <= p.jwksTTL + if cacheValid { + // Make a copy of the JWK to avoid use-after-unlock + jwkCopy := p.findAndCopyKey(kid) + p.jwksMutex.RUnlock() + + if jwkCopy != nil { + // Parse key outside lock for better concurrency + return p.parseJWK(jwkCopy) } + + // Key not found in valid cache - force refresh for potential key rotation + // Use distinct singleflight key to ensure refresh happens regardless of TTL + // This handles mid-TTL key rotation where IdP adds new keys + _, err, _ := p.jwksFetchGroup.Do("jwks-force-refresh", func() (interface{}, error) { + return nil, p.fetchJWKS(ctx) + }) + + if err != nil { + return nil, fmt.Errorf("failed to refresh JWKS for missing kid: %v", err) + } + + // Search in refreshed JWKS + p.jwksMutex.RLock() + jwkCopy = p.findAndCopyKey(kid) + p.jwksMutex.RUnlock() + + if jwkCopy != nil { + return p.parseJWK(jwkCopy) + } + + return nil, fmt.Errorf("key with ID %s not found in JWKS after forced refresh", kid) + } else { + p.jwksMutex.RUnlock() } - // Find the key with matching kid - for _, key := range p.jwksCache.Keys { - if key.Kid == kid { - return p.parseJWK(&key) + // Slow path: Cache expired or nil - need to fetch/refresh JWKS + // Use singleflight to ensure only one fetch happens even with many concurrent requests + _, err, _ := p.jwksFetchGroup.Do("jwks", func() (interface{}, error) { + // Double-check: another goroutine may have just fetched + p.jwksMutex.RLock() + stillNeedFetch := p.jwksCache == nil || p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) > p.jwksTTL + p.jwksMutex.RUnlock() + + if stillNeedFetch { + // Fetch JWKS WITHOUT holding any locks (critical for performance) + return nil, p.fetchJWKS(ctx) } + return nil, nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %v", err) } - // Key not found in cache. Refresh JWKS once to handle key rotation and retry. - if err := p.fetchJWKS(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err) + // Search in newly fetched JWKS + p.jwksMutex.RLock() + jwkCopy := p.findAndCopyKey(kid) + p.jwksMutex.RUnlock() + + if jwkCopy != nil { + return p.parseJWK(jwkCopy) } - for _, key := range p.jwksCache.Keys { - if key.Kid == kid { - return p.parseJWK(&key) + + // Key not found even after refresh + return nil, fmt.Errorf("key with ID %s not found in JWKS", kid) +} + +// findAndCopyKey searches for a key and returns a copy (not a pointer) +// Must be called with at least a read lock held +func (p *OIDCProvider) findAndCopyKey(kid string) *JWK { + if p.jwksCache == nil { + return nil + } + for i := range p.jwksCache.Keys { + if p.jwksCache.Keys[i].Kid == kid { + // Return a copy to avoid use-after-unlock bugs + keyCopy := p.jwksCache.Keys[i] + return &keyCopy } } - return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid) + return nil } -// fetchJWKS fetches the JWKS from the provider +// fetchJWKS fetches the JWKS from the provider WITHOUT holding locks during HTTP call +// This is critical for performance - HTTP calls can take seconds func (p *OIDCProvider) fetchJWKS(ctx context.Context) error { jwksURL := p.config.JWKSUri if jwksURL == "" { jwksURL = strings.TrimSuffix(p.config.Issuer, "/") + "/.well-known/jwks.json" } + // Create request WITHOUT holding any locks req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil) if err != nil { return fmt.Errorf("failed to create JWKS request: %v", err) } + // Make HTTP call WITHOUT holding any locks (critical for performance) resp, err := p.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to fetch JWKS: %v", err) @@ -600,13 +669,18 @@ func (p *OIDCProvider) fetchJWKS(ctx context.Context) error { return fmt.Errorf("JWKS endpoint returned status: %d", resp.StatusCode) } + // Decode response WITHOUT holding any locks var jwks JWKS if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { return fmt.Errorf("failed to decode JWKS response: %v", err) } + // Only acquire write lock for the actual cache update (very fast) + p.jwksMutex.Lock() p.jwksCache = &jwks p.jwksFetchedAt = time.Now() + p.jwksMutex.Unlock() + glog.V(3).Infof("Fetched JWKS with %d keys from %s", len(jwks.Keys), jwksURL) return nil } diff --git a/weed/iam/oidc/oidc_race_test.go b/weed/iam/oidc/oidc_race_test.go new file mode 100644 index 000000000..d6a87e912 --- /dev/null +++ b/weed/iam/oidc/oidc_race_test.go @@ -0,0 +1,297 @@ +package oidc + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// TestJWKSCacheConcurrentRefresh tests that concurrent JWKS refresh doesn't cause race conditions +func TestJWKSCacheConcurrentRefresh(t *testing.T) { + // Generate RSA key pair for testing + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Track JWKS fetch count + var fetchCount int + var fetchMutex sync.Mutex + + // Create mock JWKS server + jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchMutex.Lock() + fetchCount++ + currentFetch := fetchCount + fetchMutex.Unlock() + + t.Logf("JWKS fetch #%d", currentFetch) + + // Add small delay to simulate network latency + time.Sleep(10 * time.Millisecond) + + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "RSA", + "kid": "test-key-1", + "use": "sig", + "alg": "RS256", + "n": encodePublicKey(t, &privateKey.PublicKey), + "e": "AQAB", + }, + }, + } + json.NewEncoder(w).Encode(jwks) + })) + defer jwksServer.Close() + + // Initialize OIDC provider with very short TTL to force refresh + provider := NewOIDCProvider("test") + config := &OIDCConfig{ + Issuer: "https://test.example.com", + ClientID: "test-client-id", + JWKSUri: jwksServer.URL, + JWKSCacheTTLSeconds: 1, // Very short TTL + } + + if err := provider.Initialize(config); err != nil { + t.Fatalf("Failed to initialize provider: %v", err) + } + + // Generate valid JWT token + token := generateTestToken(t, privateKey, "test-key-1", "https://test.example.com", "test-client-id") + + // Test 1: Concurrent validation with initial JWKS fetch + t.Run("concurrent_initial_fetch", func(t *testing.T) { + fetchCount = 0 + provider.jwksCache = nil // Reset cache + + var wg sync.WaitGroup + concurrency := 50 + successCount := 0 + var successMutex sync.Mutex + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + ctx := context.Background() + _, err := provider.ValidateToken(ctx, token) + if err == nil { + successMutex.Lock() + successCount++ + successMutex.Unlock() + } else { + t.Logf("Goroutine %d validation error: %v", id, err) + } + }(i) + } + + wg.Wait() + + // All validations should succeed + if successCount != concurrency { + t.Errorf("Expected %d successful validations, got %d", concurrency, successCount) + } + + // JWKS should be fetched only once or very few times (due to double-checked locking) + t.Logf("JWKS fetched %d times for %d concurrent requests", fetchCount, concurrency) + if fetchCount > 5 { + t.Errorf("Too many JWKS fetches: %d (expected <= 5 due to locking)", fetchCount) + } + }) + + // Test 2: Concurrent validation during cache expiration + t.Run("concurrent_cache_expiration", func(t *testing.T) { + fetchCount = 0 + + // Pre-populate cache + ctx := context.Background() + _, err := provider.ValidateToken(ctx, token) + if err != nil { + t.Fatalf("Initial validation failed: %v", err) + } + + initialFetchCount := fetchCount + t.Logf("Initial fetch count: %d", initialFetchCount) + + // Wait for cache to expire + time.Sleep(1100 * time.Millisecond) + + // Concurrent validations after cache expiration + var wg sync.WaitGroup + concurrency := 50 + successCount := 0 + var successMutex sync.Mutex + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + _, err := provider.ValidateToken(ctx, token) + if err == nil { + successMutex.Lock() + successCount++ + successMutex.Unlock() + } else { + t.Logf("Goroutine %d validation error: %v", id, err) + } + }(i) + } + + wg.Wait() + + // All validations should succeed + if successCount != concurrency { + t.Errorf("Expected %d successful validations, got %d", concurrency, successCount) + } + + // Should only fetch once more after expiration + refreshFetchCount := fetchCount - initialFetchCount + t.Logf("JWKS refreshed %d times for %d concurrent requests after expiration", refreshFetchCount, concurrency) + if refreshFetchCount > 5 { + t.Errorf("Too many JWKS refreshes: %d (expected <= 5)", refreshFetchCount) + } + }) + + // Test 3: Race detector test - rapid concurrent access + t.Run("race_detector", func(t *testing.T) { + // Reset cache + provider.jwksCache = nil + fetchCount = 0 + + var wg sync.WaitGroup + concurrency := 100 + iterations := 10 + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + for j := 0; j < iterations; j++ { + provider.ValidateToken(ctx, token) + // Small random delay + time.Sleep(time.Millisecond * time.Duration(j%3)) + } + }() + } + + wg.Wait() + t.Logf("Completed %d iterations across %d goroutines without race", iterations, concurrency) + }) +} + +// TestJWKSCacheIsolation tests that cache updates don't interfere with ongoing reads +func TestJWKSCacheIsolation(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Create JWKS server that alternates between two key sets + keyVersion := 0 + var keyMutex sync.Mutex + + jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + keyMutex.Lock() + currentVersion := keyVersion + keyVersion++ + keyMutex.Unlock() + + // Simulate slow response + time.Sleep(50 * time.Millisecond) + + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "RSA", + "kid": "test-key-1", + "use": "sig", + "alg": "RS256", + "n": encodePublicKey(t, &privateKey.PublicKey), + "e": "AQAB", + }, + }, + } + + t.Logf("Serving JWKS version %d", currentVersion) + json.NewEncoder(w).Encode(jwks) + })) + defer jwksServer.Close() + + provider := NewOIDCProvider("test") + config := &OIDCConfig{ + Issuer: "https://test.example.com", + ClientID: "test-client-id", + JWKSUri: jwksServer.URL, + JWKSCacheTTLSeconds: 1, + } + + if err := provider.Initialize(config); err != nil { + t.Fatalf("Failed to initialize provider: %v", err) + } + + token := generateTestToken(t, privateKey, "test-key-1", "https://test.example.com", "test-client-id") + + // Concurrent readers and writers + var wg sync.WaitGroup + ctx := context.Background() + errorCount := 0 + var errorMutex sync.Mutex + + // Readers + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 20; j++ { + _, err := provider.ValidateToken(ctx, token) + if err != nil { + errorMutex.Lock() + errorCount++ + errorMutex.Unlock() + t.Logf("Reader %d got error: %v", id, err) + } + time.Sleep(5 * time.Millisecond) + } + }(i) + } + + wg.Wait() + + if errorCount > 0 { + t.Errorf("Got %d errors during concurrent read/write operations", errorCount) + } +} + +// Helper function to generate test JWT token +func generateTestToken(t *testing.T, privateKey *rsa.PrivateKey, kid, issuer, audience string) string { + claims := jwt.MapClaims{ + "sub": "test-user", + "iss": issuer, + "aud": audience, + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "email": "test@example.com", + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = kid + + tokenString, err := token.SignedString(privateKey) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + + return tokenString +}