Browse Source

address comments

pull/7160/head
chrislu 1 month ago
parent
commit
02798df85d
  1. 4
      test/s3/iam/DISTRIBUTED.md
  2. 23
      test/s3/iam/s3_iam_distributed_test.go
  3. 38
      weed/iam/oidc/oidc_provider.go
  4. 16
      weed/iam/providers/provider.go

4
test/s3/iam/DISTRIBUTED.md

@ -45,7 +45,7 @@ All S3 gateway instances share the same IAM state through the filer.
"sts": {
"tokenDuration": 3600000000000,
"maxSessionLength": 43200000000000,
"issuer": "seaweedfs-sts",
"issuer": "seaweedfs-sts",
"signingKey": "base64-encoded-signing-key",
"sessionStoreType": "filer",
"sessionStoreConfig": {
@ -64,7 +64,7 @@ All S3 gateway instances share the same IAM state through the filer.
"roleStore": {
"storeType": "filer",
"storeConfig": {
"filerAddress": "localhost:8888",
"filerAddress": "localhost:8888",
"basePath": "/seaweedfs/iam/roles"
}
}

23
test/s3/iam/s3_iam_distributed_test.go

@ -128,38 +128,34 @@ func TestS3IAMDistributedTests(t *testing.T) {
if err := framework.CreateBucket(client, bucketName); err != nil {
errors <- err
continue
}
// Small delay to reduce server load
time.Sleep(100 * time.Millisecond)
}
// Put object
objectKey := "test-object.txt"
if err := framework.PutTestObject(client, bucketName, objectKey, fmt.Sprintf("content-%d-%d", goroutineID, j)); err != nil {
errors <- err
continue
}
// Small delay to reduce server load
time.Sleep(100 * time.Millisecond)
}
// Get object
if _, err := framework.GetTestObject(client, bucketName, objectKey); err != nil {
errors <- err
continue
}
// Small delay to reduce server load
time.Sleep(100 * time.Millisecond)
}
// Delete object
if err := framework.DeleteTestObject(client, bucketName, objectKey); err != nil {
errors <- err
continue
}
// Small delay to reduce server load
time.Sleep(100 * time.Millisecond)
}
// Delete bucket
if _, err := client.DeleteBucket(&s3.DeleteBucketInput{
@ -167,10 +163,9 @@ func TestS3IAMDistributedTests(t *testing.T) {
}); err != nil {
errors <- err
continue
}
// Small delay to reduce server load
time.Sleep(100 * time.Millisecond)
}
}
}(i)
}
@ -183,15 +178,15 @@ func TestS3IAMDistributedTests(t *testing.T) {
for err := range errors {
errorList = append(errorList, err)
}
totalOperations := numGoroutines * numOperationsPerGoroutine
errorRate := float64(len(errorList)) / float64(totalOperations)
if len(errorList) > 0 {
t.Logf("Concurrent operations: %d/%d operations failed (%.1f%% error rate). First error: %v",
t.Logf("Concurrent operations: %d/%d operations failed (%.1f%% error rate). First error: %v",
len(errorList), totalOperations, errorRate*100, errorList[0])
}
// Allow up to 50% error rate for concurrent stress testing
// This tests that the system handles concurrent load gracefully
if errorRate > 0.5 {

38
weed/iam/oidc/oidc_provider.go

@ -18,11 +18,13 @@ import (
// OIDCProvider implements OpenID Connect authentication
type OIDCProvider struct {
name string
config *OIDCConfig
initialized bool
jwksCache *JWKS
httpClient *http.Client
name string
config *OIDCConfig
initialized bool
jwksCache *JWKS
httpClient *http.Client
jwksFetchedAt time.Time
jwksTTL time.Duration
}
// OIDCConfig holds OIDC provider configuration
@ -50,6 +52,9 @@ type OIDCConfig struct {
// ClaimsMapping defines how to map OIDC claims to identity attributes
ClaimsMapping map[string]string `json:"claimsMapping,omitempty"`
// JWKSCacheTTLSeconds sets how long to cache JWKS before refresh (default 3600 seconds)
JWKSCacheTTLSeconds int `json:"jwksCacheTTLSeconds,omitempty"`
}
// JWKS represents JSON Web Key Set
@ -101,6 +106,13 @@ func (p *OIDCProvider) Initialize(config interface{}) error {
p.config = oidcConfig
p.initialized = true
// Configure JWKS cache TTL
if oidcConfig.JWKSCacheTTLSeconds > 0 {
p.jwksTTL = time.Duration(oidcConfig.JWKSCacheTTLSeconds) * time.Second
} else {
p.jwksTTL = time.Hour
}
// For testing, we'll skip the actual OIDC client initialization
return nil
}
@ -385,8 +397,8 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims)
// getPublicKey retrieves the public key for the given key ID from JWKS
func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) {
// Fetch JWKS if not cached or refresh if needed
if p.jwksCache == nil {
// Fetch JWKS if not cached or refresh if expired
if p.jwksCache == nil || (!p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) > p.jwksTTL) {
if err := p.fetchJWKS(ctx); err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %v", err)
}
@ -399,7 +411,16 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{
}
}
return nil, fmt.Errorf("key with ID %s not found in JWKS", kid)
// Key not found in cache. Refresh JWKS once to handle key rotation and retry.
if err := p.fetchJWKS(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err)
}
for _, key := range p.jwksCache.Keys {
if key.Kid == kid {
return p.parseJWK(&key)
}
}
return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid)
}
// fetchJWKS fetches the JWKS from the provider
@ -430,6 +451,7 @@ func (p *OIDCProvider) fetchJWKS(ctx context.Context) error {
}
p.jwksCache = &jwks
p.jwksFetchedAt = time.Now()
glog.V(3).Infof("Fetched JWKS with %d keys from %s", len(jwks.Keys), jwksURL)
return nil
}

16
weed/iam/providers/provider.go

@ -4,7 +4,7 @@ import (
"context"
"fmt"
"net/mail"
"strings"
"path/filepath"
"time"
)
@ -210,15 +210,9 @@ func (r *MappingRule) Matches(claims *TokenClaims) bool {
// matchValue checks if a value matches the rule value (with wildcard support)
func (r *MappingRule) matchValue(value string) bool {
// Simple wildcard matching
if strings.Contains(r.Value, "*") {
// Convert wildcard to regex-like matching
pattern := strings.ReplaceAll(r.Value, "*", "")
if pattern == "" {
return true // "*" matches everything
}
return strings.Contains(value, pattern)
matched, err := filepath.Match(r.Value, value)
if err != nil {
return false
}
return value == r.Value
return matched
}
Loading…
Cancel
Save