committed by
GitHub
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 387 additions and 16 deletions
@ -0,0 +1,297 @@ |
|||||
|
package oidc |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"crypto/rand" |
||||
|
"crypto/rsa" |
||||
|
"encoding/json" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"sync" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/golang-jwt/jwt/v5" |
||||
|
) |
||||
|
|
||||
|
// TestJWKSCacheConcurrentRefresh tests that concurrent JWKS refresh doesn't cause race conditions
|
||||
|
func TestJWKSCacheConcurrentRefresh(t *testing.T) { |
||||
|
// Generate RSA key pair for testing
|
||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
|
if err != nil { |
||||
|
t.Fatalf("Failed to generate RSA key: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Track JWKS fetch count
|
||||
|
var fetchCount int |
||||
|
var fetchMutex sync.Mutex |
||||
|
|
||||
|
// Create mock JWKS server
|
||||
|
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
fetchMutex.Lock() |
||||
|
fetchCount++ |
||||
|
currentFetch := fetchCount |
||||
|
fetchMutex.Unlock() |
||||
|
|
||||
|
t.Logf("JWKS fetch #%d", currentFetch) |
||||
|
|
||||
|
// Add small delay to simulate network latency
|
||||
|
time.Sleep(10 * time.Millisecond) |
||||
|
|
||||
|
jwks := map[string]interface{}{ |
||||
|
"keys": []map[string]interface{}{ |
||||
|
{ |
||||
|
"kty": "RSA", |
||||
|
"kid": "test-key-1", |
||||
|
"use": "sig", |
||||
|
"alg": "RS256", |
||||
|
"n": encodePublicKey(t, &privateKey.PublicKey), |
||||
|
"e": "AQAB", |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
json.NewEncoder(w).Encode(jwks) |
||||
|
})) |
||||
|
defer jwksServer.Close() |
||||
|
|
||||
|
// Initialize OIDC provider with very short TTL to force refresh
|
||||
|
provider := NewOIDCProvider("test") |
||||
|
config := &OIDCConfig{ |
||||
|
Issuer: "https://test.example.com", |
||||
|
ClientID: "test-client-id", |
||||
|
JWKSUri: jwksServer.URL, |
||||
|
JWKSCacheTTLSeconds: 1, // Very short TTL
|
||||
|
} |
||||
|
|
||||
|
if err := provider.Initialize(config); err != nil { |
||||
|
t.Fatalf("Failed to initialize provider: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Generate valid JWT token
|
||||
|
token := generateTestToken(t, privateKey, "test-key-1", "https://test.example.com", "test-client-id") |
||||
|
|
||||
|
// Test 1: Concurrent validation with initial JWKS fetch
|
||||
|
t.Run("concurrent_initial_fetch", func(t *testing.T) { |
||||
|
fetchCount = 0 |
||||
|
provider.jwksCache = nil // Reset cache
|
||||
|
|
||||
|
var wg sync.WaitGroup |
||||
|
concurrency := 50 |
||||
|
successCount := 0 |
||||
|
var successMutex sync.Mutex |
||||
|
|
||||
|
for i := 0; i < concurrency; i++ { |
||||
|
wg.Add(1) |
||||
|
go func(id int) { |
||||
|
defer wg.Done() |
||||
|
ctx := context.Background() |
||||
|
_, err := provider.ValidateToken(ctx, token) |
||||
|
if err == nil { |
||||
|
successMutex.Lock() |
||||
|
successCount++ |
||||
|
successMutex.Unlock() |
||||
|
} else { |
||||
|
t.Logf("Goroutine %d validation error: %v", id, err) |
||||
|
} |
||||
|
}(i) |
||||
|
} |
||||
|
|
||||
|
wg.Wait() |
||||
|
|
||||
|
// All validations should succeed
|
||||
|
if successCount != concurrency { |
||||
|
t.Errorf("Expected %d successful validations, got %d", concurrency, successCount) |
||||
|
} |
||||
|
|
||||
|
// JWKS should be fetched only once or very few times (due to double-checked locking)
|
||||
|
t.Logf("JWKS fetched %d times for %d concurrent requests", fetchCount, concurrency) |
||||
|
if fetchCount > 5 { |
||||
|
t.Errorf("Too many JWKS fetches: %d (expected <= 5 due to locking)", fetchCount) |
||||
|
} |
||||
|
}) |
||||
|
|
||||
|
// Test 2: Concurrent validation during cache expiration
|
||||
|
t.Run("concurrent_cache_expiration", func(t *testing.T) { |
||||
|
fetchCount = 0 |
||||
|
|
||||
|
// Pre-populate cache
|
||||
|
ctx := context.Background() |
||||
|
_, err := provider.ValidateToken(ctx, token) |
||||
|
if err != nil { |
||||
|
t.Fatalf("Initial validation failed: %v", err) |
||||
|
} |
||||
|
|
||||
|
initialFetchCount := fetchCount |
||||
|
t.Logf("Initial fetch count: %d", initialFetchCount) |
||||
|
|
||||
|
// Wait for cache to expire
|
||||
|
time.Sleep(1100 * time.Millisecond) |
||||
|
|
||||
|
// Concurrent validations after cache expiration
|
||||
|
var wg sync.WaitGroup |
||||
|
concurrency := 50 |
||||
|
successCount := 0 |
||||
|
var successMutex sync.Mutex |
||||
|
|
||||
|
for i := 0; i < concurrency; i++ { |
||||
|
wg.Add(1) |
||||
|
go func(id int) { |
||||
|
defer wg.Done() |
||||
|
_, err := provider.ValidateToken(ctx, token) |
||||
|
if err == nil { |
||||
|
successMutex.Lock() |
||||
|
successCount++ |
||||
|
successMutex.Unlock() |
||||
|
} else { |
||||
|
t.Logf("Goroutine %d validation error: %v", id, err) |
||||
|
} |
||||
|
}(i) |
||||
|
} |
||||
|
|
||||
|
wg.Wait() |
||||
|
|
||||
|
// All validations should succeed
|
||||
|
if successCount != concurrency { |
||||
|
t.Errorf("Expected %d successful validations, got %d", concurrency, successCount) |
||||
|
} |
||||
|
|
||||
|
// Should only fetch once more after expiration
|
||||
|
refreshFetchCount := fetchCount - initialFetchCount |
||||
|
t.Logf("JWKS refreshed %d times for %d concurrent requests after expiration", refreshFetchCount, concurrency) |
||||
|
if refreshFetchCount > 5 { |
||||
|
t.Errorf("Too many JWKS refreshes: %d (expected <= 5)", refreshFetchCount) |
||||
|
} |
||||
|
}) |
||||
|
|
||||
|
// Test 3: Race detector test - rapid concurrent access
|
||||
|
t.Run("race_detector", func(t *testing.T) { |
||||
|
// Reset cache
|
||||
|
provider.jwksCache = nil |
||||
|
fetchCount = 0 |
||||
|
|
||||
|
var wg sync.WaitGroup |
||||
|
concurrency := 100 |
||||
|
iterations := 10 |
||||
|
|
||||
|
for i := 0; i < concurrency; i++ { |
||||
|
wg.Add(1) |
||||
|
go func() { |
||||
|
defer wg.Done() |
||||
|
ctx := context.Background() |
||||
|
for j := 0; j < iterations; j++ { |
||||
|
provider.ValidateToken(ctx, token) |
||||
|
// Small random delay
|
||||
|
time.Sleep(time.Millisecond * time.Duration(j%3)) |
||||
|
} |
||||
|
}() |
||||
|
} |
||||
|
|
||||
|
wg.Wait() |
||||
|
t.Logf("Completed %d iterations across %d goroutines without race", iterations, concurrency) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// TestJWKSCacheIsolation tests that cache updates don't interfere with ongoing reads
|
||||
|
func TestJWKSCacheIsolation(t *testing.T) { |
||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
|
if err != nil { |
||||
|
t.Fatalf("Failed to generate RSA key: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Create JWKS server that alternates between two key sets
|
||||
|
keyVersion := 0 |
||||
|
var keyMutex sync.Mutex |
||||
|
|
||||
|
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
keyMutex.Lock() |
||||
|
currentVersion := keyVersion |
||||
|
keyVersion++ |
||||
|
keyMutex.Unlock() |
||||
|
|
||||
|
// Simulate slow response
|
||||
|
time.Sleep(50 * time.Millisecond) |
||||
|
|
||||
|
jwks := map[string]interface{}{ |
||||
|
"keys": []map[string]interface{}{ |
||||
|
{ |
||||
|
"kty": "RSA", |
||||
|
"kid": "test-key-1", |
||||
|
"use": "sig", |
||||
|
"alg": "RS256", |
||||
|
"n": encodePublicKey(t, &privateKey.PublicKey), |
||||
|
"e": "AQAB", |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
t.Logf("Serving JWKS version %d", currentVersion) |
||||
|
json.NewEncoder(w).Encode(jwks) |
||||
|
})) |
||||
|
defer jwksServer.Close() |
||||
|
|
||||
|
provider := NewOIDCProvider("test") |
||||
|
config := &OIDCConfig{ |
||||
|
Issuer: "https://test.example.com", |
||||
|
ClientID: "test-client-id", |
||||
|
JWKSUri: jwksServer.URL, |
||||
|
JWKSCacheTTLSeconds: 1, |
||||
|
} |
||||
|
|
||||
|
if err := provider.Initialize(config); err != nil { |
||||
|
t.Fatalf("Failed to initialize provider: %v", err) |
||||
|
} |
||||
|
|
||||
|
token := generateTestToken(t, privateKey, "test-key-1", "https://test.example.com", "test-client-id") |
||||
|
|
||||
|
// Concurrent readers and writers
|
||||
|
var wg sync.WaitGroup |
||||
|
ctx := context.Background() |
||||
|
errorCount := 0 |
||||
|
var errorMutex sync.Mutex |
||||
|
|
||||
|
// Readers
|
||||
|
for i := 0; i < 10; i++ { |
||||
|
wg.Add(1) |
||||
|
go func(id int) { |
||||
|
defer wg.Done() |
||||
|
for j := 0; j < 20; j++ { |
||||
|
_, err := provider.ValidateToken(ctx, token) |
||||
|
if err != nil { |
||||
|
errorMutex.Lock() |
||||
|
errorCount++ |
||||
|
errorMutex.Unlock() |
||||
|
t.Logf("Reader %d got error: %v", id, err) |
||||
|
} |
||||
|
time.Sleep(5 * time.Millisecond) |
||||
|
} |
||||
|
}(i) |
||||
|
} |
||||
|
|
||||
|
wg.Wait() |
||||
|
|
||||
|
if errorCount > 0 { |
||||
|
t.Errorf("Got %d errors during concurrent read/write operations", errorCount) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Helper function to generate test JWT token
|
||||
|
func generateTestToken(t *testing.T, privateKey *rsa.PrivateKey, kid, issuer, audience string) string { |
||||
|
claims := jwt.MapClaims{ |
||||
|
"sub": "test-user", |
||||
|
"iss": issuer, |
||||
|
"aud": audience, |
||||
|
"exp": time.Now().Add(1 * time.Hour).Unix(), |
||||
|
"iat": time.Now().Unix(), |
||||
|
"email": "test@example.com", |
||||
|
} |
||||
|
|
||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) |
||||
|
token.Header["kid"] = kid |
||||
|
|
||||
|
tokenString, err := token.SignedString(privateKey) |
||||
|
if err != nil { |
||||
|
t.Fatalf("Failed to sign token: %v", err) |
||||
|
} |
||||
|
|
||||
|
return tokenString |
||||
|
} |
||||
Write
Preview
Loading…
Cancel
Save
Reference in new issue