Browse Source

fix(sts): encapsulate TokenGenerator in STSService and add getter

pull/8003/head
Chris Lu 8 hours ago
parent
commit
7b79e1d7c6
  1. 24
      weed/iam/sts/cross_instance_token_test.go
  2. 24
      weed/iam/sts/distributed_sts_test.go
  3. 24
      weed/iam/sts/sts_service.go
  4. 2
      weed/s3api/s3api_sts.go

24
weed/iam/sts/cross_instance_token_test.go

@ -127,16 +127,16 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
sessionId := TestSessionID
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err, "Instance A should generate token")
// Validate token on Instance B
claimsFromB, err := instanceB.TokenGenerator.ValidateSessionToken(tokenFromA)
claimsFromB, err := instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
require.NoError(t, err, "Instance B should validate token from Instance A")
assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match")
// Validate same token on Instance C
claimsFromC, err := instanceC.TokenGenerator.ValidateSessionToken(tokenFromA)
claimsFromC, err := instanceC.GetTokenGenerator().ValidateSessionToken(tokenFromA)
require.NoError(t, err, "Instance C should validate token from Instance A")
assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match")
@ -295,15 +295,15 @@ func TestSTSDistributedConfigurationRequirements(t *testing.T) {
// Generate token on Instance A
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance A should validate its own token
_, err = instanceA.TokenGenerator.ValidateSessionToken(tokenFromA)
_, err = instanceA.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.NoError(t, err, "Instance A should validate own token")
// Instance B should REJECT token due to different signing key
_, err = instanceB.TokenGenerator.ValidateSessionToken(tokenFromA)
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.Error(t, err, "Instance B should reject token with different signing key")
assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error")
})
@ -339,11 +339,11 @@ func TestSTSDistributedConfigurationRequirements(t *testing.T) {
// Generate token on Instance A
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance B should REJECT token due to different issuer
_, err = instanceB.TokenGenerator.ValidateSessionToken(tokenFromA)
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.Error(t, err, "Instance B should reject token with different issuer")
assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error")
})
@ -368,12 +368,12 @@ func TestSTSDistributedConfigurationRequirements(t *testing.T) {
// Generate token on Instance 0
sessionId := "multi-instance-test"
expiresAt := time.Now().Add(time.Hour)
token, err := instances[0].TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
token, err := instances[0].GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// All other instances should validate the token
for i := 1; i < 5; i++ {
claims, err := instances[i].TokenGenerator.ValidateSessionToken(token)
claims, err := instances[i].GetTokenGenerator().ValidateSessionToken(token)
require.NoError(t, err, "Instance %d should validate token", i)
assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i)
}
@ -486,10 +486,10 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) {
assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired")
// Step 5: Token should be identical when parsed
claims2, err := gateway2.TokenGenerator.ValidateSessionToken(sessionToken)
claims2, err := gateway2.GetTokenGenerator().ValidateSessionToken(sessionToken)
require.NoError(t, err)
claims3, err := gateway3.TokenGenerator.ValidateSessionToken(sessionToken)
claims3, err := gateway3.GetTokenGenerator().ValidateSessionToken(sessionToken)
require.NoError(t, err)
assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match")

24
weed/iam/sts/distributed_sts_test.go

@ -109,9 +109,9 @@ func TestDistributedSTSService(t *testing.T) {
expiresAt := time.Now().Add(time.Hour)
// Generate tokens from different instances
token1, err1 := instance1.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
token2, err2 := instance2.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
token3, err3 := instance3.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
token1, err1 := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
token2, err2 := instance2.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
token3, err3 := instance3.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err1, "Instance 1 token generation should succeed")
require.NoError(t, err2, "Instance 2 token generation should succeed")
@ -130,13 +130,13 @@ func TestDistributedSTSService(t *testing.T) {
expiresAt := time.Now().Add(time.Hour)
// Generate token on instance 1
token, err := instance1.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Validate on all instances
claims1, err1 := instance1.TokenGenerator.ValidateSessionToken(token)
claims2, err2 := instance2.TokenGenerator.ValidateSessionToken(token)
claims3, err3 := instance3.TokenGenerator.ValidateSessionToken(token)
claims1, err1 := instance1.GetTokenGenerator().ValidateSessionToken(token)
claims2, err2 := instance2.GetTokenGenerator().ValidateSessionToken(token)
claims3, err3 := instance3.GetTokenGenerator().ValidateSessionToken(token)
require.NoError(t, err1, "Instance 1 should validate token from instance 1")
require.NoError(t, err2, "Instance 2 should validate token from instance 1")
@ -216,15 +216,15 @@ func TestSTSConfigurationValidation(t *testing.T) {
// Generate token on instance 1
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
token, err := instance1.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance 1 should validate its own token
_, err = instance1.TokenGenerator.ValidateSessionToken(token)
_, err = instance1.GetTokenGenerator().ValidateSessionToken(token)
assert.NoError(t, err, "Instance 1 should validate its own token")
// Instance 2 should reject token from instance 1 (different signing key)
_, err = instance2.TokenGenerator.ValidateSessionToken(token)
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
assert.Error(t, err, "Instance 2 should reject token with different signing key")
})
@ -258,12 +258,12 @@ func TestSTSConfigurationValidation(t *testing.T) {
// Generate token on instance 1
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
token, err := instance1.TokenGenerator.GenerateSessionToken(sessionId, expiresAt)
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance 2 should reject token due to issuer mismatch
// (Even though signing key is the same, issuer validation will fail)
_, err = instance2.TokenGenerator.ValidateSessionToken(token)
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
assert.Error(t, err, "Instance 2 should reject token with different issuer")
})
}

24
weed/iam/sts/sts_service.go

@ -77,10 +77,16 @@ type STSService struct {
initialized bool
providers map[string]providers.IdentityProvider
issuerToProvider map[string]providers.IdentityProvider // Efficient issuer-based provider lookup
TokenGenerator *TokenGenerator
tokenGenerator *TokenGenerator
trustPolicyValidator TrustPolicyValidator // Interface for trust policy validation
}
// GetTokenGenerator returns the token generator used by the STS service.
// This keeps the underlying field unexported while still allowing read-only access.
func (s *STSService) GetTokenGenerator() *TokenGenerator {
return s.tokenGenerator
}
// STSConfig holds STS service configuration
type STSConfig struct {
// TokenDuration is the default duration for issued tokens
@ -265,7 +271,7 @@ func (s *STSService) Initialize(config *STSConfig) error {
s.Config = config
// Initialize token generator for stateless JWT operations
s.TokenGenerator = NewTokenGenerator(config.SigningKey, config.Issuer)
s.tokenGenerator = NewTokenGenerator(config.SigningKey, config.Issuer)
// Load identity providers from configuration
if err := s.loadProvidersFromConfig(config); err != nil {
@ -460,7 +466,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
WithMaxDuration(sessionDuration)
// Generate self-contained JWT token with all session information
jwtToken, err := s.TokenGenerator.GenerateJWTWithClaims(sessionClaims)
jwtToken, err := s.tokenGenerator.GenerateJWTWithClaims(sessionClaims)
if err != nil {
return nil, fmt.Errorf("failed to generate JWT session token: %w", err)
}
@ -540,7 +546,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
WithMaxDuration(sessionDuration)
// Generate self-contained JWT token with all session information
jwtToken, err := s.TokenGenerator.GenerateJWTWithClaims(sessionClaims)
jwtToken, err := s.tokenGenerator.GenerateJWTWithClaims(sessionClaims)
if err != nil {
return nil, fmt.Errorf("failed to generate JWT session token: %w", err)
}
@ -566,7 +572,7 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri
}
// Validate JWT and extract comprehensive session claims
claims, err := s.TokenGenerator.ValidateJWTWithClaims(sessionToken)
claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
if err != nil {
return nil, fmt.Errorf(ErrSessionValidationFailed, err)
}
@ -811,8 +817,8 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpir
// extractSessionIdFromToken extracts session ID from JWT session token
func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
// Parse JWT and extract session ID from claims
claims, err := s.TokenGenerator.ValidateJWTWithClaims(sessionToken)
// Validate JWT and extract session claims
claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
if err != nil {
// For test compatibility, also handle direct session IDs
if len(sessionToken) == 32 { // Typical session ID length
@ -866,8 +872,8 @@ func (s *STSService) ExpireSessionForTesting(ctx context.Context, sessionToken s
return fmt.Errorf("session token cannot be empty")
}
// Validate JWT token format
_, err := s.TokenGenerator.ValidateJWTWithClaims(sessionToken)
// Just validate the signature
_, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
if err != nil {
return fmt.Errorf("invalid session token format: %w", err)
}

2
weed/s3api/s3api_sts.go

@ -483,7 +483,7 @@ func (h *STSHandlers) prepareSTSCredentials(roleArn, roleSessionName, principalA
}
// Generate JWT session token
sessionToken, err := h.stsService.TokenGenerator.GenerateJWTWithClaims(claims)
sessionToken, err := h.stsService.GetTokenGenerator().GenerateJWTWithClaims(claims)
if err != nil {
return STSCredentials{}, nil, fmt.Errorf("failed to generate session token: %w", err)
}

Loading…
Cancel
Save