diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index 64cee26e4..4ae8a24c7 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -274,69 +274,18 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { return } - // Calculate duration - duration := time.Hour // Default 1 hour - if durationSeconds != nil { - duration = time.Duration(*durationSeconds) * time.Second - } - - // Generate session ID and create JWT token with embedded claims - sessionId, err := sts.GenerateSessionId() - if err != nil { - glog.Errorf("AssumeRole: failed to generate session ID: %v", err) - h.writeSTSErrorResponse(w, r, STSErrInternalError, err) - return - } - - expiration := time.Now().Add(duration) - - // Extract role name from ARN for proper response formatting - roleName := utils.ExtractRoleNameFromArn(roleArn) - if roleName == "" { - roleName = roleArn // Fallback to full ARN if extraction fails - } - - // Create session claims with role information - claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). - WithSessionName(roleSessionName). - WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), identity.PrincipalArn) - - // Generate JWT session token - sessionToken, err := h.stsService.TokenGenerator.GenerateJWTWithClaims(claims) + // Generate common STS components + stsCreds, assumedUser, err := h.prepareSTSCredentials(roleArn, roleSessionName, identity.PrincipalArn, durationSeconds, nil) if err != nil { - glog.Errorf("AssumeRole: failed to generate session token: %v", err) h.writeSTSErrorResponse(w, r, STSErrInternalError, err) return } - // Generate temporary credentials from session ID (deterministic) - credGen := sts.NewCredentialGenerator() - creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration) - if err != nil { - glog.Errorf("AssumeRole: failed to generate credentials: %v", err) - h.writeSTSErrorResponse(w, r, STSErrInternalError, err) - return - } - - // Get account ID from STS config or use default - accountId := "111122223333" // Default account ID - if h.stsService != nil && h.stsService.Config != nil && h.stsService.Config.AccountId != "" { - accountId = h.stsService.Config.AccountId - } - - // Build and return response with proper ARN formatting + // Build and return response xmlResponse := &AssumeRoleResponse{ Result: AssumeRoleResult{ - Credentials: STSCredentials{ - AccessKeyId: creds.AccessKeyId, - SecretAccessKey: creds.SecretAccessKey, - SessionToken: sessionToken, - Expiration: expiration.Format(time.RFC3339), - }, - AssumedRoleUser: &AssumedRoleUser{ - AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName), - Arn: fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountId, roleName, roleSessionName), - }, + Credentials: stsCreds, + AssumedRoleUser: assumedUser, }, } xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano()) @@ -447,18 +396,43 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r return } + // Generate common STS components with LDAP-specific claims + modifyClaims := func(claims *sts.STSSessionClaims) { + claims.WithIdentityProvider("ldap", identity.UserID, identity.Provider) + } + + stsCreds, assumedUser, err := h.prepareSTSCredentials(roleArn, roleSessionName, ldapUserIdentity.PrincipalArn, durationSeconds, modifyClaims) + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrInternalError, err) + return + } + + // Build and return response + xmlResponse := &AssumeRoleWithLDAPIdentityResponse{ + Result: LDAPIdentityResult{ + Credentials: stsCreds, + AssumedRoleUser: assumedUser, + }, + } + xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano()) + + s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) +} + +// prepareSTSCredentials extracts common shared logic for credential generation +func (h *STSHandlers) prepareSTSCredentials(roleArn, roleSessionName, principalArn string, + durationSeconds *int64, modifyClaims func(*sts.STSSessionClaims)) (STSCredentials, *AssumedRoleUser, error) { + // Calculate duration duration := time.Hour // Default 1 hour if durationSeconds != nil { duration = time.Duration(*durationSeconds) * time.Second } - // Generate session ID and create JWT token with embedded claims + // Generate session ID sessionId, err := sts.GenerateSessionId() if err != nil { - glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session ID: %v", err) - h.writeSTSErrorResponse(w, r, STSErrInternalError, err) - return + return STSCredentials{}, nil, fmt.Errorf("failed to generate session ID: %w", err) } expiration := time.Now().Add(duration) @@ -469,49 +443,48 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r roleName = roleArn // Fallback to full ARN if extraction fails } - // Create session claims with role and LDAP provider information + // Create session claims with role information claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). WithSessionName(roleSessionName). - WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), ldapUserIdentity.PrincipalArn). - WithIdentityProvider("ldap", identity.UserID, identity.Provider) + WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), principalArn) + + // Apply custom claims if provided (e.g., LDAP identity) + if modifyClaims != nil { + modifyClaims(claims) + } // Generate JWT session token sessionToken, err := h.stsService.TokenGenerator.GenerateJWTWithClaims(claims) if err != nil { - glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate session token: %v", err) - h.writeSTSErrorResponse(w, r, STSErrInternalError, err) - return + return STSCredentials{}, nil, fmt.Errorf("failed to generate session token: %w", err) } // Generate temporary credentials from session ID (deterministic) credGen := sts.NewCredentialGenerator() creds, err := credGen.GenerateTemporaryCredentials(sessionId, expiration) if err != nil { - glog.Errorf("AssumeRoleWithLDAPIdentity: failed to generate credentials: %v", err) - h.writeSTSErrorResponse(w, r, STSErrInternalError, err) - return + return STSCredentials{}, nil, fmt.Errorf("failed to generate credentials: %w", err) } - // Build and return response with proper ARN formatting - // accountId is already defined above (line 423-426) + // Get account ID from STS config or use default + accountId := "111122223333" // Default account ID + if h.stsService != nil && h.stsService.Config != nil && h.stsService.Config.AccountId != "" { + accountId = h.stsService.Config.AccountId + } - xmlResponse := &AssumeRoleWithLDAPIdentityResponse{ - Result: LDAPIdentityResult{ - Credentials: STSCredentials{ - AccessKeyId: creds.AccessKeyId, - SecretAccessKey: creds.SecretAccessKey, - SessionToken: sessionToken, - Expiration: expiration.Format(time.RFC3339), - }, - AssumedRoleUser: &AssumedRoleUser{ - AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName), - Arn: fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountId, roleName, roleSessionName), - }, - }, + stsCreds := STSCredentials{ + AccessKeyId: creds.AccessKeyId, + SecretAccessKey: creds.SecretAccessKey, + SessionToken: sessionToken, + Expiration: expiration.Format(time.RFC3339), } - xmlResponse.ResponseMetadata.RequestId = fmt.Sprintf("%d", time.Now().UnixNano()) - s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) + assumedUser := &AssumedRoleUser{ + AssumedRoleId: fmt.Sprintf("%s:%s", roleName, roleSessionName), + Arn: fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountId, roleName, roleSessionName), + } + + return stsCreds, assumedUser, nil } // STS Response types for XML marshaling