You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							453 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							453 lines
						
					
					
						
							12 KiB
						
					
					
				
								package sts
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"context"
							 | 
						|
									"fmt"
							 | 
						|
									"strings"
							 | 
						|
									"testing"
							 | 
						|
									"time"
							 | 
						|
								
							 | 
						|
									"github.com/golang-jwt/jwt/v5"
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/iam/providers"
							 | 
						|
									"github.com/stretchr/testify/assert"
							 | 
						|
									"github.com/stretchr/testify/require"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// createSTSTestJWT creates a test JWT token for STS service tests
							 | 
						|
								func createSTSTestJWT(t *testing.T, issuer, subject string) string {
							 | 
						|
									token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
							 | 
						|
										"iss": issuer,
							 | 
						|
										"sub": subject,
							 | 
						|
										"aud": "test-client",
							 | 
						|
										"exp": time.Now().Add(time.Hour).Unix(),
							 | 
						|
										"iat": time.Now().Unix(),
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									tokenString, err := token.SignedString([]byte("test-signing-key"))
							 | 
						|
									require.NoError(t, err)
							 | 
						|
									return tokenString
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestSTSServiceInitialization tests STS service initialization
							 | 
						|
								func TestSTSServiceInitialization(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name    string
							 | 
						|
										config  *STSConfig
							 | 
						|
										wantErr bool
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name: "valid config",
							 | 
						|
											config: &STSConfig{
							 | 
						|
												TokenDuration:    FlexibleDuration{time.Hour},
							 | 
						|
												MaxSessionLength: FlexibleDuration{time.Hour * 12},
							 | 
						|
												Issuer:           "seaweedfs-sts",
							 | 
						|
												SigningKey:       []byte("test-signing-key"),
							 | 
						|
											},
							 | 
						|
											wantErr: false,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name: "missing signing key",
							 | 
						|
											config: &STSConfig{
							 | 
						|
												TokenDuration: FlexibleDuration{time.Hour},
							 | 
						|
												Issuer:        "seaweedfs-sts",
							 | 
						|
											},
							 | 
						|
											wantErr: true,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name: "invalid token duration",
							 | 
						|
											config: &STSConfig{
							 | 
						|
												TokenDuration: FlexibleDuration{-time.Hour},
							 | 
						|
												Issuer:        "seaweedfs-sts",
							 | 
						|
												SigningKey:    []byte("test-key"),
							 | 
						|
											},
							 | 
						|
											wantErr: true,
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											service := NewSTSService()
							 | 
						|
								
							 | 
						|
											err := service.Initialize(tt.config)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												assert.Error(t, err)
							 | 
						|
											} else {
							 | 
						|
												assert.NoError(t, err)
							 | 
						|
												assert.True(t, service.IsInitialized())
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
							 | 
						|
								func TestAssumeRoleWithWebIdentity(t *testing.T) {
							 | 
						|
									service := setupTestSTSService(t)
							 | 
						|
								
							 | 
						|
									tests := []struct {
							 | 
						|
										name             string
							 | 
						|
										roleArn          string
							 | 
						|
										webIdentityToken string
							 | 
						|
										sessionName      string
							 | 
						|
										durationSeconds  *int64
							 | 
						|
										wantErr          bool
							 | 
						|
										expectedSubject  string
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:             "successful role assumption",
							 | 
						|
											roleArn:          "arn:seaweed:iam::role/TestRole",
							 | 
						|
											webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
							 | 
						|
											sessionName:      "test-session",
							 | 
						|
											durationSeconds:  nil, // Use default
							 | 
						|
											wantErr:          false,
							 | 
						|
											expectedSubject:  "test-user-id",
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:             "invalid web identity token",
							 | 
						|
											roleArn:          "arn:seaweed:iam::role/TestRole",
							 | 
						|
											webIdentityToken: "invalid-token",
							 | 
						|
											sessionName:      "test-session",
							 | 
						|
											wantErr:          true,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:             "non-existent role",
							 | 
						|
											roleArn:          "arn:seaweed:iam::role/NonExistentRole",
							 | 
						|
											webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
							 | 
						|
											sessionName:      "test-session",
							 | 
						|
											wantErr:          true,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:             "custom session duration",
							 | 
						|
											roleArn:          "arn:seaweed:iam::role/TestRole",
							 | 
						|
											webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
							 | 
						|
											sessionName:      "test-session",
							 | 
						|
											durationSeconds:  int64Ptr(7200), // 2 hours
							 | 
						|
											wantErr:          false,
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											ctx := context.Background()
							 | 
						|
								
							 | 
						|
											request := &AssumeRoleWithWebIdentityRequest{
							 | 
						|
												RoleArn:          tt.roleArn,
							 | 
						|
												WebIdentityToken: tt.webIdentityToken,
							 | 
						|
												RoleSessionName:  tt.sessionName,
							 | 
						|
												DurationSeconds:  tt.durationSeconds,
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											response, err := service.AssumeRoleWithWebIdentity(ctx, request)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												assert.Error(t, err)
							 | 
						|
												assert.Nil(t, response)
							 | 
						|
											} else {
							 | 
						|
												assert.NoError(t, err)
							 | 
						|
												assert.NotNil(t, response)
							 | 
						|
												assert.NotNil(t, response.Credentials)
							 | 
						|
												assert.NotNil(t, response.AssumedRoleUser)
							 | 
						|
								
							 | 
						|
												// Verify credentials
							 | 
						|
												creds := response.Credentials
							 | 
						|
												assert.NotEmpty(t, creds.AccessKeyId)
							 | 
						|
												assert.NotEmpty(t, creds.SecretAccessKey)
							 | 
						|
												assert.NotEmpty(t, creds.SessionToken)
							 | 
						|
												assert.True(t, creds.Expiration.After(time.Now()))
							 | 
						|
								
							 | 
						|
												// Verify assumed role user
							 | 
						|
												user := response.AssumedRoleUser
							 | 
						|
												assert.Equal(t, tt.roleArn, user.AssumedRoleId)
							 | 
						|
												assert.Contains(t, user.Arn, tt.sessionName)
							 | 
						|
								
							 | 
						|
												if tt.expectedSubject != "" {
							 | 
						|
													assert.Equal(t, tt.expectedSubject, user.Subject)
							 | 
						|
												}
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
							 | 
						|
								func TestAssumeRoleWithLDAP(t *testing.T) {
							 | 
						|
									service := setupTestSTSService(t)
							 | 
						|
								
							 | 
						|
									tests := []struct {
							 | 
						|
										name        string
							 | 
						|
										roleArn     string
							 | 
						|
										username    string
							 | 
						|
										password    string
							 | 
						|
										sessionName string
							 | 
						|
										wantErr     bool
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:        "successful LDAP role assumption",
							 | 
						|
											roleArn:     "arn:seaweed:iam::role/LDAPRole",
							 | 
						|
											username:    "testuser",
							 | 
						|
											password:    "testpass",
							 | 
						|
											sessionName: "ldap-session",
							 | 
						|
											wantErr:     false,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:        "invalid LDAP credentials",
							 | 
						|
											roleArn:     "arn:seaweed:iam::role/LDAPRole",
							 | 
						|
											username:    "testuser",
							 | 
						|
											password:    "wrongpass",
							 | 
						|
											sessionName: "ldap-session",
							 | 
						|
											wantErr:     true,
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											ctx := context.Background()
							 | 
						|
								
							 | 
						|
											request := &AssumeRoleWithCredentialsRequest{
							 | 
						|
												RoleArn:         tt.roleArn,
							 | 
						|
												Username:        tt.username,
							 | 
						|
												Password:        tt.password,
							 | 
						|
												RoleSessionName: tt.sessionName,
							 | 
						|
												ProviderName:    "test-ldap",
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											response, err := service.AssumeRoleWithCredentials(ctx, request)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												assert.Error(t, err)
							 | 
						|
												assert.Nil(t, response)
							 | 
						|
											} else {
							 | 
						|
												assert.NoError(t, err)
							 | 
						|
												assert.NotNil(t, response)
							 | 
						|
												assert.NotNil(t, response.Credentials)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestSessionTokenValidation tests session token validation
							 | 
						|
								func TestSessionTokenValidation(t *testing.T) {
							 | 
						|
									service := setupTestSTSService(t)
							 | 
						|
									ctx := context.Background()
							 | 
						|
								
							 | 
						|
									// First, create a session
							 | 
						|
									request := &AssumeRoleWithWebIdentityRequest{
							 | 
						|
										RoleArn:          "arn:seaweed:iam::role/TestRole",
							 | 
						|
										WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
							 | 
						|
										RoleSessionName:  "test-session",
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									response, err := service.AssumeRoleWithWebIdentity(ctx, request)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
									require.NotNil(t, response)
							 | 
						|
								
							 | 
						|
									sessionToken := response.Credentials.SessionToken
							 | 
						|
								
							 | 
						|
									tests := []struct {
							 | 
						|
										name    string
							 | 
						|
										token   string
							 | 
						|
										wantErr bool
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:    "valid session token",
							 | 
						|
											token:   sessionToken,
							 | 
						|
											wantErr: false,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "invalid session token",
							 | 
						|
											token:   "invalid-session-token",
							 | 
						|
											wantErr: true,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:    "empty session token",
							 | 
						|
											token:   "",
							 | 
						|
											wantErr: true,
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											session, err := service.ValidateSessionToken(ctx, tt.token)
							 | 
						|
								
							 | 
						|
											if tt.wantErr {
							 | 
						|
												assert.Error(t, err)
							 | 
						|
												assert.Nil(t, session)
							 | 
						|
											} else {
							 | 
						|
												assert.NoError(t, err)
							 | 
						|
												assert.NotNil(t, session)
							 | 
						|
												assert.Equal(t, "test-session", session.SessionName)
							 | 
						|
												assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
							 | 
						|
								// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
							 | 
						|
								func TestSessionTokenPersistence(t *testing.T) {
							 | 
						|
									service := setupTestSTSService(t)
							 | 
						|
									ctx := context.Background()
							 | 
						|
								
							 | 
						|
									// Create a session first
							 | 
						|
									request := &AssumeRoleWithWebIdentityRequest{
							 | 
						|
										RoleArn:          "arn:seaweed:iam::role/TestRole",
							 | 
						|
										WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
							 | 
						|
										RoleSessionName:  "test-session",
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									response, err := service.AssumeRoleWithWebIdentity(ctx, request)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
								
							 | 
						|
									sessionToken := response.Credentials.SessionToken
							 | 
						|
								
							 | 
						|
									// Verify token is valid initially
							 | 
						|
									session, err := service.ValidateSessionToken(ctx, sessionToken)
							 | 
						|
									assert.NoError(t, err)
							 | 
						|
									assert.NotNil(t, session)
							 | 
						|
									assert.Equal(t, "test-session", session.SessionName)
							 | 
						|
								
							 | 
						|
									// In a stateless JWT system, tokens remain valid throughout their lifetime
							 | 
						|
									// Multiple validations should all succeed as long as the token hasn't expired
							 | 
						|
									session2, err := service.ValidateSessionToken(ctx, sessionToken)
							 | 
						|
									assert.NoError(t, err, "Token should remain valid in stateless system")
							 | 
						|
									assert.NotNil(t, session2, "Session should be returned from JWT token")
							 | 
						|
									assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Helper functions
							 | 
						|
								
							 | 
						|
								func setupTestSTSService(t *testing.T) *STSService {
							 | 
						|
									service := NewSTSService()
							 | 
						|
								
							 | 
						|
									config := &STSConfig{
							 | 
						|
										TokenDuration:    FlexibleDuration{time.Hour},
							 | 
						|
										MaxSessionLength: FlexibleDuration{time.Hour * 12},
							 | 
						|
										Issuer:           "test-sts",
							 | 
						|
										SigningKey:       []byte("test-signing-key-32-characters-long"),
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									err := service.Initialize(config)
							 | 
						|
									require.NoError(t, err)
							 | 
						|
								
							 | 
						|
									// Set up mock trust policy validator (required for STS testing)
							 | 
						|
									mockValidator := &MockTrustPolicyValidator{}
							 | 
						|
									service.SetTrustPolicyValidator(mockValidator)
							 | 
						|
								
							 | 
						|
									// Register test providers
							 | 
						|
									mockOIDCProvider := &MockIdentityProvider{
							 | 
						|
										name: "test-oidc",
							 | 
						|
										validTokens: map[string]*providers.TokenClaims{
							 | 
						|
											createSTSTestJWT(t, "test-issuer", "test-user"): {
							 | 
						|
												Subject: "test-user-id",
							 | 
						|
												Issuer:  "test-issuer",
							 | 
						|
												Claims: map[string]interface{}{
							 | 
						|
													"email": "test@example.com",
							 | 
						|
													"name":  "Test User",
							 | 
						|
												},
							 | 
						|
											},
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									mockLDAPProvider := &MockIdentityProvider{
							 | 
						|
										name: "test-ldap",
							 | 
						|
										validCredentials: map[string]string{
							 | 
						|
											"testuser": "testpass",
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									service.RegisterProvider(mockOIDCProvider)
							 | 
						|
									service.RegisterProvider(mockLDAPProvider)
							 | 
						|
								
							 | 
						|
									return service
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func int64Ptr(v int64) *int64 {
							 | 
						|
									return &v
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Mock identity provider for testing
							 | 
						|
								type MockIdentityProvider struct {
							 | 
						|
									name             string
							 | 
						|
									validTokens      map[string]*providers.TokenClaims
							 | 
						|
									validCredentials map[string]string
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (m *MockIdentityProvider) Name() string {
							 | 
						|
									return m.name
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (m *MockIdentityProvider) GetIssuer() string {
							 | 
						|
									return "test-issuer" // This matches the issuer in the token claims
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (m *MockIdentityProvider) Initialize(config interface{}) error {
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
							 | 
						|
									// First try to parse as JWT token
							 | 
						|
									if len(token) > 20 && strings.Count(token, ".") >= 2 {
							 | 
						|
										parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
							 | 
						|
										if err == nil {
							 | 
						|
											if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
							 | 
						|
												issuer, _ := claims["iss"].(string)
							 | 
						|
												subject, _ := claims["sub"].(string)
							 | 
						|
								
							 | 
						|
												// Verify the issuer matches what we expect
							 | 
						|
												if issuer == "test-issuer" && subject != "" {
							 | 
						|
													return &providers.ExternalIdentity{
							 | 
						|
														UserID:      subject,
							 | 
						|
														Email:       subject + "@test-domain.com",
							 | 
						|
														DisplayName: "Test User " + subject,
							 | 
						|
														Provider:    m.name,
							 | 
						|
													}, nil
							 | 
						|
												}
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Handle legacy OIDC tokens (for backwards compatibility)
							 | 
						|
									if claims, exists := m.validTokens[token]; exists {
							 | 
						|
										email, _ := claims.GetClaimString("email")
							 | 
						|
										name, _ := claims.GetClaimString("name")
							 | 
						|
								
							 | 
						|
										return &providers.ExternalIdentity{
							 | 
						|
											UserID:      claims.Subject,
							 | 
						|
											Email:       email,
							 | 
						|
											DisplayName: name,
							 | 
						|
											Provider:    m.name,
							 | 
						|
										}, nil
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Handle LDAP credentials (username:password format)
							 | 
						|
									if m.validCredentials != nil {
							 | 
						|
										parts := strings.Split(token, ":")
							 | 
						|
										if len(parts) == 2 {
							 | 
						|
											username, password := parts[0], parts[1]
							 | 
						|
											if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
							 | 
						|
												return &providers.ExternalIdentity{
							 | 
						|
													UserID:      username,
							 | 
						|
													Email:       username + "@" + m.name + ".com",
							 | 
						|
													DisplayName: "Test User " + username,
							 | 
						|
													Provider:    m.name,
							 | 
						|
												}, nil
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return nil, fmt.Errorf("unknown test token: %s", token)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
							 | 
						|
									return &providers.ExternalIdentity{
							 | 
						|
										UserID:   userID,
							 | 
						|
										Email:    userID + "@" + m.name + ".com",
							 | 
						|
										Provider: m.name,
							 | 
						|
									}, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
							 | 
						|
									if claims, exists := m.validTokens[token]; exists {
							 | 
						|
										return claims, nil
							 | 
						|
									}
							 | 
						|
									return nil, fmt.Errorf("invalid token")
							 | 
						|
								}
							 |