diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go index d31f322b0..fe1cdaccb 100644 --- a/weed/iam/oidc/oidc_provider.go +++ b/weed/iam/oidc/oidc_provider.go @@ -186,14 +186,22 @@ func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*provide attributes["roles"] = strings.Join(roles, ",") } - return &providers.ExternalIdentity{ + identity := &providers.ExternalIdentity{ UserID: claims.Subject, Email: email, DisplayName: displayName, Groups: groups, Attributes: attributes, Provider: p.name, - }, nil + } + + // Pass the token expiration to limit session duration + // This ensures the STS session doesn't exceed the source token's validity + if !claims.ExpiresAt.IsZero() { + identity.TokenExpiration = &claims.ExpiresAt + } + + return identity, nil } // GetUserInfo retrieves user information from the UserInfo endpoint @@ -372,6 +380,24 @@ func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*provid Claims: make(map[string]interface{}), } + // Extract time-based claims (exp, iat, nbf) + for key, target := range map[string]*time.Time{ + "exp": &tokenClaims.ExpiresAt, + "iat": &tokenClaims.IssuedAt, + "nbf": &tokenClaims.NotBefore, + } { + if val, ok := claims[key]; ok { + switch v := val.(type) { + case float64: + *target = time.Unix(int64(v), 0) + case json.Number: + if intVal, err := v.Int64(); err == nil { + *target = time.Unix(intVal, 0) + } + } + } + } + // Copy all claims for key, value := range claims { tokenClaims.Claims[key] = value diff --git a/weed/iam/providers/provider.go b/weed/iam/providers/provider.go index 5c1deb03d..3b7affc8e 100644 --- a/weed/iam/providers/provider.go +++ b/weed/iam/providers/provider.go @@ -47,6 +47,10 @@ type ExternalIdentity struct { // Provider is the name of the identity provider Provider string `json:"provider"` + + // TokenExpiration is the expiration time of the source identity token + // This is used to limit session duration to not exceed the token's exp claim + TokenExpiration *time.Time `json:"tokenExpiration,omitempty"` } // Validate validates the external identity structure diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go index 3d9f9af35..e28340f30 100644 --- a/weed/iam/sts/sts_service.go +++ b/weed/iam/sts/sts_service.go @@ -422,8 +422,9 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass return nil, fmt.Errorf("role assumption denied: %w", err) } - // 3. Calculate session duration - sessionDuration := s.calculateSessionDuration(request.DurationSeconds) + // 3. Calculate session duration, capping at the source token's expiration + // This ensures sessions from short-lived tokens (e.g., GitLab CI job tokens) don't outlive their source + sessionDuration := s.calculateSessionDuration(request.DurationSeconds, externalIdentity.TokenExpiration) expiresAt := time.Now().Add(sessionDuration) // 4. Generate session ID and credentials @@ -502,7 +503,8 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass } // 4. Calculate session duration - sessionDuration := s.calculateSessionDuration(request.DurationSeconds) + // For credential-based auth, there's no source token with expiration to cap against + sessionDuration := s.calculateSessionDuration(request.DurationSeconds, nil) expiresAt := time.Now().Add(sessionDuration) // 5. Generate session ID and temporary credentials @@ -745,14 +747,42 @@ func (s *STSService) validateRoleAssumptionForCredentials(ctx context.Context, r return nil } -// calculateSessionDuration calculates the session duration -func (s *STSService) calculateSessionDuration(durationSeconds *int64) time.Duration { +// calculateSessionDuration calculates the session duration, respecting the source token's expiration +// If the incoming web identity token has an exp claim, the session duration is capped to not exceed it +// This ensures that sessions from short-lived tokens (e.g., GitLab CI job tokens) don't outlive their source +func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpiration *time.Time) time.Duration { + var duration time.Duration if durationSeconds != nil { - return time.Duration(*durationSeconds) * time.Second + duration = time.Duration(*durationSeconds) * time.Second + } else { + // Use default from config + duration = s.Config.TokenDuration.Duration + } + + // If the source token has an expiration, cap the session duration to not exceed it + // This follows the principle: "if calculated exp > incoming exp claim, then limit outgoing exp to incoming exp" + if tokenExpiration != nil && !tokenExpiration.IsZero() { + timeUntilTokenExpiry := time.Until(*tokenExpiration) + if timeUntilTokenExpiry <= 0 { + // Token already expired - use minimal duration as defense-in-depth + // The token should have been rejected during validation, but we handle this defensively + glog.V(2).Infof("Source token already expired, using minimal session duration") + duration = time.Minute + } else if timeUntilTokenExpiry < duration { + glog.V(2).Infof("Limiting session duration from %v to %v based on source token expiration", + duration, timeUntilTokenExpiry) + duration = timeUntilTokenExpiry + } + } + + // Cap at MaxSessionLength if configured + if s.Config.MaxSessionLength.Duration > 0 && duration > s.Config.MaxSessionLength.Duration { + glog.V(2).Infof("Limiting session duration from %v to %v based on MaxSessionLength config", + duration, s.Config.MaxSessionLength.Duration) + duration = s.Config.MaxSessionLength.Duration } - // Use default from config - return s.Config.TokenDuration.Duration + return duration } // extractSessionIdFromToken extracts session ID from JWT session token diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go index 72d69c8c8..56b6755de 100644 --- a/weed/iam/sts/sts_service_test.go +++ b/weed/iam/sts/sts_service_test.go @@ -451,3 +451,212 @@ func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) } return nil, fmt.Errorf("invalid token") } + +// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim +func TestSessionDurationCappedByTokenExpiration(t *testing.T) { + service := NewSTSService() + + config := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour + MaxSessionLength: FlexibleDuration{time.Hour * 12}, + Issuer: "test-sts", + SigningKey: []byte("test-signing-key-32-characters-long"), + } + + err := service.Initialize(config) + require.NoError(t, err) + + tests := []struct { + name string + durationSeconds *int64 + tokenExpiration *time.Time + expectedMaxSeconds int64 + description string + }{ + { + name: "no token expiration - use default duration", + durationSeconds: nil, + tokenExpiration: nil, + expectedMaxSeconds: 3600, // 1 hour default + description: "When no token expiration is set, use the configured default duration", + }, + { + name: "token expires before default duration", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)), + expectedMaxSeconds: 30 * 60, // 30 minutes + description: "When token expires in 30 min, session should be capped at 30 min", + }, + { + name: "token expires after default duration - use default", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)), + expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry + description: "When token expires after default duration, use the default duration", + }, + { + name: "requested duration shorter than token expiry", + durationSeconds: int64Ptr(1800), // 30 min requested + tokenExpiration: timePtr(time.Now().Add(time.Hour)), + expectedMaxSeconds: 1800, // 30 minutes as requested + description: "When requested duration is shorter than token expiry, use requested duration", + }, + { + name: "requested duration longer than token expiry - cap at token expiry", + durationSeconds: int64Ptr(3600), // 1 hour requested + tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)), + expectedMaxSeconds: 15 * 60, // Capped at 15 minutes + description: "When requested duration exceeds token expiry, cap at token expiry", + }, + { + name: "GitLab CI short-lived token scenario", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)), + expectedMaxSeconds: 5 * 60, // 5 minutes + description: "GitLab CI job with 5 minute timeout should result in 5 minute session", + }, + { + name: "already expired token - defense in depth", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago + expectedMaxSeconds: 60, // 1 minute minimum + description: "Already expired token should result in minimal 1 minute session", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration) + + // Allow 5 second tolerance for time calculations + maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second + minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second + + assert.GreaterOrEqual(t, duration, minExpected, + "%s: duration %v should be >= %v", tt.description, duration, minExpected) + assert.LessOrEqual(t, duration, maxExpected, + "%s: duration %v should be <= %v", tt.description, duration, maxExpected) + }) + } +} + +// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped +func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) { + service := NewSTSService() + + config := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{time.Hour * 12}, + Issuer: "test-sts", + SigningKey: []byte("test-signing-key-32-characters-long"), + } + + err := service.Initialize(config) + require.NoError(t, err) + + // Set up mock trust policy validator + mockValidator := &MockTrustPolicyValidator{} + service.SetTrustPolicyValidator(mockValidator) + + // Create a mock provider that returns tokens with short expiration + shortLivedTokenExpiration := time.Now().Add(10 * time.Minute) + mockProvider := &MockIdentityProviderWithExpiration{ + name: "short-lived-issuer", + tokenExpiration: &shortLivedTokenExpiration, + } + service.RegisterProvider(mockProvider) + + ctx := context.Background() + + // Create a JWT token with short expiration + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iss": "short-lived-issuer", + "sub": "test-user", + "aud": "test-client", + "exp": shortLivedTokenExpiration.Unix(), + "iat": time.Now().Unix(), + }) + tokenString, err := token.SignedString([]byte("test-signing-key")) + require.NoError(t, err) + + request := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:aws:iam::role/TestRole", + WebIdentityToken: tokenString, + RoleSessionName: "test-session", + } + + response, err := service.AssumeRoleWithWebIdentity(ctx, request) + require.NoError(t, err) + require.NotNil(t, response) + + // Verify the session expires at or before the token expiration + // Allow 5 second tolerance + assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)), + "Session expiration (%v) should not exceed token expiration (%v)", + response.Credentials.Expiration, shortLivedTokenExpiration) +} + +// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration +type MockIdentityProviderWithExpiration struct { + name string + tokenExpiration *time.Time +} + +func (m *MockIdentityProviderWithExpiration) Name() string { + return m.name +} + +func (m *MockIdentityProviderWithExpiration) GetIssuer() string { + return m.name +} + +func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error { + return nil +} + +func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { + // Parse the token to get subject + parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + claims, ok := parsedToken.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid claims") + } + + subject, _ := claims["sub"].(string) + + identity := &providers.ExternalIdentity{ + UserID: subject, + Email: subject + "@example.com", + DisplayName: "Test User", + Provider: m.name, + TokenExpiration: m.tokenExpiration, + } + + return identity, nil +} + +func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { + return &providers.ExternalIdentity{ + UserID: userID, + Provider: m.name, + }, nil +} + +func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { + claims := &providers.TokenClaims{ + Subject: "test-user", + Issuer: m.name, + } + if m.tokenExpiration != nil { + claims.ExpiresAt = *m.tokenExpiration + } + return claims, nil +} + +func timePtr(t time.Time) *time.Time { + return &t +}