Browse Source

Merge e1bea6495d into 78a3441b30

pull/8311/merge
YGoetschel 1 day ago
committed by GitHub
parent
commit
555ff116a3
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 106
      weed/iam/oidc/oidc_provider.go
  2. 297
      weed/iam/oidc/oidc_race_test.go

106
weed/iam/oidc/oidc_provider.go

@ -16,11 +16,13 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"golang.org/x/sync/singleflight"
)
// OIDCProvider implements OpenID Connect authentication
@ -32,6 +34,8 @@ type OIDCProvider struct {
httpClient *http.Client
jwksFetchedAt time.Time
jwksTTL time.Duration
jwksMutex sync.RWMutex // Protects jwksCache and jwksFetchedAt
jwksFetchGroup singleflight.Group // Prevents duplicate concurrent JWKS fetches
}
// OIDCConfig holds OIDC provider configuration
@ -551,45 +555,110 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims)
}
// getPublicKey retrieves the public key for the given key ID from JWKS
// Uses singleflight pattern to prevent duplicate concurrent JWKS fetches
// and proper locking to avoid use-after-unlock bugs
func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) {
// Fetch JWKS if not cached or refresh if expired
if p.jwksCache == nil || (!p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) > p.jwksTTL) {
if err := p.fetchJWKS(ctx); err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %v", err)
// Fast path: Try to find key in cache with read lock
p.jwksMutex.RLock()
cacheValid := p.jwksCache != nil && !p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) <= p.jwksTTL
if cacheValid {
// Make a copy of the JWK to avoid use-after-unlock
jwkCopy := p.findAndCopyKey(kid)
p.jwksMutex.RUnlock()
if jwkCopy != nil {
// Parse key outside lock for better concurrency
return p.parseJWK(jwkCopy)
}
// Key not found in valid cache - force refresh for potential key rotation
// Use distinct singleflight key to ensure refresh happens regardless of TTL
// This handles mid-TTL key rotation where IdP adds new keys
_, err, _ := p.jwksFetchGroup.Do("jwks-force-refresh", func() (interface{}, error) {
return nil, p.fetchJWKS(ctx)
})
if err != nil {
return nil, fmt.Errorf("failed to refresh JWKS for missing kid: %v", err)
}
// Search in refreshed JWKS
p.jwksMutex.RLock()
jwkCopy = p.findAndCopyKey(kid)
p.jwksMutex.RUnlock()
if jwkCopy != nil {
return p.parseJWK(jwkCopy)
}
return nil, fmt.Errorf("key with ID %s not found in JWKS after forced refresh", kid)
} else {
p.jwksMutex.RUnlock()
}
// Find the key with matching kid
for _, key := range p.jwksCache.Keys {
if key.Kid == kid {
return p.parseJWK(&key)
// Slow path: Cache expired or nil - need to fetch/refresh JWKS
// Use singleflight to ensure only one fetch happens even with many concurrent requests
_, err, _ := p.jwksFetchGroup.Do("jwks", func() (interface{}, error) {
// Double-check: another goroutine may have just fetched
p.jwksMutex.RLock()
stillNeedFetch := p.jwksCache == nil || p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) > p.jwksTTL
p.jwksMutex.RUnlock()
if stillNeedFetch {
// Fetch JWKS WITHOUT holding any locks (critical for performance)
return nil, p.fetchJWKS(ctx)
}
return nil, nil
})
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %v", err)
}
// Key not found in cache. Refresh JWKS once to handle key rotation and retry.
if err := p.fetchJWKS(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err)
// Search in newly fetched JWKS
p.jwksMutex.RLock()
jwkCopy := p.findAndCopyKey(kid)
p.jwksMutex.RUnlock()
if jwkCopy != nil {
return p.parseJWK(jwkCopy)
}
for _, key := range p.jwksCache.Keys {
if key.Kid == kid {
return p.parseJWK(&key)
// Key not found even after refresh
return nil, fmt.Errorf("key with ID %s not found in JWKS", kid)
}
// findAndCopyKey searches for a key and returns a copy (not a pointer)
// Must be called with at least a read lock held
func (p *OIDCProvider) findAndCopyKey(kid string) *JWK {
if p.jwksCache == nil {
return nil
}
for i := range p.jwksCache.Keys {
if p.jwksCache.Keys[i].Kid == kid {
// Return a copy to avoid use-after-unlock bugs
keyCopy := p.jwksCache.Keys[i]
return &keyCopy
}
}
return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid)
return nil
}
// fetchJWKS fetches the JWKS from the provider
// fetchJWKS fetches the JWKS from the provider WITHOUT holding locks during HTTP call
// This is critical for performance - HTTP calls can take seconds
func (p *OIDCProvider) fetchJWKS(ctx context.Context) error {
jwksURL := p.config.JWKSUri
if jwksURL == "" {
jwksURL = strings.TrimSuffix(p.config.Issuer, "/") + "/.well-known/jwks.json"
}
// Create request WITHOUT holding any locks
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
if err != nil {
return fmt.Errorf("failed to create JWKS request: %v", err)
}
// Make HTTP call WITHOUT holding any locks (critical for performance)
resp, err := p.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to fetch JWKS: %v", err)
@ -600,13 +669,18 @@ func (p *OIDCProvider) fetchJWKS(ctx context.Context) error {
return fmt.Errorf("JWKS endpoint returned status: %d", resp.StatusCode)
}
// Decode response WITHOUT holding any locks
var jwks JWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return fmt.Errorf("failed to decode JWKS response: %v", err)
}
// Only acquire write lock for the actual cache update (very fast)
p.jwksMutex.Lock()
p.jwksCache = &jwks
p.jwksFetchedAt = time.Now()
p.jwksMutex.Unlock()
glog.V(3).Infof("Fetched JWKS with %d keys from %s", len(jwks.Keys), jwksURL)
return nil
}

297
weed/iam/oidc/oidc_race_test.go

@ -0,0 +1,297 @@
package oidc
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
// TestJWKSCacheConcurrentRefresh tests that concurrent JWKS refresh doesn't cause race conditions
func TestJWKSCacheConcurrentRefresh(t *testing.T) {
// Generate RSA key pair for testing
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
// Track JWKS fetch count
var fetchCount int
var fetchMutex sync.Mutex
// Create mock JWKS server
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fetchMutex.Lock()
fetchCount++
currentFetch := fetchCount
fetchMutex.Unlock()
t.Logf("JWKS fetch #%d", currentFetch)
// Add small delay to simulate network latency
time.Sleep(10 * time.Millisecond)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": "test-key-1",
"use": "sig",
"alg": "RS256",
"n": encodePublicKey(t, &privateKey.PublicKey),
"e": "AQAB",
},
},
}
json.NewEncoder(w).Encode(jwks)
}))
defer jwksServer.Close()
// Initialize OIDC provider with very short TTL to force refresh
provider := NewOIDCProvider("test")
config := &OIDCConfig{
Issuer: "https://test.example.com",
ClientID: "test-client-id",
JWKSUri: jwksServer.URL,
JWKSCacheTTLSeconds: 1, // Very short TTL
}
if err := provider.Initialize(config); err != nil {
t.Fatalf("Failed to initialize provider: %v", err)
}
// Generate valid JWT token
token := generateTestToken(t, privateKey, "test-key-1", "https://test.example.com", "test-client-id")
// Test 1: Concurrent validation with initial JWKS fetch
t.Run("concurrent_initial_fetch", func(t *testing.T) {
fetchCount = 0
provider.jwksCache = nil // Reset cache
var wg sync.WaitGroup
concurrency := 50
successCount := 0
var successMutex sync.Mutex
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
ctx := context.Background()
_, err := provider.ValidateToken(ctx, token)
if err == nil {
successMutex.Lock()
successCount++
successMutex.Unlock()
} else {
t.Logf("Goroutine %d validation error: %v", id, err)
}
}(i)
}
wg.Wait()
// All validations should succeed
if successCount != concurrency {
t.Errorf("Expected %d successful validations, got %d", concurrency, successCount)
}
// JWKS should be fetched only once or very few times (due to double-checked locking)
t.Logf("JWKS fetched %d times for %d concurrent requests", fetchCount, concurrency)
if fetchCount > 5 {
t.Errorf("Too many JWKS fetches: %d (expected <= 5 due to locking)", fetchCount)
}
})
// Test 2: Concurrent validation during cache expiration
t.Run("concurrent_cache_expiration", func(t *testing.T) {
fetchCount = 0
// Pre-populate cache
ctx := context.Background()
_, err := provider.ValidateToken(ctx, token)
if err != nil {
t.Fatalf("Initial validation failed: %v", err)
}
initialFetchCount := fetchCount
t.Logf("Initial fetch count: %d", initialFetchCount)
// Wait for cache to expire
time.Sleep(1100 * time.Millisecond)
// Concurrent validations after cache expiration
var wg sync.WaitGroup
concurrency := 50
successCount := 0
var successMutex sync.Mutex
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
_, err := provider.ValidateToken(ctx, token)
if err == nil {
successMutex.Lock()
successCount++
successMutex.Unlock()
} else {
t.Logf("Goroutine %d validation error: %v", id, err)
}
}(i)
}
wg.Wait()
// All validations should succeed
if successCount != concurrency {
t.Errorf("Expected %d successful validations, got %d", concurrency, successCount)
}
// Should only fetch once more after expiration
refreshFetchCount := fetchCount - initialFetchCount
t.Logf("JWKS refreshed %d times for %d concurrent requests after expiration", refreshFetchCount, concurrency)
if refreshFetchCount > 5 {
t.Errorf("Too many JWKS refreshes: %d (expected <= 5)", refreshFetchCount)
}
})
// Test 3: Race detector test - rapid concurrent access
t.Run("race_detector", func(t *testing.T) {
// Reset cache
provider.jwksCache = nil
fetchCount = 0
var wg sync.WaitGroup
concurrency := 100
iterations := 10
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ctx := context.Background()
for j := 0; j < iterations; j++ {
provider.ValidateToken(ctx, token)
// Small random delay
time.Sleep(time.Millisecond * time.Duration(j%3))
}
}()
}
wg.Wait()
t.Logf("Completed %d iterations across %d goroutines without race", iterations, concurrency)
})
}
// TestJWKSCacheIsolation tests that cache updates don't interfere with ongoing reads
func TestJWKSCacheIsolation(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
// Create JWKS server that alternates between two key sets
keyVersion := 0
var keyMutex sync.Mutex
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
keyMutex.Lock()
currentVersion := keyVersion
keyVersion++
keyMutex.Unlock()
// Simulate slow response
time.Sleep(50 * time.Millisecond)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": "test-key-1",
"use": "sig",
"alg": "RS256",
"n": encodePublicKey(t, &privateKey.PublicKey),
"e": "AQAB",
},
},
}
t.Logf("Serving JWKS version %d", currentVersion)
json.NewEncoder(w).Encode(jwks)
}))
defer jwksServer.Close()
provider := NewOIDCProvider("test")
config := &OIDCConfig{
Issuer: "https://test.example.com",
ClientID: "test-client-id",
JWKSUri: jwksServer.URL,
JWKSCacheTTLSeconds: 1,
}
if err := provider.Initialize(config); err != nil {
t.Fatalf("Failed to initialize provider: %v", err)
}
token := generateTestToken(t, privateKey, "test-key-1", "https://test.example.com", "test-client-id")
// Concurrent readers and writers
var wg sync.WaitGroup
ctx := context.Background()
errorCount := 0
var errorMutex sync.Mutex
// Readers
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 20; j++ {
_, err := provider.ValidateToken(ctx, token)
if err != nil {
errorMutex.Lock()
errorCount++
errorMutex.Unlock()
t.Logf("Reader %d got error: %v", id, err)
}
time.Sleep(5 * time.Millisecond)
}
}(i)
}
wg.Wait()
if errorCount > 0 {
t.Errorf("Got %d errors during concurrent read/write operations", errorCount)
}
}
// Helper function to generate test JWT token
func generateTestToken(t *testing.T, privateKey *rsa.PrivateKey, kid, issuer, audience string) string {
claims := jwt.MapClaims{
"sub": "test-user",
"iss": issuer,
"aud": audience,
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"email": "test@example.com",
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = kid
tokenString, err := token.SignedString(privateKey)
if err != nil {
t.Fatalf("Failed to sign token: %v", err)
}
return tokenString
}
Loading…
Cancel
Save