diff --git a/test/s3/iam/DISTRIBUTED.md b/test/s3/iam/DISTRIBUTED.md index 16b356108..a0be7b108 100644 --- a/test/s3/iam/DISTRIBUTED.md +++ b/test/s3/iam/DISTRIBUTED.md @@ -45,7 +45,7 @@ All S3 gateway instances share the same IAM state through the filer. "sts": { "tokenDuration": 3600000000000, "maxSessionLength": 43200000000000, - "issuer": "seaweedfs-sts", + "issuer": "seaweedfs-sts", "signingKey": "base64-encoded-signing-key", "sessionStoreType": "filer", "sessionStoreConfig": { @@ -64,7 +64,7 @@ All S3 gateway instances share the same IAM state through the filer. "roleStore": { "storeType": "filer", "storeConfig": { - "filerAddress": "localhost:8888", + "filerAddress": "localhost:8888", "basePath": "/seaweedfs/iam/roles" } } diff --git a/test/s3/iam/s3_iam_distributed_test.go b/test/s3/iam/s3_iam_distributed_test.go index a5d9dbd5e..16d108bbf 100644 --- a/test/s3/iam/s3_iam_distributed_test.go +++ b/test/s3/iam/s3_iam_distributed_test.go @@ -128,38 +128,34 @@ func TestS3IAMDistributedTests(t *testing.T) { if err := framework.CreateBucket(client, bucketName); err != nil { errors <- err continue - + } // Small delay to reduce server load time.Sleep(100 * time.Millisecond) - } // Put object objectKey := "test-object.txt" if err := framework.PutTestObject(client, bucketName, objectKey, fmt.Sprintf("content-%d-%d", goroutineID, j)); err != nil { errors <- err continue - + } // Small delay to reduce server load time.Sleep(100 * time.Millisecond) - } // Get object if _, err := framework.GetTestObject(client, bucketName, objectKey); err != nil { errors <- err continue - + } // Small delay to reduce server load time.Sleep(100 * time.Millisecond) - } // Delete object if err := framework.DeleteTestObject(client, bucketName, objectKey); err != nil { errors <- err continue - + } // Small delay to reduce server load time.Sleep(100 * time.Millisecond) - } // Delete bucket if _, err := client.DeleteBucket(&s3.DeleteBucketInput{ @@ -167,10 +163,9 @@ func TestS3IAMDistributedTests(t *testing.T) { }); err != nil { errors <- err continue - + } // Small delay to reduce server load time.Sleep(100 * time.Millisecond) - } } }(i) } @@ -183,15 +178,15 @@ func TestS3IAMDistributedTests(t *testing.T) { for err := range errors { errorList = append(errorList, err) } - + totalOperations := numGoroutines * numOperationsPerGoroutine errorRate := float64(len(errorList)) / float64(totalOperations) - + if len(errorList) > 0 { - t.Logf("Concurrent operations: %d/%d operations failed (%.1f%% error rate). First error: %v", + t.Logf("Concurrent operations: %d/%d operations failed (%.1f%% error rate). First error: %v", len(errorList), totalOperations, errorRate*100, errorList[0]) } - + // Allow up to 50% error rate for concurrent stress testing // This tests that the system handles concurrent load gracefully if errorRate > 0.5 { diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go index 3fbe6ce85..006e76e57 100644 --- a/weed/iam/oidc/oidc_provider.go +++ b/weed/iam/oidc/oidc_provider.go @@ -18,11 +18,13 @@ import ( // OIDCProvider implements OpenID Connect authentication type OIDCProvider struct { - name string - config *OIDCConfig - initialized bool - jwksCache *JWKS - httpClient *http.Client + name string + config *OIDCConfig + initialized bool + jwksCache *JWKS + httpClient *http.Client + jwksFetchedAt time.Time + jwksTTL time.Duration } // OIDCConfig holds OIDC provider configuration @@ -50,6 +52,9 @@ type OIDCConfig struct { // ClaimsMapping defines how to map OIDC claims to identity attributes ClaimsMapping map[string]string `json:"claimsMapping,omitempty"` + + // JWKSCacheTTLSeconds sets how long to cache JWKS before refresh (default 3600 seconds) + JWKSCacheTTLSeconds int `json:"jwksCacheTTLSeconds,omitempty"` } // JWKS represents JSON Web Key Set @@ -101,6 +106,13 @@ func (p *OIDCProvider) Initialize(config interface{}) error { p.config = oidcConfig p.initialized = true + // Configure JWKS cache TTL + if oidcConfig.JWKSCacheTTLSeconds > 0 { + p.jwksTTL = time.Duration(oidcConfig.JWKSCacheTTLSeconds) * time.Second + } else { + p.jwksTTL = time.Hour + } + // For testing, we'll skip the actual OIDC client initialization return nil } @@ -385,8 +397,8 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims) // getPublicKey retrieves the public key for the given key ID from JWKS func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) { - // Fetch JWKS if not cached or refresh if needed - if p.jwksCache == nil { + // Fetch JWKS if not cached or refresh if expired + if p.jwksCache == nil || (!p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) > p.jwksTTL) { if err := p.fetchJWKS(ctx); err != nil { return nil, fmt.Errorf("failed to fetch JWKS: %v", err) } @@ -399,7 +411,16 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{ } } - return nil, fmt.Errorf("key with ID %s not found in JWKS", kid) + // Key not found in cache. Refresh JWKS once to handle key rotation and retry. + if err := p.fetchJWKS(ctx); err != nil { + return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err) + } + for _, key := range p.jwksCache.Keys { + if key.Kid == kid { + return p.parseJWK(&key) + } + } + return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid) } // fetchJWKS fetches the JWKS from the provider @@ -430,6 +451,7 @@ func (p *OIDCProvider) fetchJWKS(ctx context.Context) error { } p.jwksCache = &jwks + p.jwksFetchedAt = time.Now() glog.V(3).Infof("Fetched JWKS with %d keys from %s", len(jwks.Keys), jwksURL) return nil } diff --git a/weed/iam/providers/provider.go b/weed/iam/providers/provider.go index 2a3d9790f..74d63dd46 100644 --- a/weed/iam/providers/provider.go +++ b/weed/iam/providers/provider.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "net/mail" - "strings" + "path/filepath" "time" ) @@ -210,15 +210,9 @@ func (r *MappingRule) Matches(claims *TokenClaims) bool { // matchValue checks if a value matches the rule value (with wildcard support) func (r *MappingRule) matchValue(value string) bool { - // Simple wildcard matching - if strings.Contains(r.Value, "*") { - // Convert wildcard to regex-like matching - pattern := strings.ReplaceAll(r.Value, "*", "") - if pattern == "" { - return true // "*" matches everything - } - return strings.Contains(value, pattern) + matched, err := filepath.Match(r.Value, value) + if err != nil { + return false } - - return value == r.Value + return matched }