Browse Source

address comments

pull/7160/head
chrislu 1 month ago
parent
commit
20d3f8550f
  1. 2
      test/s3/iam/docker-compose.yml
  2. 31
      test/s3/iam/s3_iam_distributed_test.go
  3. 51
      weed/iam/oidc/oidc_provider.go
  4. 2
      weed/iam/policy/aws_iam_compliance_test.go
  5. 45
      weed/iam/policy/policy_engine.go
  6. 24
      weed/iam/providers/provider.go
  7. 192
      weed/iam/sts/security_test.go
  8. 58
      weed/iam/sts/sts_service.go
  9. 11
      weed/iam/sts/sts_service_test.go
  10. 94
      weed/s3api/s3_iam_middleware.go
  11. 259
      weed/s3api/s3_iam_role_selection_test.go

2
test/s3/iam/docker-compose.yml

@ -132,7 +132,7 @@ services:
# Keycloak Setup Service
keycloak-setup:
image: alpine/curl:latest
image: alpine/curl:8.4.0
container_name: keycloak-setup
volumes:
- ./setup_keycloak_docker.sh:/setup.sh:ro

31
test/s3/iam/s3_iam_distributed_test.go

@ -104,8 +104,8 @@ func TestS3IAMDistributedTests(t *testing.T) {
t.Run("distributed_concurrent_operations", func(t *testing.T) {
// Test concurrent operations across distributed instances
// CI-REALISTIC APPROACH: 8 total operations (4x2) - 33% more than original (6) with pragmatic CI tolerance
// Success rates vary from 62.5%-87.5% due to CI environment resource constraints, but system functions correctly
// STRINGENT APPROACH: 8 total operations (4x2) - 33% more than original (6) with rigorous error detection
// Target >87.5% success rate to catch concurrency regressions while allowing minimal CI infrastructure issues
const numGoroutines = 4 // Optimal concurrency for CI reliability
const numOperationsPerGoroutine = 2 // Minimal operations per goroutine
@ -216,34 +216,33 @@ func TestS3IAMDistributedTests(t *testing.T) {
}
}
// CI-REALISTIC CONCURRENCY TESTING: Pragmatic thresholds based on observed CI environment variability
// For totalOperations=8, success rates vary 62.5%-87.5% due to CI infrastructure constraints
// STRINGENT CONCURRENCY TESTING: More rigorous thresholds to catch regressions while accounting for CI variability
// For totalOperations=8, target >87.5% success rate (≤12.5% error rate) to detect concurrency issues
// Serious errors (race conditions, deadlocks) should be minimal for reliable CI testing
// Allow up to 50% serious errors for CI environment variability (infrastructure limitations, volume allocation issues)
// This accounts for the reality that CI environments have resource constraints that can affect multiple operations
maxSeriousErrors := totalOperations / 2 // Realistic tolerance for CI infrastructure variability
// Serious errors (race conditions, deadlocks) should be very limited - allow only 1 for CI infrastructure issues
// Based on observed data: 1-3 errors due to volume allocation constraints, not actual concurrency bugs
maxSeriousErrors := 1 // Allow 1 serious error (12.5%) for CI infrastructure limitations only
if len(seriousErrors) > maxSeriousErrors {
t.Errorf("❌ %d serious error(s) detected (%.1f%%), exceeding threshold of %d. This indicates potential concurrency bugs. First error: %v",
len(seriousErrors), float64(len(seriousErrors))/float64(totalOperations)*100, maxSeriousErrors, seriousErrors[0])
}
// For total errors, use pragmatic thresholds based on observed CI environment variability
// CI environments can have resource constraints that affect multiple operations simultaneously
maxTotalErrorsStrict := 2 // Allow max 2 total errors (25% rate) - good performance in CI
maxTotalErrorsRelaxed := 4 // Allow max 4 total errors (50% rate) - acceptable for resource-constrained CI
// For total errors, use stringent thresholds to catch regressions while allowing minimal CI infrastructure issues
// Target >87.5% success rate to ensure system reliability and catch concurrency problems early
maxTotalErrorsStrict := 1 // Allow max 1 total error (12.5% rate) - excellent performance target
maxTotalErrorsRelaxed := 2 // Allow max 2 total errors (25% rate) - acceptable with infrastructure constraints
if len(errorList) > maxTotalErrorsRelaxed {
t.Errorf("❌ Too many total errors: %d (%.1f%%) - exceeds relaxed threshold of %d (%.1f%%). System is unstable under concurrent load.",
t.Errorf("❌ Too many total errors: %d (%.1f%%) - exceeds threshold of %d (%.1f%%). System may have concurrency issues.",
len(errorList), errorRate*100, maxTotalErrorsRelaxed, float64(maxTotalErrorsRelaxed)/float64(totalOperations)*100)
} else if len(errorList) > maxTotalErrorsStrict {
t.Logf("⚠️ Concurrent operations completed with %d errors (%.1f%%) - within relaxed CI limits. Normal CI environment variability.",
t.Logf("⚠️ Concurrent operations completed with %d errors (%.1f%%) - acceptable but monitor for patterns.",
len(errorList), errorRate*100)
} else if len(errorList) > 0 {
t.Logf("✅ Concurrent operations completed with %d errors (%.1f%%) - good performance for CI environment!",
t.Logf("✅ Concurrent operations completed with %d errors (%.1f%%) - excellent performance!",
len(errorList), errorRate*100)
} else {
t.Logf("🎉 All %d concurrent operations completed successfully - excellent CI performance!", totalOperations)
t.Logf("🎉 All %d concurrent operations completed successfully - perfect concurrency handling!", totalOperations)
}
})
}

51
weed/iam/oidc/oidc_provider.go

@ -2,6 +2,8 @@ package oidc
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
@ -490,6 +492,8 @@ func (p *OIDCProvider) parseJWK(key *JWK) (interface{}, error) {
switch key.Kty {
case "RSA":
return p.parseRSAKey(key)
case "EC":
return p.parseECKey(key)
default:
return nil, fmt.Errorf("unsupported key type: %s", key.Kty)
}
@ -524,6 +528,53 @@ func (p *OIDCProvider) parseRSAKey(key *JWK) (*rsa.PublicKey, error) {
return pubKey, nil
}
// parseECKey parses an Elliptic Curve key from JWK
func (p *OIDCProvider) parseECKey(key *JWK) (*ecdsa.PublicKey, error) {
// Validate required fields
if key.X == "" || key.Y == "" || key.Crv == "" {
return nil, fmt.Errorf("incomplete EC key: missing x, y, or crv parameter")
}
// Get the curve
var curve elliptic.Curve
switch key.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported EC curve: %s", key.Crv)
}
// Decode x coordinate
xBytes, err := base64.RawURLEncoding.DecodeString(key.X)
if err != nil {
return nil, fmt.Errorf("failed to decode EC x coordinate: %v", err)
}
// Decode y coordinate
yBytes, err := base64.RawURLEncoding.DecodeString(key.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode EC y coordinate: %v", err)
}
// Create EC public key
pubKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
}
// Validate that the point is on the curve
if !curve.IsOnCurve(pubKey.X, pubKey.Y) {
return nil, fmt.Errorf("EC key coordinates are not on the specified curve")
}
return pubKey, nil
}
// mapUserInfoToIdentity maps UserInfo response to ExternalIdentity
func (p *OIDCProvider) mapUserInfoToIdentity(userInfo map[string]interface{}) *providers.ExternalIdentity {
identity := &providers.ExternalIdentity{

2
weed/iam/policy/aws_iam_compliance_test.go

@ -200,7 +200,7 @@ func TestAWSWildcardMatch(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := awsWildcardMatch(tt.pattern, tt.value)
result := AwsWildcardMatch(tt.pattern, tt.value)
assert.Equal(t, tt.expected, result, "AWS wildcard match should match expected")
})
}

45
weed/iam/policy/policy_engine.go

@ -7,6 +7,7 @@ import (
"path/filepath"
"regexp"
"strings"
"sync"
)
// Effect represents the policy evaluation result
@ -17,6 +18,12 @@ const (
EffectDeny Effect = "Deny"
)
// Package-level regex cache for performance optimization
var (
regexCache = make(map[string]*regexp.Regexp)
regexCacheMu sync.RWMutex
)
// PolicyEngine evaluates policies against requests
type PolicyEngine struct {
config *PolicyEngineConfig
@ -624,7 +631,7 @@ func awsIAMMatch(pattern, value string, evalCtx *EvaluationContext) bool {
// Step 4: Handle AWS-style wildcards (case-insensitive)
if strings.Contains(expandedPattern, "*") || strings.Contains(expandedPattern, "?") {
return awsWildcardMatch(expandedPattern, value)
return AwsWildcardMatch(expandedPattern, value)
}
return false
@ -666,19 +673,37 @@ func getContextValue(evalCtx *EvaluationContext, key, defaultValue string) strin
return defaultValue
}
// awsWildcardMatch performs case-insensitive wildcard matching like AWS IAM
func awsWildcardMatch(pattern, value string) bool {
// Convert pattern to regex with case-insensitive matching
// AWS uses * for any sequence and ? for any single character
// AwsWildcardMatch performs case-insensitive wildcard matching like AWS IAM
func AwsWildcardMatch(pattern, value string) bool {
// Create regex pattern key for caching
regexPattern := strings.ReplaceAll(pattern, "*", ".*")
regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
regexPattern = "^" + regexPattern + "$"
regexKey := "(?i)" + regexPattern
// Compile with case-insensitive flag
regex, err := regexp.Compile("(?i)" + regexPattern)
if err != nil {
// Fallback to simple case-insensitive comparison if regex fails
return strings.EqualFold(pattern, value)
// Try to get compiled regex from cache
regexCacheMu.RLock()
regex, found := regexCache[regexKey]
regexCacheMu.RUnlock()
if !found {
// Compile and cache the regex
compiledRegex, err := regexp.Compile(regexKey)
if err != nil {
// Fallback to simple case-insensitive comparison if regex fails
return strings.EqualFold(pattern, value)
}
// Store in cache with write lock
regexCacheMu.Lock()
// Double-check in case another goroutine added it
if existingRegex, exists := regexCache[regexKey]; exists {
regex = existingRegex
} else {
regexCache[regexKey] = compiledRegex
regex = compiledRegex
}
regexCacheMu.Unlock()
}
return regex.MatchString(value)

24
weed/iam/providers/provider.go

@ -4,11 +4,10 @@ import (
"context"
"fmt"
"net/mail"
"regexp"
"strings"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/iam/policy"
)
// IdentityProvider defines the interface for external identity providers
@ -222,26 +221,7 @@ func (r *MappingRule) Matches(claims *TokenClaims) bool {
// matchValue checks if a value matches the rule value (with wildcard support)
// Uses AWS IAM-compliant case-insensitive wildcard matching for consistency with policy engine
func (r *MappingRule) matchValue(value string) bool {
matched := awsWildcardMatch(r.Value, value)
matched := policy.AwsWildcardMatch(r.Value, value)
glog.V(3).Infof("AWS IAM pattern match result: '%s' matches '%s' = %t", value, r.Value, matched)
return matched
}
// awsWildcardMatch performs case-insensitive wildcard matching like AWS IAM
// This function ensures consistent matching behavior across the IAM system
func awsWildcardMatch(pattern, value string) bool {
// Convert pattern to regex with case-insensitive matching
// AWS uses * for any sequence and ? for any single character
regexPattern := strings.ReplaceAll(pattern, "*", ".*")
regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
regexPattern = "^" + regexPattern + "$"
// Compile with case-insensitive flag
regex, err := regexp.Compile("(?i)" + regexPattern)
if err != nil {
// Fallback to simple case-insensitive comparison if regex fails
return strings.EqualFold(pattern, value)
}
return regex.MatchString(value)
}

192
weed/iam/sts/security_test.go

@ -0,0 +1,192 @@
package sts
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens
// with specific issuer claims can only be validated by the provider registered for that issuer
func TestSecurityIssuerToProviderMapping(t *testing.T) {
ctx := context.Background()
// Create STS service with two mock providers
service := NewSTSService()
config := &STSConfig{
TokenDuration: time.Hour,
MaxSessionLength: time.Hour * 12,
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Set up mock trust policy validator
mockValidator := &MockTrustPolicyValidator{}
service.SetTrustPolicyValidator(mockValidator)
// Create two mock providers with different issuers
providerA := &MockIdentityProviderWithIssuer{
name: "provider-a",
issuer: "https://provider-a.com",
validTokens: map[string]bool{
"token-for-provider-a": true,
},
}
providerB := &MockIdentityProviderWithIssuer{
name: "provider-b",
issuer: "https://provider-b.com",
validTokens: map[string]bool{
"token-for-provider-b": true,
},
}
// Register both providers
err = service.RegisterProvider(providerA)
require.NoError(t, err)
err = service.RegisterProvider(providerB)
require.NoError(t, err)
// Create JWT tokens with specific issuer claims
tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a")
tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b")
t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) {
// This should succeed - token has issuer A and provider A is registered
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA)
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Equal(t, "provider-a", provider.Name())
})
t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) {
// This should succeed - token has issuer B and provider B is registered
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB)
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Equal(t, "provider-b", provider.Name())
})
t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) {
// Create token with unregistered issuer
tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x")
// This should fail - no provider registered for this issuer
identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer)
assert.Error(t, err)
assert.Nil(t, identity)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com")
})
t.Run("non_jwt_tokens_still_work_with_fallback", func(t *testing.T) {
// Non-JWT tokens should still work via fallback mechanism
identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a")
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Equal(t, "provider-a", provider.Name())
})
}
// createTestJWT creates a test JWT token with the specified issuer and subject
func createTestJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping
type MockIdentityProviderWithIssuer struct {
name string
issuer string
validTokens map[string]bool
}
func (m *MockIdentityProviderWithIssuer) Name() string {
return m.name
}
func (m *MockIdentityProviderWithIssuer) GetIssuer() string {
return m.issuer
}
func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
// For JWT tokens, parse and validate the token format
if len(token) > 50 && strings.Contains(token, ".") {
// This looks like a JWT - parse it to get the subject
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("invalid JWT token")
}
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid claims")
}
issuer, _ := claims["iss"].(string)
subject, _ := claims["sub"].(string)
// Verify the issuer matches what we expect
if issuer != m.issuer {
return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer)
}
return &providers.ExternalIdentity{
UserID: subject,
Email: subject + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
// For non-JWT tokens, check our simple token list
if m.validTokens[token] {
return &providers.ExternalIdentity{
UserID: "test-user",
Email: "test@" + m.name + ".com",
Provider: m.name,
}, nil
}
return nil, fmt.Errorf("invalid token")
}
func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Email: userID + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if m.validTokens[token] {
return &providers.TokenClaims{
Subject: "test-user",
Issuer: m.issuer,
}, nil
}
return nil, fmt.Errorf("invalid token")
}

58
weed/iam/sts/sts_service.go

@ -346,6 +346,11 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
return nil, fmt.Errorf("invalid request: %w", err)
}
// Check for unsupported session policy
if request.Policy != nil {
return nil, fmt.Errorf("session policies are not currently supported - Policy parameter must be omitted")
}
// 1. Validate the web identity token with appropriate provider
externalIdentity, provider, err := s.validateWebIdentityToken(ctx, request.WebIdentityToken)
if err != nil {
@ -542,29 +547,44 @@ func (s *STSService) validateAssumeRoleWithWebIdentityRequest(request *AssumeRol
return nil
}
// validateWebIdentityToken validates the web identity token with available providers
// validateWebIdentityToken validates the web identity token with strict issuer-to-provider mapping
// SECURITY: JWT tokens with a specific issuer claim MUST only be validated by the provider for that issuer
func (s *STSService) validateWebIdentityToken(ctx context.Context, token string) (*providers.ExternalIdentity, providers.IdentityProvider, error) {
// First, try efficient issuer-based lookup
// Try to extract issuer from JWT token for strict validation
issuer, err := s.extractIssuerFromJWT(token)
if err == nil {
// Look up provider by issuer for O(1) efficiency
if provider, exists := s.issuerToProvider[issuer]; exists {
identity, err := provider.Authenticate(ctx, token)
if err == nil && identity != nil {
return identity, provider, nil
}
// If authentication fails with the expected provider, don't continue
// This prevents tokens from being validated by wrong providers
return nil, nil, fmt.Errorf("token validation failed with issuer %s: %w", issuer, err)
}
if err != nil {
// Token is not a valid JWT or cannot be parsed
// This can happen with non-JWT providers (e.g., LDAP) or test tokens
// For backward compatibility with non-JWT tokens, fall back to trying all providers
glog.V(2).Infof("Token is not a valid JWT (%v), falling back to all providers", err)
return s.validateWithAllProviders(ctx, token)
}
glog.V(2).Infof("No provider registered for issuer %s, falling back to brute-force search", issuer)
} else {
glog.V(2).Infof("Could not extract issuer from token (%v), falling back to brute-force search", err)
// Look up the specific provider for this issuer
provider, exists := s.issuerToProvider[issuer]
if !exists {
// SECURITY: If no provider is registered for this issuer, fail immediately
// This prevents JWT tokens from being validated by unintended providers
return nil, nil, fmt.Errorf("no identity provider registered for issuer: %s", issuer)
}
// Fallback: try all providers (backward compatibility)
// This handles providers that don't have issuer mapping or malformed tokens
// Authenticate with the correct provider for this issuer
identity, err := provider.Authenticate(ctx, token)
if err != nil {
return nil, nil, fmt.Errorf("token validation failed with provider for issuer %s: %w", issuer, err)
}
if identity == nil {
return nil, nil, fmt.Errorf("authentication succeeded but no identity returned for issuer %s", issuer)
}
return identity, provider, nil
}
// validateWithAllProviders is a fallback for non-JWT tokens (e.g., LDAP credentials, test tokens)
// This should only be used when the token is not a valid JWT
func (s *STSService) validateWithAllProviders(ctx context.Context, token string) (*providers.ExternalIdentity, providers.IdentityProvider, error) {
// Try each provider until one succeeds
for _, provider := range s.providers {
identity, err := provider.Authenticate(ctx, token)
if err == nil && identity != nil {
@ -572,7 +592,7 @@ func (s *STSService) validateWebIdentityToken(ctx context.Context, token string)
}
}
return nil, nil, fmt.Errorf("web identity token validation failed with all providers")
return nil, nil, fmt.Errorf("token validation failed with all providers")
}
// extractIssuerFromJWT extracts the issuer (iss) claim from a JWT token without verification

11
weed/iam/sts/sts_service_test.go

@ -410,3 +410,14 @@ func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string)
}
return nil, fmt.Errorf("invalid token")
}
// GetIssuer returns the issuer for this provider (for OIDC providers)
func (m *MockIdentityProvider) GetIssuer() string {
// For OIDC mock provider, return the issuer from the first valid token
// This matches the issuer in the token claims
if m.name == "test-oidc" {
return "test-issuer"
}
// LDAP providers don't have issuers
return ""
}

94
weed/s3api/s3_iam_middleware.go

@ -6,6 +6,7 @@ import (
"net"
"net/http"
"net/url"
"sort"
"strings"
"time"
@ -769,14 +770,25 @@ func (s3iam *S3IAMIntegration) validateOIDCToken(ctx context.Context, token stri
}
// Parse roles (stored as comma-separated string)
roles := strings.Split(rolesAttr, ",")
if len(roles) == 0 {
rolesStr := strings.TrimSpace(rolesAttr)
roles := strings.Split(rolesStr, ",")
// Clean up role names
var cleanRoles []string
for _, role := range roles {
cleanRole := strings.TrimSpace(role)
if cleanRole != "" {
cleanRoles = append(cleanRoles, cleanRole)
}
}
if len(cleanRoles) == 0 {
glog.V(3).Infof("Empty roles list from provider %s", providerName)
continue
}
// Use the first role as the primary role
roleArn := roles[0]
// Determine the primary role using intelligent selection
roleArn := s3iam.selectPrimaryRole(cleanRoles, externalIdentity)
return &OIDCIdentity{
UserID: externalIdentity.UserID,
@ -788,6 +800,80 @@ func (s3iam *S3IAMIntegration) validateOIDCToken(ctx context.Context, token stri
return nil, fmt.Errorf("token not valid for any registered OIDC provider")
}
// selectPrimaryRole intelligently selects the primary role from multiple available roles
// This provides deterministic role selection to prevent unpredictable access control behavior
func (s3iam *S3IAMIntegration) selectPrimaryRole(roles []string, externalIdentity *providers.ExternalIdentity) string {
if len(roles) == 1 {
return roles[0]
}
glog.V(2).Infof("🔍 selectPrimaryRole: Selecting from %d roles: %v", len(roles), roles)
// Strategy 1: Check for explicit primary_role claim
if primaryRole, exists := externalIdentity.Attributes["primary_role"]; exists && primaryRole != "" {
primaryRole = strings.TrimSpace(primaryRole)
// Verify the primary role is in the available roles list
for _, role := range roles {
if strings.EqualFold(role, primaryRole) {
glog.V(2).Infof("🔍 selectPrimaryRole: Using explicit primary_role: %s", role)
return role
}
}
glog.V(1).Infof("⚠️ selectPrimaryRole: primary_role '%s' not found in available roles, falling back", primaryRole)
}
// Strategy 2: Role hierarchy - select most privileged role
selectedRole := s3iam.selectByRoleHierarchy(roles)
if selectedRole != "" {
glog.V(2).Infof("🔍 selectPrimaryRole: Using hierarchical selection: %s", selectedRole)
return selectedRole
}
// Strategy 3: Deterministic fallback - alphabetical order (consistent behavior)
sort.Strings(roles)
glog.V(2).Infof("🔍 selectPrimaryRole: Using deterministic selection (first alphabetically): %s", roles[0])
return roles[0]
}
// selectByRoleHierarchy selects a role based on predefined privilege hierarchy
// Returns the most privileged role available, or empty string if no hierarchy match
func (s3iam *S3IAMIntegration) selectByRoleHierarchy(roles []string) string {
// Define role hierarchy from most privileged to least privileged
// This covers common enterprise role naming patterns
roleHierarchy := [][]string{
// Tier 1: Super Admin roles
{"SuperAdmin", "super-admin", "super_admin", "root", "owner"},
// Tier 2: Admin roles
{"Admin", "admin", "Administrator", "administrator", "system-admin", "system_admin"},
// Tier 3: Manager/Power User roles
{"Manager", "manager", "PowerUser", "power-user", "power_user", "lead", "supervisor"},
// Tier 4: Editor/Write roles
{"Editor", "editor", "Writer", "writer", "Contributor", "contributor", "write", "readwrite", "read-write"},
// Tier 5: Viewer/Read roles
{"Viewer", "viewer", "Reader", "reader", "read-only", "read_only", "readonly", "read", "guest"},
}
// Find the highest priority role available
for _, tier := range roleHierarchy {
for _, privilegedRole := range tier {
for _, availableRole := range roles {
// Check for exact match or contains match (case-insensitive)
if strings.EqualFold(availableRole, privilegedRole) ||
strings.Contains(strings.ToLower(availableRole), strings.ToLower(privilegedRole)) {
return availableRole
}
}
}
}
// No hierarchy match found
return ""
}
// getProviderNames returns a list of provider names for debugging
func getProviderNames(providers map[string]providers.IdentityProvider) []string {
names := make([]string, 0, len(providers))

259
weed/s3api/s3_iam_role_selection_test.go

@ -0,0 +1,259 @@
package s3api
import (
"strings"
"testing"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
)
func TestSelectPrimaryRole(t *testing.T) {
s3iam := &S3IAMIntegration{}
t.Run("single_role_returns_that_role", func(t *testing.T) {
roles := []string{"admin"}
externalIdentity := &providers.ExternalIdentity{
Attributes: make(map[string]string),
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "admin", result)
})
t.Run("explicit_primary_role_takes_precedence", func(t *testing.T) {
roles := []string{"admin", "reader", "writer"}
externalIdentity := &providers.ExternalIdentity{
Attributes: map[string]string{
"primary_role": "reader",
},
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "reader", result)
})
t.Run("explicit_primary_role_case_insensitive", func(t *testing.T) {
roles := []string{"Admin", "Reader", "Writer"}
externalIdentity := &providers.ExternalIdentity{
Attributes: map[string]string{
"primary_role": "admin",
},
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "Admin", result)
})
t.Run("invalid_primary_role_falls_back_to_hierarchy", func(t *testing.T) {
roles := []string{"admin", "reader", "writer"}
externalIdentity := &providers.ExternalIdentity{
Attributes: map[string]string{
"primary_role": "nonexistent",
},
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "admin", result) // Should select admin via hierarchy
})
t.Run("hierarchy_selection_admin_over_reader", func(t *testing.T) {
roles := []string{"reader", "admin", "writer"}
externalIdentity := &providers.ExternalIdentity{
Attributes: make(map[string]string),
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "admin", result) // Admin has higher priority
})
t.Run("hierarchy_selection_case_insensitive", func(t *testing.T) {
roles := []string{"Reader", "ADMIN", "writer"}
externalIdentity := &providers.ExternalIdentity{
Attributes: make(map[string]string),
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "ADMIN", result)
})
t.Run("hierarchy_selection_contains_match", func(t *testing.T) {
roles := []string{"system-reader", "system-admin-user", "system-writer"}
externalIdentity := &providers.ExternalIdentity{
Attributes: make(map[string]string),
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "system-admin-user", result) // Contains "admin"
})
t.Run("deterministic_fallback_alphabetical", func(t *testing.T) {
// Roles that don't match any hierarchy
roles := []string{"zebra", "alpha", "beta"}
externalIdentity := &providers.ExternalIdentity{
Attributes: make(map[string]string),
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "alpha", result) // First alphabetically
})
t.Run("complex_enterprise_roles", func(t *testing.T) {
roles := []string{
"app-user-readonly",
"app-user-contributor",
"app-admin-full",
"system-guest",
}
externalIdentity := &providers.ExternalIdentity{
Attributes: make(map[string]string),
}
result := s3iam.selectPrimaryRole(roles, externalIdentity)
assert.Equal(t, "app-admin-full", result) // Contains "admin"
})
}
func TestSelectByRoleHierarchy(t *testing.T) {
s3iam := &S3IAMIntegration{}
t.Run("super_admin_highest_priority", func(t *testing.T) {
roles := []string{"admin", "super-admin", "reader"}
result := s3iam.selectByRoleHierarchy(roles)
assert.Equal(t, "super-admin", result)
})
t.Run("admin_over_manager", func(t *testing.T) {
roles := []string{"manager", "admin", "reader"}
result := s3iam.selectByRoleHierarchy(roles)
assert.Equal(t, "admin", result)
})
t.Run("manager_over_editor", func(t *testing.T) {
roles := []string{"editor", "manager", "reader"}
result := s3iam.selectByRoleHierarchy(roles)
assert.Equal(t, "manager", result)
})
t.Run("editor_over_viewer", func(t *testing.T) {
roles := []string{"viewer", "editor"}
result := s3iam.selectByRoleHierarchy(roles)
assert.Equal(t, "editor", result)
})
t.Run("no_hierarchy_match_returns_empty", func(t *testing.T) {
roles := []string{"custom-role-1", "custom-role-2", "special-user"}
result := s3iam.selectByRoleHierarchy(roles)
assert.Equal(t, "", result)
})
t.Run("multiple_same_tier_returns_first_found", func(t *testing.T) {
roles := []string{"viewer", "reader", "guest"}
result := s3iam.selectByRoleHierarchy(roles)
// Should return first match found in the hierarchy (viewer comes first in tier definition)
assert.Equal(t, "viewer", result)
})
t.Run("case_variations", func(t *testing.T) {
roles := []string{"ADMIN", "Reader", "writer"}
result := s3iam.selectByRoleHierarchy(roles)
assert.Equal(t, "ADMIN", result)
})
}
func TestRoleSelectionIntegration(t *testing.T) {
t.Run("real_world_enterprise_scenario", func(t *testing.T) {
// Simulate a real enterprise OIDC token with multiple roles
testCases := []struct {
name string
roles []string
primaryRole string // explicit primary_role claim
expectedRole string
selectionType string
}{
{
name: "explicit_primary_overrides_hierarchy",
roles: []string{"admin", "reader", "writer"},
primaryRole: "reader",
expectedRole: "reader",
selectionType: "explicit",
},
{
name: "hierarchy_selects_admin_over_others",
roles: []string{"contributor", "admin", "viewer"},
primaryRole: "", // No explicit primary
expectedRole: "admin",
selectionType: "hierarchy",
},
{
name: "deterministic_fallback_for_unknown_roles",
roles: []string{"zebra-role", "alpha-role", "beta-role"},
primaryRole: "",
expectedRole: "alpha-role",
selectionType: "deterministic",
},
{
name: "complex_enterprise_naming",
roles: []string{"org-user-readonly", "org-power-user", "org-system-admin"},
primaryRole: "",
expectedRole: "org-system-admin", // Contains "admin"
selectionType: "hierarchy",
},
}
s3iam := &S3IAMIntegration{}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
externalIdentity := &providers.ExternalIdentity{
Attributes: make(map[string]string),
}
if tc.primaryRole != "" {
externalIdentity.Attributes["primary_role"] = tc.primaryRole
}
result := s3iam.selectPrimaryRole(tc.roles, externalIdentity)
assert.Equal(t, tc.expectedRole, result,
"Expected %s selection to return %s, got %s",
tc.selectionType, tc.expectedRole, result)
})
}
})
}
// Test helper function to verify role parsing improvements
func TestRoleParsingImprovements(t *testing.T) {
t.Run("whitespace_handling", func(t *testing.T) {
// Test the improved role parsing logic
rolesStr := " admin , reader , writer "
roles := strings.Split(rolesStr, ",")
// Clean up role names (this is what the main code does now)
var cleanRoles []string
for _, role := range roles {
cleanRole := strings.TrimSpace(role)
if cleanRole != "" {
cleanRoles = append(cleanRoles, cleanRole)
}
}
expected := []string{"admin", "reader", "writer"}
assert.Equal(t, expected, cleanRoles)
})
t.Run("empty_roles_filtered", func(t *testing.T) {
rolesStr := "admin,,reader, ,writer"
roles := strings.Split(rolesStr, ",")
var cleanRoles []string
for _, role := range roles {
cleanRole := strings.TrimSpace(role)
if cleanRole != "" {
cleanRoles = append(cleanRoles, cleanRole)
}
}
expected := []string{"admin", "reader", "writer"}
assert.Equal(t, expected, cleanRoles)
})
}
Loading…
Cancel
Save