From 31df6b1ac41ab88869a1f3f6162105737981d6ba Mon Sep 17 00:00:00 2001 From: Chris Lu Date: Sun, 11 Jan 2026 20:46:10 -0800 Subject: [PATCH] fix: address PR feedback (Round 5) - JWT tokens, ARN formatting, PrincipalArn CRITICAL FIXES: - Replace standalone credential generation with STS service JWT tokens - handleAssumeRole now generates proper JWT session tokens - handleAssumeRoleWithLDAPIdentity now generates proper JWT session tokens - Session tokens can be validated across distributed instances - Fix ARN formatting in responses - Extract role name from ARN using utils.ExtractRoleNameFromArn() - Prevents malformed ARNs like "arn:aws:sts::assumed-role/arn:aws:iam::..." - Add configurable AccountId for federated users - Add AccountId field to STSConfig (defaults to "111122223333") - PrincipalArn now uses configured account ID instead of hardcoded "aws" - Enables proper trust policy validation IMPROVEMENTS: - Sanitize LDAP authentication error messages (don't leak internal details) - Remove duplicate comment in provider detection - Add utils import for ARN parsing utilities --- weed/iam/sts/sts_service.go | 7 ++ weed/s3api/s3api_sts.go | 138 +++++++++++++++++++++++------------- 2 files changed, 94 insertions(+), 51 deletions(-) diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go index 1d3716099..ceee3a8ec 100644 --- a/weed/iam/sts/sts_service.go +++ b/weed/iam/sts/sts_service.go @@ -95,8 +95,15 @@ type STSConfig struct { // SigningKey is used to sign session tokens SigningKey []byte `json:"signingKey"` + // AccountId is the AWS account ID used for federated user ARNs + // Defaults to "111122223333" if not specified + AccountId string `json:"accountId,omitempty"` + // Providers configuration - enables automatic provider loading Providers []*ProviderConfig `json:"providers,omitempty"` + + // TokenGenerator is used internally for JWT generation (not serialized) + TokenGenerator *TokenGenerator `json:"-"` } // ProviderConfig holds identity provider configuration diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index 20f02ff9c..0879d9ce1 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -5,19 +5,17 @@ package s3api // AWS SDKs to obtain temporary credentials using OIDC/JWT tokens. import ( - "crypto/rand" - "encoding/hex" "encoding/xml" "errors" "fmt" "net/http" "strconv" - "strings" "time" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/iam/ldap" "github.com/seaweedfs/seaweedfs/weed/iam/sts" + "github.com/seaweedfs/seaweedfs/weed/iam/utils" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" ) @@ -68,37 +66,9 @@ func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) { return &ds, "", nil } -// generateSecureCredentials generates cryptographically secure temporary credentials -func generateSecureCredentials(userPrefix string, duration time.Duration) (accessKey, secretKey, sessionToken string, expiration time.Time, err error) { - // Generate access key with prefix and random suffix - accessKeyBytes := make([]byte, 10) - if _, err = rand.Read(accessKeyBytes); err != nil { - return "", "", "", time.Time{}, fmt.Errorf("failed to generate access key: %w", err) - } - // Use ASIA prefix for temporary credentials (AWS convention) - prefixLen := len(userPrefix) - if prefixLen > 4 { - prefixLen = 4 - } - accessKey = fmt.Sprintf("ASIA%s%s", strings.ToUpper(userPrefix[:prefixLen]), hex.EncodeToString(accessKeyBytes)[:12]) - - // Generate cryptographically secure secret key (40 hex characters = 20 bytes) - secretKeyBytes := make([]byte, 20) - if _, err = rand.Read(secretKeyBytes); err != nil { - return "", "", "", time.Time{}, fmt.Errorf("failed to generate secret key: %w", err) - } - secretKey = hex.EncodeToString(secretKeyBytes) - - // Generate session token (64 hex characters = 32 bytes) - sessionTokenBytes := make([]byte, 32) - if _, err = rand.Read(sessionTokenBytes); err != nil { - return "", "", "", time.Time{}, fmt.Errorf("failed to generate session token: %w", err) - } - sessionToken = hex.EncodeToString(sessionTokenBytes) - - expiration = time.Now().Add(duration) - return accessKey, secretKey, sessionToken, expiration, nil -} +// Removed generateSecureCredentials - now using STS service's JWT token generation +// The STS service generates proper JWT tokens with embedded claims that can be validated +// across distributed instances without shared state. // STSHandlers provides HTTP handlers for STS operations type STSHandlers struct { @@ -310,26 +280,56 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { duration = time.Duration(*durationSeconds) * time.Second } - // Generate cryptographically secure temporary credentials - tempAccessKey, tempSecretKey, sessionToken, expiration, err := generateSecureCredentials(identity.Name, duration) + // Generate session ID and create JWT token with embedded claims + sessionId, err := sts.GenerateSessionId() + if err != nil { + glog.Errorf("AssumeRole: failed to generate session ID: %v", err) + h.writeSTSErrorResponse(w, r, STSErrInternalError, err) + return + } + + expiration := time.Now().Add(duration) + + // Extract role name from ARN for proper response formatting + roleName := utils.ExtractRoleNameFromArn(roleArn) + if roleName == "" { + roleName = roleArn // Fallback to full ARN if extraction fails + } + + // Create session claims with role information + claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). + WithSessionName(roleSessionName). + WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), identity.PrincipalArn) + + // Generate JWT session token + sessionToken, err := h.stsService.Config.TokenGenerator.GenerateJWTWithClaims(claims) + if err != nil { + glog.Errorf("AssumeRole: failed to generate session token: %v", err) + h.writeSTSErrorResponse(w, r, STSErrInternalError, err) + return + } + + // Generate temporary credentials from session ID (deterministic) + credGen := sts.NewCredentialGenerator() + creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration) if err != nil { glog.Errorf("AssumeRole: failed to generate credentials: %v", err) h.writeSTSErrorResponse(w, r, STSErrInternalError, err) return } - // Build and return response + // Build and return response with proper ARN formatting xmlResponse := &AssumeRoleResponse{ Result: AssumeRoleResult{ Credentials: STSCredentials{ - AccessKeyId: tempAccessKey, - SecretAccessKey: tempSecretKey, + AccessKeyId: creds.AccessKeyId, + SecretAccessKey: creds.SecretAccessKey, SessionToken: sessionToken, Expiration: expiration.Format(time.RFC3339), }, AssumedRoleUser: &AssumedRoleUser{ - AssumedRoleId: fmt.Sprintf("%s:%s", roleArn, roleSessionName), - Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleArn, roleSessionName), + AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName), + Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleName, roleSessionName), }, }, } @@ -385,7 +385,6 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r return } - // Find an LDAP provider from the registered providers // Find an LDAP provider from the registered providers var ldapProvider *ldap.LDAPProvider for _, provider := range h.stsService.GetProviders() { @@ -410,7 +409,7 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r if err != nil { glog.V(2).Infof("AssumeRoleWithLDAPIdentity: LDAP authentication failed for user %s: %v", ldapUsername, err) h.writeSTSErrorResponse(w, r, STSErrAccessDenied, - fmt.Errorf("LDAP authentication failed: %v", err)) + fmt.Errorf("authentication failed")) return } @@ -420,6 +419,12 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r // Verify that the identity is allowed to assume the role // We create a temporary identity to represent the LDAP user for permission checking // The checking logic will verify if the role's trust policy allows this principal + // Use configured account ID or default to "111122223333" for federated users + accountId := "111122223333" // Default account ID for federated users + if h.stsService != nil && h.stsService.Config != nil && h.stsService.Config.AccountId != "" { + accountId = h.stsService.Config.AccountId + } + ldapUserIdentity := &Identity{ Name: identity.UserID, Account: &Account{ @@ -427,7 +432,7 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r EmailAddress: identity.Email, Id: identity.UserID, }, - PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/%s", "aws", identity.UserID), + PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/%s", accountId, identity.UserID), } if authErr := h.iam.VerifyActionPermission(r, ldapUserIdentity, Action("sts:AssumeRole"), "", roleArn); authErr != s3err.ErrNone { @@ -442,26 +447,57 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r duration = time.Duration(*durationSeconds) * time.Second } - // Generate cryptographically secure temporary credentials - tempAccessKey, tempSecretKey, sessionToken, expiration, err := generateSecureCredentials(ldapUsername, duration) + // Generate session ID and create JWT token with embedded claims + sessionId, err := sts.GenerateSessionId() + if err != nil { + glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session ID: %v", err) + h.writeSTSErrorResponse(w, r, STSErrInternalError, err) + return + } + + expiration := time.Now().Add(duration) + + // Extract role name from ARN for proper response formatting + roleName := utils.ExtractRoleNameFromArn(roleArn) + if roleName == "" { + roleName = roleArn // Fallback to full ARN if extraction fails + } + + // Create session claims with role and LDAP provider information + claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). + WithSessionName(roleSessionName). + WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), ldapUserIdentity.PrincipalArn). + WithIdentityProvider("ldap", identity.UserID, identity.Provider) + + // Generate JWT session token + sessionToken, err := h.stsService.Config.TokenGenerator.GenerateJWTWithClaims(claims) + if err != nil { + glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session token: %v", err) + h.writeSTSErrorResponse(w, r, STSErrInternalError, err) + return + } + + // Generate temporary credentials from session ID (deterministic) + credGen := sts.NewCredentialGenerator() + creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration) if err != nil { glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate credentials: %v", err) h.writeSTSErrorResponse(w, r, STSErrInternalError, err) return } - // Build and return response + // Build and return response with proper ARN formatting xmlResponse := &AssumeRoleWithLDAPIdentityResponse{ Result: LDAPIdentityResult{ Credentials: STSCredentials{ - AccessKeyId: tempAccessKey, - SecretAccessKey: tempSecretKey, + AccessKeyId: creds.AccessKeyId, + SecretAccessKey: creds.SecretAccessKey, SessionToken: sessionToken, Expiration: expiration.Format(time.RFC3339), }, AssumedRoleUser: &AssumedRoleUser{ - AssumedRoleId: fmt.Sprintf("%s:%s", roleArn, roleSessionName), - Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleArn, roleSessionName), + AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName), + Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleName, roleSessionName), }, }, }