Browse Source

fix: address PR feedback (Round 5) - JWT tokens, ARN formatting, PrincipalArn

CRITICAL FIXES:
- Replace standalone credential generation with STS service JWT tokens
  - handleAssumeRole now generates proper JWT session tokens
  - handleAssumeRoleWithLDAPIdentity now generates proper JWT session tokens
  - Session tokens can be validated across distributed instances

- Fix ARN formatting in responses
  - Extract role name from ARN using utils.ExtractRoleNameFromArn()
  - Prevents malformed ARNs like "arn:aws:sts::assumed-role/arn:aws:iam::..."

- Add configurable AccountId for federated users
  - Add AccountId field to STSConfig (defaults to "111122223333")
  - PrincipalArn now uses configured account ID instead of hardcoded "aws"
  - Enables proper trust policy validation

IMPROVEMENTS:
- Sanitize LDAP authentication error messages (don't leak internal details)
- Remove duplicate comment in provider detection
- Add utils import for ARN parsing utilities
pull/8003/head
Chris Lu 2 days ago
parent
commit
31df6b1ac4
  1. 7
      weed/iam/sts/sts_service.go
  2. 138
      weed/s3api/s3api_sts.go

7
weed/iam/sts/sts_service.go

@ -95,8 +95,15 @@ type STSConfig struct {
// SigningKey is used to sign session tokens
SigningKey []byte `json:"signingKey"`
// AccountId is the AWS account ID used for federated user ARNs
// Defaults to "111122223333" if not specified
AccountId string `json:"accountId,omitempty"`
// Providers configuration - enables automatic provider loading
Providers []*ProviderConfig `json:"providers,omitempty"`
// TokenGenerator is used internally for JWT generation (not serialized)
TokenGenerator *TokenGenerator `json:"-"`
}
// ProviderConfig holds identity provider configuration

138
weed/s3api/s3api_sts.go

@ -5,19 +5,17 @@ package s3api
// AWS SDKs to obtain temporary credentials using OIDC/JWT tokens.
import (
"crypto/rand"
"encoding/hex"
"encoding/xml"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/iam/ldap"
"github.com/seaweedfs/seaweedfs/weed/iam/sts"
"github.com/seaweedfs/seaweedfs/weed/iam/utils"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
)
@ -68,37 +66,9 @@ func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) {
return &ds, "", nil
}
// generateSecureCredentials generates cryptographically secure temporary credentials
func generateSecureCredentials(userPrefix string, duration time.Duration) (accessKey, secretKey, sessionToken string, expiration time.Time, err error) {
// Generate access key with prefix and random suffix
accessKeyBytes := make([]byte, 10)
if _, err = rand.Read(accessKeyBytes); err != nil {
return "", "", "", time.Time{}, fmt.Errorf("failed to generate access key: %w", err)
}
// Use ASIA prefix for temporary credentials (AWS convention)
prefixLen := len(userPrefix)
if prefixLen > 4 {
prefixLen = 4
}
accessKey = fmt.Sprintf("ASIA%s%s", strings.ToUpper(userPrefix[:prefixLen]), hex.EncodeToString(accessKeyBytes)[:12])
// Generate cryptographically secure secret key (40 hex characters = 20 bytes)
secretKeyBytes := make([]byte, 20)
if _, err = rand.Read(secretKeyBytes); err != nil {
return "", "", "", time.Time{}, fmt.Errorf("failed to generate secret key: %w", err)
}
secretKey = hex.EncodeToString(secretKeyBytes)
// Generate session token (64 hex characters = 32 bytes)
sessionTokenBytes := make([]byte, 32)
if _, err = rand.Read(sessionTokenBytes); err != nil {
return "", "", "", time.Time{}, fmt.Errorf("failed to generate session token: %w", err)
}
sessionToken = hex.EncodeToString(sessionTokenBytes)
expiration = time.Now().Add(duration)
return accessKey, secretKey, sessionToken, expiration, nil
}
// Removed generateSecureCredentials - now using STS service's JWT token generation
// The STS service generates proper JWT tokens with embedded claims that can be validated
// across distributed instances without shared state.
// STSHandlers provides HTTP handlers for STS operations
type STSHandlers struct {
@ -310,26 +280,56 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
duration = time.Duration(*durationSeconds) * time.Second
}
// Generate cryptographically secure temporary credentials
tempAccessKey, tempSecretKey, sessionToken, expiration, err := generateSecureCredentials(identity.Name, duration)
// Generate session ID and create JWT token with embedded claims
sessionId, err := sts.GenerateSessionId()
if err != nil {
glog.Errorf("AssumeRole: failed to generate session ID: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
expiration := time.Now().Add(duration)
// Extract role name from ARN for proper response formatting
roleName := utils.ExtractRoleNameFromArn(roleArn)
if roleName == "" {
roleName = roleArn // Fallback to full ARN if extraction fails
}
// Create session claims with role information
claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration).
WithSessionName(roleSessionName).
WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), identity.PrincipalArn)
// Generate JWT session token
sessionToken, err := h.stsService.Config.TokenGenerator.GenerateJWTWithClaims(claims)
if err != nil {
glog.Errorf("AssumeRole: failed to generate session token: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
// Generate temporary credentials from session ID (deterministic)
credGen := sts.NewCredentialGenerator()
creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration)
if err != nil {
glog.Errorf("AssumeRole: failed to generate credentials: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
// Build and return response
// Build and return response with proper ARN formatting
xmlResponse := &AssumeRoleResponse{
Result: AssumeRoleResult{
Credentials: STSCredentials{
AccessKeyId: tempAccessKey,
SecretAccessKey: tempSecretKey,
AccessKeyId: creds.AccessKeyId,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: sessionToken,
Expiration: expiration.Format(time.RFC3339),
},
AssumedRoleUser: &AssumedRoleUser{
AssumedRoleId: fmt.Sprintf("%s:%s", roleArn, roleSessionName),
Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleArn, roleSessionName),
AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName),
Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleName, roleSessionName),
},
},
}
@ -385,7 +385,6 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
return
}
// Find an LDAP provider from the registered providers
// Find an LDAP provider from the registered providers
var ldapProvider *ldap.LDAPProvider
for _, provider := range h.stsService.GetProviders() {
@ -410,7 +409,7 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
if err != nil {
glog.V(2).Infof("AssumeRoleWithLDAPIdentity: LDAP authentication failed for user %s: %v", ldapUsername, err)
h.writeSTSErrorResponse(w, r, STSErrAccessDenied,
fmt.Errorf("LDAP authentication failed: %v", err))
fmt.Errorf("authentication failed"))
return
}
@ -420,6 +419,12 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
// Verify that the identity is allowed to assume the role
// We create a temporary identity to represent the LDAP user for permission checking
// The checking logic will verify if the role's trust policy allows this principal
// Use configured account ID or default to "111122223333" for federated users
accountId := "111122223333" // Default account ID for federated users
if h.stsService != nil && h.stsService.Config != nil && h.stsService.Config.AccountId != "" {
accountId = h.stsService.Config.AccountId
}
ldapUserIdentity := &Identity{
Name: identity.UserID,
Account: &Account{
@ -427,7 +432,7 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
EmailAddress: identity.Email,
Id: identity.UserID,
},
PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/%s", "aws", identity.UserID),
PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/%s", accountId, identity.UserID),
}
if authErr := h.iam.VerifyActionPermission(r, ldapUserIdentity, Action("sts:AssumeRole"), "", roleArn); authErr != s3err.ErrNone {
@ -442,26 +447,57 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
duration = time.Duration(*durationSeconds) * time.Second
}
// Generate cryptographically secure temporary credentials
tempAccessKey, tempSecretKey, sessionToken, expiration, err := generateSecureCredentials(ldapUsername, duration)
// Generate session ID and create JWT token with embedded claims
sessionId, err := sts.GenerateSessionId()
if err != nil {
glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session ID: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
expiration := time.Now().Add(duration)
// Extract role name from ARN for proper response formatting
roleName := utils.ExtractRoleNameFromArn(roleArn)
if roleName == "" {
roleName = roleArn // Fallback to full ARN if extraction fails
}
// Create session claims with role and LDAP provider information
claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration).
WithSessionName(roleSessionName).
WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), ldapUserIdentity.PrincipalArn).
WithIdentityProvider("ldap", identity.UserID, identity.Provider)
// Generate JWT session token
sessionToken, err := h.stsService.Config.TokenGenerator.GenerateJWTWithClaims(claims)
if err != nil {
glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session token: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
// Generate temporary credentials from session ID (deterministic)
credGen := sts.NewCredentialGenerator()
creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration)
if err != nil {
glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate credentials: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
// Build and return response
// Build and return response with proper ARN formatting
xmlResponse := &AssumeRoleWithLDAPIdentityResponse{
Result: LDAPIdentityResult{
Credentials: STSCredentials{
AccessKeyId: tempAccessKey,
SecretAccessKey: tempSecretKey,
AccessKeyId: creds.AccessKeyId,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: sessionToken,
Expiration: expiration.Format(time.RFC3339),
},
AssumedRoleUser: &AssumedRoleUser{
AssumedRoleId: fmt.Sprintf("%s:%s", roleArn, roleSessionName),
Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleArn, roleSessionName),
AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName),
Arn: fmt.Sprintf("arn:aws:sts::assumed-role/%s/%s", roleName, roleSessionName),
},
},
}

Loading…
Cancel
Save