Browse Source

refactor(s3api): extract shared STS credential generation logic

- Move common logic for session claims and credential generation to prepareSTSCredentials
- Update handleAssumeRole and handleAssumeRoleWithLDAPIdentity to use the helper
- Remove stale comments referencing outdated line numbers
pull/8003/head
Chris Lu 2 days ago
parent
commit
678aeeff0d
  1. 145
      weed/s3api/s3api_sts.go

145
weed/s3api/s3api_sts.go

@ -274,69 +274,18 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
return return
} }
// Calculate duration
duration := time.Hour // Default 1 hour
if durationSeconds != nil {
duration = time.Duration(*durationSeconds) * time.Second
}
// Generate session ID and create JWT token with embedded claims
sessionId, err := sts.GenerateSessionId()
if err != nil {
glog.Errorf("AssumeRole: failed to generate session ID: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
expiration := time.Now().Add(duration)
// Extract role name from ARN for proper response formatting
roleName := utils.ExtractRoleNameFromArn(roleArn)
if roleName == "" {
roleName = roleArn // Fallback to full ARN if extraction fails
}
// Create session claims with role information
claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration).
WithSessionName(roleSessionName).
WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), identity.PrincipalArn)
// Generate JWT session token
sessionToken, err := h.stsService.TokenGenerator.GenerateJWTWithClaims(claims)
// Generate common STS components
stsCreds, assumedUser, err := h.prepareSTSCredentials(roleArn, roleSessionName, identity.PrincipalArn, durationSeconds, nil)
if err != nil { if err != nil {
glog.Errorf("AssumeRole: failed to generate session token: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err) h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return return
} }
// Generate temporary credentials from session ID (deterministic)
credGen := sts.NewCredentialGenerator()
creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration)
if err != nil {
glog.Errorf("AssumeRole: failed to generate credentials: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
// Get account ID from STS config or use default
accountId := "111122223333" // Default account ID
if h.stsService != nil && h.stsService.Config != nil && h.stsService.Config.AccountId != "" {
accountId = h.stsService.Config.AccountId
}
// Build and return response with proper ARN formatting
// Build and return response
xmlResponse := &AssumeRoleResponse{ xmlResponse := &AssumeRoleResponse{
Result: AssumeRoleResult{ Result: AssumeRoleResult{
Credentials: STSCredentials{
AccessKeyId: creds.AccessKeyId,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: sessionToken,
Expiration: expiration.Format(time.RFC3339),
},
AssumedRoleUser: &AssumedRoleUser{
AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName),
Arn: fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountId, roleName, roleSessionName),
},
Credentials: stsCreds,
AssumedRoleUser: assumedUser,
}, },
} }
xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano()) xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano())
@ -447,18 +396,43 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
return return
} }
// Generate common STS components with LDAP-specific claims
modifyClaims := func(claims *sts.STSSessionClaims) {
claims.WithIdentityProvider("ldap", identity.UserID, identity.Provider)
}
stsCreds, assumedUser, err := h.prepareSTSCredentials(roleArn, roleSessionName, ldapUserIdentity.PrincipalArn, durationSeconds, modifyClaims)
if err != nil {
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
}
// Build and return response
xmlResponse := &AssumeRoleWithLDAPIdentityResponse{
Result: LDAPIdentityResult{
Credentials: stsCreds,
AssumedRoleUser: assumedUser,
},
}
xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano())
s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse)
}
// prepareSTSCredentials extracts common shared logic for credential generation
func (h *STSHandlers) prepareSTSCredentials(roleArn, roleSessionName, principalArn string,
durationSeconds *int64, modifyClaims func(*sts.STSSessionClaims)) (STSCredentials, *AssumedRoleUser, error) {
// Calculate duration // Calculate duration
duration := time.Hour // Default 1 hour duration := time.Hour // Default 1 hour
if durationSeconds != nil { if durationSeconds != nil {
duration = time.Duration(*durationSeconds) * time.Second duration = time.Duration(*durationSeconds) * time.Second
} }
// Generate session ID and create JWT token with embedded claims
// Generate session ID
sessionId, err := sts.GenerateSessionId() sessionId, err := sts.GenerateSessionId()
if err != nil { if err != nil {
glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session ID: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
return STSCredentials{}, nil, fmt.Errorf("failed to generate session ID: %w", err)
} }
expiration := time.Now().Add(duration) expiration := time.Now().Add(duration)
@ -469,49 +443,48 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
roleName = roleArn // Fallback to full ARN if extraction fails roleName = roleArn // Fallback to full ARN if extraction fails
} }
// Create session claims with role and LDAP provider information
// Create session claims with role information
claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration).
WithSessionName(roleSessionName). WithSessionName(roleSessionName).
WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), ldapUserIdentity.PrincipalArn).
WithIdentityProvider("ldap", identity.UserID, identity.Provider)
WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), principalArn)
// Apply custom claims if provided (e.g., LDAP identity)
if modifyClaims != nil {
modifyClaims(claims)
}
// Generate JWT session token // Generate JWT session token
sessionToken, err := h.stsService.TokenGenerator.GenerateJWTWithClaims(claims) sessionToken, err := h.stsService.TokenGenerator.GenerateJWTWithClaims(claims)
if err != nil { if err != nil {
glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session token: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
return STSCredentials{}, nil, fmt.Errorf("failed to generate session token: %w", err)
} }
// Generate temporary credentials from session ID (deterministic) // Generate temporary credentials from session ID (deterministic)
credGen := sts.NewCredentialGenerator() credGen := sts.NewCredentialGenerator()
creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration) creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration)
if err != nil { if err != nil {
glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate credentials: %v", err)
h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return
return STSCredentials{}, nil, fmt.Errorf("failed to generate credentials: %w", err)
} }
// Build and return response with proper ARN formatting
// accountId is already defined above (line 423-426)
// Get account ID from STS config or use default
accountId := "111122223333" // Default account ID
if h.stsService != nil && h.stsService.Config != nil && h.stsService.Config.AccountId != "" {
accountId = h.stsService.Config.AccountId
}
xmlResponse := &AssumeRoleWithLDAPIdentityResponse{
Result: LDAPIdentityResult{
Credentials: STSCredentials{
AccessKeyId: creds.AccessKeyId,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: sessionToken,
Expiration: expiration.Format(time.RFC3339),
},
AssumedRoleUser: &AssumedRoleUser{
AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName),
Arn: fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountId, roleName, roleSessionName),
},
},
stsCreds := STSCredentials{
AccessKeyId: creds.AccessKeyId,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: sessionToken,
Expiration: expiration.Format(time.RFC3339),
} }
xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano())
s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse)
assumedUser := &AssumedRoleUser{
AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName),
Arn: fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountId, roleName, roleSessionName),
}
return stsCreds, assumedUser, nil
} }
// STS Response types for XML marshaling // STS Response types for XML marshaling

Loading…
Cancel
Save