Browse Source

STS: Fallback to Caller Identity when RoleArn is missing in AssumeRole (#8345)

* s3api: make RoleArn optional in AssumeRole

* s3api: address PR feedback for optional RoleArn

* iam: add configurable default role for AssumeRole

* S3 STS: Use caller identity when RoleArn is missing

- Fallback to PrincipalArn/Context in AssumeRole if RoleArn is empty

- Handle User ARNs in prepareSTSCredentials

- Fix PrincipalArn generation for env var credentials

* Test: Add unit test for AssumeRole caller identity fallback

* fix(s3api): propagate admin permissions to assumed role session when using caller identity fallback

* STS: Fix is_admin propagation and optimize IAM policy evaluation for assumed roles

- Restore is_admin propagation via JWT req_ctx
- Optimize IsActionAllowed to skip role lookups for admin sessions
- Ensure session policies are still applied for downscoping
- Remove debug logging
- Fix syntax errors in cleanup

* fix(iam): resolve STS policy bypass for admin sessions

- Fixed IsActionAllowed in iam_manager.go to correctly identify and validate internal STS tokens, ensuring session policies are enforced.
- Refactored VerifyActionPermission in auth_credentials.go to properly handle session tokens and avoid legacy authorization short-circuits.
- Added debug logging for better tracing of policy evaluation and session validation.
pull/8347/head
Chris Lu 5 days ago
committed by GitHub
parent
commit
cf8e383e1e
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 124
      weed/iam/integration/iam_manager.go
  2. 2
      weed/iam/sts/token_utils.go
  3. 17
      weed/s3api/auth_credentials.go
  4. 5
      weed/s3api/auth_signature_v4.go
  5. 4
      weed/s3api/s3_iam_middleware.go
  6. 6
      weed/s3api/s3api_server.go
  7. 81
      weed/s3api/s3api_sts.go
  8. 153
      weed/s3api/s3api_sts_assume_role_test.go

124
weed/iam/integration/iam_manager.go

@ -323,14 +323,30 @@ func (m *IAMManager) IsActionAllowed(ctx context.Context, request *ActionRequest
return false, fmt.Errorf("IAM manager not initialized") return false, fmt.Errorf("IAM manager not initialized")
} }
// Validate session token if present (skip for OIDC tokens which are already validated,
// and skip for empty tokens which represent static access keys)
// Validate session token if present
// We always try to validate with the internal STS service first if it's a SeaweedFS token.
// This ensures that session policies embedded in the token are correctly extracted and enforced.
var sessionInfo *sts.SessionInfo var sessionInfo *sts.SessionInfo
if request.SessionToken != "" && !isOIDCToken(request.SessionToken) {
var err error
sessionInfo, err = m.stsService.ValidateSessionToken(ctx, request.SessionToken)
if err != nil {
return false, fmt.Errorf("invalid session: %w", err)
if request.SessionToken != "" {
// Parse unverified to check issuer
parsed, _, err := new(jwt.Parser).ParseUnverified(request.SessionToken, jwt.MapClaims{})
isInternal := false
if err == nil {
if claims, ok := parsed.Claims.(jwt.MapClaims); ok {
if issuer, ok := claims["iss"].(string); ok && m.stsService != nil && m.stsService.Config != nil {
if issuer == m.stsService.Config.Issuer {
isInternal = true
}
}
}
}
if isInternal || !isOIDCToken(request.SessionToken) {
var err error
sessionInfo, err = m.stsService.ValidateSessionToken(ctx, request.SessionToken)
if err != nil {
return false, fmt.Errorf("invalid session: %w", err)
}
} }
} }
@ -349,7 +365,17 @@ func (m *IAMManager) IsActionAllowed(ctx context.Context, request *ActionRequest
// Add principal to context for policy matching // Add principal to context for policy matching
// The PolicyEngine checks RequestContext["principal"] or RequestContext["aws:PrincipalArn"] // The PolicyEngine checks RequestContext["principal"] or RequestContext["aws:PrincipalArn"]
evalCtx.RequestContext["principal"] = request.Principal evalCtx.RequestContext["principal"] = request.Principal
evalCtx.RequestContext["aws:PrincipalArn"] = request.Principal
evalCtx.RequestContext["aws:PrincipalArn"] = request.Principal // AWS standard key
// Check if this is an admin request - bypass policy evaluation if so
// This mirrors the logic in auth_signature_v4.go but applies it at authorization time
isAdmin := false
if request.RequestContext != nil {
if val, ok := request.RequestContext["is_admin"].(bool); ok && val {
isAdmin = true
}
// Print full request context for debugging
}
// Parse principal ARN to extract details for context variables (e.g. ${aws:username}) // Parse principal ARN to extract details for context variables (e.g. ${aws:username})
arnInfo := utils.ParsePrincipalARN(request.Principal) arnInfo := utils.ParsePrincipalARN(request.Principal)
@ -382,48 +408,56 @@ func (m *IAMManager) IsActionAllowed(ctx context.Context, request *ActionRequest
} }
} }
policies := request.PolicyNames
if len(policies) == 0 {
// Extract role name from principal ARN
roleName := utils.ExtractRoleNameFromPrincipal(request.Principal)
if roleName == "" {
userName := utils.ExtractUserNameFromPrincipal(request.Principal)
if userName == "" {
return false, fmt.Errorf("could not extract role from principal: %s", request.Principal)
}
if m.userStore == nil {
return false, fmt.Errorf("user store unavailable for principal: %s", request.Principal)
}
user, err := m.userStore.GetUser(ctx, userName)
if err != nil || user == nil {
return false, fmt.Errorf("user not found for principal: %s (user=%s)", request.Principal, userName)
}
policies = user.GetPolicyNames()
} else {
// Get role definition
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil {
return false, fmt.Errorf("role not found: %s", roleName)
}
var baseResult *policy.EvaluationResult
var err error
if isAdmin {
// Admin always has base access allowed
baseResult = &policy.EvaluationResult{Effect: policy.EffectAllow}
} else {
policies := request.PolicyNames
if len(policies) == 0 {
// Extract role name from principal ARN
roleName := utils.ExtractRoleNameFromPrincipal(request.Principal)
if roleName == "" {
userName := utils.ExtractUserNameFromPrincipal(request.Principal)
if userName == "" {
return false, fmt.Errorf("could not extract role from principal: %s", request.Principal)
}
if m.userStore == nil {
return false, fmt.Errorf("user store unavailable for principal: %s", request.Principal)
}
user, err := m.userStore.GetUser(ctx, userName)
if err != nil || user == nil {
return false, fmt.Errorf("user not found for principal: %s (user=%s)", request.Principal, userName)
}
policies = user.GetPolicyNames()
} else {
// Get role definition
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil {
return false, fmt.Errorf("role not found: %s", roleName)
}
policies = roleDef.AttachedPolicies
policies = roleDef.AttachedPolicies
}
} }
}
if bucketPolicyName != "" {
// Enforce an upper bound on the number of policies to avoid excessive allocations
if len(policies) >= maxPoliciesForEvaluation {
return false, fmt.Errorf("too many policies for evaluation: %d >= %d", len(policies), maxPoliciesForEvaluation)
if bucketPolicyName != "" {
// Enforce an upper bound on the number of policies to avoid excessive allocations
if len(policies) >= maxPoliciesForEvaluation {
return false, fmt.Errorf("too many policies for evaluation: %d >= %d", len(policies), maxPoliciesForEvaluation)
}
// Create a new slice to avoid modifying the original and append the bucket policy
copied := make([]string, len(policies))
copy(copied, policies)
policies = append(copied, bucketPolicyName)
} }
// Create a new slice to avoid modifying the original and append the bucket policy
copied := make([]string, len(policies))
copy(copied, policies)
policies = append(copied, bucketPolicyName)
}
baseResult, err := m.policyEngine.Evaluate(ctx, "", evalCtx, policies)
if err != nil {
return false, fmt.Errorf("policy evaluation failed: %w", err)
baseResult, err = m.policyEngine.Evaluate(ctx, "", evalCtx, policies)
if err != nil {
return false, fmt.Errorf("policy evaluation failed: %w", err)
}
} }
// Base policy must allow; if it doesn't, deny immediately (session policy can only further restrict) // Base policy must allow; if it doesn't, deny immediately (session policy can only further restrict)

2
weed/iam/sts/token_utils.go

@ -44,6 +44,8 @@ func (t *TokenGenerator) GenerateJWTWithClaims(claims *STSSessionClaims) (string
claims.Issuer = t.issuer claims.Issuer = t.issuer
} }
// SECURITY: Use deterministic signing results for troubleshooting if needed,
// but standard HS256 with common secret is usually sufficient.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(t.signingKey) return token.SignedString(t.signingKey)
} }

17
weed/s3api/auth_credentials.go

@ -300,7 +300,8 @@ func (iam *IdentityAccessManagement) loadEnvironmentVariableCredentials() {
Actions: []Action{ Actions: []Action{
s3_constants.ACTION_ADMIN, s3_constants.ACTION_ADMIN,
}, },
IsStatic: true,
PrincipalArn: generatePrincipalArn(identityName),
IsStatic: true,
} }
iam.m.Lock() iam.m.Lock()
@ -1562,14 +1563,22 @@ func (iam *IdentityAccessManagement) VerifyActionPermission(r *http.Request, ide
} }
// Traditional identities (with Actions from -s3.config) use legacy auth, // Traditional identities (with Actions from -s3.config) use legacy auth,
// JWT/STS identities (no Actions) use IAM authorization
// JWT/STS identities (no Actions or having a session token) use IAM authorization.
// IMPORTANT: We MUST prioritize IAM authorization for any request with a session token
// to ensure that session policies are correctly enforced.
hasSessionToken := r.Header.Get("X-SeaweedFS-Session-Token") != "" ||
r.Header.Get("X-Amz-Security-Token") != "" ||
r.URL.Query().Get("X-Amz-Security-Token") != ""
if (len(identity.Actions) == 0 || hasSessionToken) && iam.iamIntegration != nil {
return iam.authorizeWithIAM(r, identity, action, bucket, object)
}
if len(identity.Actions) > 0 { if len(identity.Actions) > 0 {
if !identity.CanDo(action, bucket, object) { if !identity.CanDo(action, bucket, object) {
return s3err.ErrAccessDenied return s3err.ErrAccessDenied
} }
return s3err.ErrNone return s3err.ErrNone
} else if iam.iamIntegration != nil {
return iam.authorizeWithIAM(r, identity, action, bucket, object)
} }
return s3err.ErrAccessDenied return s3err.ErrAccessDenied

5
weed/s3api/auth_signature_v4.go

@ -434,6 +434,11 @@ func (iam *IdentityAccessManagement) validateSTSSessionToken(r *http.Request, se
Claims: claims, // Populate Claims for policy variable substitution Claims: claims, // Populate Claims for policy variable substitution
} }
// Restore admin privileges if the session was created by an admin
// if isAdmin, ok := claims["is_admin"].(bool); ok && isAdmin {
// identity.Actions = append(identity.Actions, s3_constants.ACTION_ADMIN)
// }
glog.V(2).Infof("Successfully validated STS session token for principal: %s, assumed role user: %s", glog.V(2).Infof("Successfully validated STS session token for principal: %s, assumed role user: %s",
sessionInfo.Principal, sessionInfo.AssumedRoleUser) sessionInfo.Principal, sessionInfo.AssumedRoleUser)
return identity, cred, s3err.ErrNone return identity, cred, s3err.ErrNone

4
weed/s3api/s3_iam_middleware.go

@ -233,6 +233,10 @@ func (s3iam *S3IAMIntegration) ValidateSessionToken(ctx context.Context, token s
// AuthorizeAction authorizes actions using our policy engine // AuthorizeAction authorizes actions using our policy engine
func (s3iam *S3IAMIntegration) AuthorizeAction(ctx context.Context, identity *IAMIdentity, action Action, bucket string, objectKey string, r *http.Request) s3err.ErrorCode { func (s3iam *S3IAMIntegration) AuthorizeAction(ctx context.Context, identity *IAMIdentity, action Action, bucket string, objectKey string, r *http.Request) s3err.ErrorCode {
fmt.Printf("DEBUG: AuthorizeAction called: Identity=%s Action=%s Bucket=%s Enabled=%v\n", identity.Name, action, bucket, s3iam.enabled)
if identity.Claims != nil {
fmt.Printf("DEBUG: AuthorizeAction Identity.Claims=%v\n", identity.Claims)
}
if !s3iam.enabled { if !s3iam.enabled {
return s3err.ErrNone // Fallback to existing authorization return s3err.ErrNone // Fallback to existing authorization
} }

6
weed/s3api/s3api_server.go

@ -852,6 +852,12 @@ func loadIAMManagerFromConfig(configPath string, filerAddressProvider func() str
if err := json.Unmarshal(configData, &configRoot); err != nil { if err := json.Unmarshal(configData, &configRoot); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err) return nil, fmt.Errorf("failed to parse config: %w", err)
} }
glog.V(0).Infof("DEBUG: Loaded IAM Config. Policy=%v. Raw JSON len=%d", configRoot.Policy, len(configData))
if configRoot.Policy != nil {
glog.V(0).Infof("DEBUG: Policy Config: DefaultEffect='%s'", configRoot.Policy.DefaultEffect)
} else {
glog.V(0).Infof("DEBUG: Policy Config is NIL")
}
// Ensure a valid policy engine config exists // Ensure a valid policy engine config exists
if configRoot.Policy == nil { if configRoot.Policy == nil {

81
weed/s3api/s3api_sts.go

@ -186,6 +186,8 @@ func (h *STSHandlers) handleAssumeRoleWithWebIdentity(w http.ResponseWriter, r *
Policy: sessionPolicyPtr, Policy: sessionPolicyPtr,
} }
glog.V(0).Infof("DEBUG: AssumeRoleWithWebIdentity: RoleArn=%s SessionPolicyLen=%d", roleArn, len(sessionPolicyJSON))
// Call STS service // Call STS service
response, err := h.stsService.AssumeRoleWithWebIdentity(ctx, request) response, err := h.stsService.AssumeRoleWithWebIdentity(ctx, request)
if err != nil { if err != nil {
@ -237,11 +239,7 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
roleSessionName := r.FormValue("RoleSessionName") roleSessionName := r.FormValue("RoleSessionName")
// Validate required parameters // Validate required parameters
if roleArn == "" {
h.writeSTSErrorResponse(w, r, STSErrMissingParameter,
fmt.Errorf("RoleArn is required"))
return
}
// RoleArn is optional to support S3-compatible clients that omit it
if roleSessionName == "" { if roleSessionName == "" {
h.writeSTSErrorResponse(w, r, STSErrMissingParameter, h.writeSTSErrorResponse(w, r, STSErrMissingParameter,
@ -290,22 +288,40 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
// Check if the caller is authorized to assume the role (sts:AssumeRole permission) // Check if the caller is authorized to assume the role (sts:AssumeRole permission)
// This validates that the caller has a policy allowing sts:AssumeRole on the target role // This validates that the caller has a policy allowing sts:AssumeRole on the target role
if authErr := h.iam.VerifyActionPermission(r, identity, Action("sts:AssumeRole"), "", roleArn); authErr != s3err.ErrNone {
glog.V(2).Infof("AssumeRole: caller %s is not authorized to assume role %s", identity.Name, roleArn)
h.writeSTSErrorResponse(w, r, STSErrAccessDenied,
fmt.Errorf("user %s is not authorized to assume role %s", identity.Name, roleArn))
return
}
// Check authorizations
if roleArn != "" {
// Check if the caller is authorized to assume the role (sts:AssumeRole permission)
if authErr := h.iam.VerifyActionPermission(r, identity, Action("sts:AssumeRole"), "", roleArn); authErr != s3err.ErrNone {
glog.V(2).Infof("AssumeRole: caller %s is not authorized to assume role %s", identity.Name, roleArn)
h.writeSTSErrorResponse(w, r, STSErrAccessDenied,
fmt.Errorf("user %s is not authorized to assume role %s", identity.Name, roleArn))
return
}
// Validate that the target role trusts the caller (Trust Policy)
// This ensures the role's trust policy explicitly allows the principal to assume it
if err := h.iam.ValidateTrustPolicyForPrincipal(r.Context(), roleArn, identity.PrincipalArn); err != nil {
glog.V(2).Infof("AssumeRole: trust policy validation failed for %s to assume %s: %v", identity.Name, roleArn, err)
h.writeSTSErrorResponse(w, r, STSErrAccessDenied, fmt.Errorf("trust policy denies access"))
return
// Validate that the target role trusts the caller (Trust Policy)
if err := h.iam.ValidateTrustPolicyForPrincipal(r.Context(), roleArn, identity.PrincipalArn); err != nil {
glog.V(2).Infof("AssumeRole: trust policy validation failed for %s to assume %s: %v", identity.Name, roleArn, err)
h.writeSTSErrorResponse(w, r, STSErrAccessDenied, fmt.Errorf("trust policy denies access"))
return
}
} else {
// If RoleArn is missing, default to the caller's identity (User Context)
// This allows the user to "assume" a session for themselves, inheriting their own permissions.
roleArn = identity.PrincipalArn
glog.V(2).Infof("AssumeRole: no RoleArn provided, defaulting to caller identity: %s", roleArn)
// We still enforce a global "sts:AssumeRole" check, similar to how we'd check if they can assume *any* role.
// However, for self-assumption, this might be implicit.
// For safety/consistency with previous logic, we keep the check but strictly it might not be required by AWS for GetSessionToken.
// But since this IS AssumeRole, let's keep it.
// Admin/Global check when no specific role is requested
if authErr := h.iam.VerifyActionPermission(r, identity, Action("sts:AssumeRole"), "", ""); authErr != s3err.ErrNone {
glog.Warningf("AssumeRole: caller %s attempted to assume role without RoleArn and lacks global sts:AssumeRole permission", identity.Name)
h.writeSTSErrorResponse(w, r, STSErrAccessDenied, fmt.Errorf("access denied"))
return
}
} }
// Parse optional inline session policy for downscoping
sessionPolicyJSON, err := sts.NormalizeSessionPolicy(r.FormValue("Policy")) sessionPolicyJSON, err := sts.NormalizeSessionPolicy(r.FormValue("Policy"))
if err != nil { if err != nil {
h.writeSTSErrorResponse(w, r, STSErrMalformedPolicyDocument, h.writeSTSErrorResponse(w, r, STSErrMalformedPolicyDocument,
@ -313,8 +329,19 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
return return
} }
// Prepare custom claims for the session
var modifyClaims func(claims *sts.STSSessionClaims)
if identity.isAdmin() {
modifyClaims = func(claims *sts.STSSessionClaims) {
if claims.RequestContext == nil {
claims.RequestContext = make(map[string]interface{})
}
claims.RequestContext["is_admin"] = true
}
}
// Generate common STS components // Generate common STS components
stsCreds, assumedUser, err := h.prepareSTSCredentials(roleArn, roleSessionName, durationSeconds, sessionPolicyJSON, nil)
stsCreds, assumedUser, err := h.prepareSTSCredentials(roleArn, roleSessionName, durationSeconds, sessionPolicyJSON, modifyClaims)
if err != nil { if err != nil {
h.writeSTSErrorResponse(w, r, STSErrInternalError, err) h.writeSTSErrorResponse(w, r, STSErrInternalError, err)
return return
@ -492,7 +519,12 @@ func (h *STSHandlers) prepareSTSCredentials(roleArn, roleSessionName string,
expiration := time.Now().Add(duration) expiration := time.Now().Add(duration)
// Extract role name from ARN for proper response formatting // Extract role name from ARN for proper response formatting
roleName := utils.ExtractRoleNameFromArn(roleArn)
roleName := utils.ExtractRoleNameFromPrincipal(roleArn)
if roleName == "" {
// Try to extract user name if it's a user ARN (for "User Context" assumption)
roleName = utils.ExtractUserNameFromPrincipal(roleArn)
}
if roleName == "" { if roleName == "" {
roleName = roleArn // Fallback to full ARN if extraction fails roleName = roleArn // Fallback to full ARN if extraction fails
} }
@ -502,12 +534,19 @@ func (h *STSHandlers) prepareSTSCredentials(roleArn, roleSessionName string,
// Construct AssumedRoleUser ARN - this will be used as the principal for the vended token // Construct AssumedRoleUser ARN - this will be used as the principal for the vended token
assumedRoleArn := fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountID, roleName, roleSessionName) assumedRoleArn := fmt.Sprintf("arn:aws:sts::%s:assumed-role/%s/%s", accountID, roleName, roleSessionName)
// Use assumedRoleArn as RoleArn in claims if original RoleArn is empty
// This ensures STSSessionClaims.IsValid() passes (it requires non-empty RoleArn)
effectiveRoleArn := roleArn
if effectiveRoleArn == "" {
effectiveRoleArn = assumedRoleArn
}
// Create session claims with role information // Create session claims with role information
// SECURITY: Use the assumedRoleArn as the principal in the token. // SECURITY: Use the assumedRoleArn as the principal in the token.
// This ensures that subsequent requests using this token are correctly identified as the assumed role. // This ensures that subsequent requests using this token are correctly identified as the assumed role.
claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration).
WithSessionName(roleSessionName). WithSessionName(roleSessionName).
WithRoleInfo(roleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), assumedRoleArn)
WithRoleInfo(effectiveRoleArn, fmt.Sprintf("%s:%s", roleName, roleSessionName), assumedRoleArn)
if sessionPolicy != "" { if sessionPolicy != "" {
claims.WithSessionPolicy(sessionPolicy) claims.WithSessionPolicy(sessionPolicy)

153
weed/s3api/s3api_sts_assume_role_test.go

@ -0,0 +1,153 @@
package s3api
import (
"context"
"fmt"
"net/http"
"net/url"
"testing"
"github.com/seaweedfs/seaweedfs/weed/iam/sts"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestAssumeRole_CallerIdentityFallback tests the fallback logic when RoleArn is missing
func TestAssumeRole_CallerIdentityFallback(t *testing.T) {
// Setup STS service
stsService, _ := setupTestSTSService(t)
// Create IAM integration mock
iamMock := &MockIAMIntegration{
authorizeFunc: func(ctx context.Context, identity *IAMIdentity, action Action, bucket, object string, r *http.Request) s3err.ErrorCode {
// Allow global sts:AssumeRole
if action == "sts:AssumeRole" {
return s3err.ErrNone
}
return s3err.ErrAccessDenied
},
validateTrustPolicyFunc: func(ctx context.Context, roleArn, principalArn string) error {
// Allow all trust policies for this test
return nil
},
}
// Create IAM service with the mock integration
iam := &IdentityAccessManagement{
iamIntegration: iamMock,
}
// Create STS handlers
stsHandlers := NewSTSHandlers(stsService, iam)
// Test case 1: Caller is an IAM User, RoleArn is missing
t.Run("Caller is IAM User, No RoleArn", func(t *testing.T) {
// Mock request
req, err := http.NewRequest("POST", "/", nil)
require.NoError(t, err)
req.Form = url.Values{}
req.Form.Set("Action", "AssumeRole")
req.Form.Set("RoleSessionName", "test-session")
req.Form.Set("Version", "2011-06-15")
// Mock the authenticated identity (IAM User)
callerIdentity := &Identity{
Name: "alice",
Account: &AccountAdmin,
PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/alice", defaultAccountID),
Actions: []Action{s3_constants.ACTION_ADMIN},
}
// 1. Test prepareSTSCredentials with NO RoleArn (simulating the fallback logic having passed PrincipalArn)
// expected RoleArn passed to prepareSTSCredentials would be the caller's PrincipalArn
fallbackRoleArn := callerIdentity.PrincipalArn
// Prepare custom claims for the session (mimicking handleAssumeRole logic)
var modifyClaims func(claims *sts.STSSessionClaims)
if callerIdentity.isAdmin() {
modifyClaims = func(claims *sts.STSSessionClaims) {
if claims.RequestContext == nil {
claims.RequestContext = make(map[string]interface{})
}
claims.RequestContext["is_admin"] = true
}
}
stsCreds, assumedUser, err := stsHandlers.prepareSTSCredentials(fallbackRoleArn, "test-session", nil, "", modifyClaims)
require.NoError(t, err)
// Assertions
// The role name should be extracted from the user ARN ("alice")
assert.Contains(t, assumedUser.Arn, fmt.Sprintf("assumed-role/alice/test-session"))
assert.Contains(t, assumedUser.AssumedRoleId, "alice:test-session")
// Verify token claims using ValidateSessionToken
sessionInfo, err := stsService.ValidateSessionToken(context.Background(), stsCreds.SessionToken)
require.NoError(t, err)
// The RoleArn in session info should match the fallback ARN (user ARN)
assert.Equal(t, fallbackRoleArn, sessionInfo.RoleArn)
// Verify is_admin claim is present
isAdmin, ok := sessionInfo.RequestContext["is_admin"].(bool)
assert.True(t, ok, "is_admin claim should be present")
assert.True(t, isAdmin, "is_admin claim should be true")
})
// Test case 2: Caller is an STS Assumed Role, No RoleArn
t.Run("Caller is STS Assumed Role, No RoleArn", func(t *testing.T) {
// Mock identity
callerIdentity := &Identity{
Name: "arn:aws:sts::111122223333:assumed-role/admin/session1",
Account: &AccountAdmin,
PrincipalArn: "arn:aws:sts::111122223333:assumed-role/admin/session1",
}
fallbackRoleArn := callerIdentity.PrincipalArn
stsCreds, assumedUser, err := stsHandlers.prepareSTSCredentials(fallbackRoleArn, "nested-session", nil, "", nil)
require.NoError(t, err)
// The role name should be extracted from the assumed role ARN ("admin")
assert.Contains(t, assumedUser.Arn, "assumed-role/admin/nested-session")
assert.Contains(t, assumedUser.AssumedRoleId, "admin:nested-session")
// Check claims
sessionInfo, err := stsService.ValidateSessionToken(context.Background(), stsCreds.SessionToken)
require.NoError(t, err)
assert.Equal(t, fallbackRoleArn, sessionInfo.RoleArn)
})
// Test case 3: Explicit RoleArn provided (Standard AssumeRole)
t.Run("Explicit RoleArn Provided", func(t *testing.T) {
explicitRoleArn := "arn:aws:iam::111122223333:role/TargetRole"
stsCreds, assumedUser, err := stsHandlers.prepareSTSCredentials(explicitRoleArn, "explicit-session", nil, "", nil)
require.NoError(t, err)
// Role name should be "TargetRole"
assert.Contains(t, assumedUser.Arn, "assumed-role/TargetRole/explicit-session")
// Check claims
sessionInfo, err := stsService.ValidateSessionToken(context.Background(), stsCreds.SessionToken)
require.NoError(t, err)
assert.Equal(t, explicitRoleArn, sessionInfo.RoleArn)
})
// Test case 4: Malformed ARN (Edge case)
t.Run("Malformed ARN", func(t *testing.T) {
malformedArn := "invalid-arn"
stsCreds, assumedUser, err := stsHandlers.prepareSTSCredentials(malformedArn, "bad-session", nil, "", nil)
require.NoError(t, err)
// Fallback behavior: use full string as role name if extraction fails
assert.Contains(t, assumedUser.Arn, "assumed-role/invalid-arn/bad-session")
sessionInfo, err := stsService.ValidateSessionToken(context.Background(), stsCreds.SessionToken)
require.NoError(t, err)
assert.Equal(t, malformedArn, sessionInfo.RoleArn)
})
}
Loading…
Cancel
Save