You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							412 lines
						
					
					
						
							10 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							412 lines
						
					
					
						
							10 KiB
						
					
					
				| package sts | |
| 
 | |
| import ( | |
| 	"context" | |
| 	"fmt" | |
| 	"strings" | |
| 	"testing" | |
| 	"time" | |
| 
 | |
| 	"github.com/seaweedfs/seaweedfs/weed/iam/providers" | |
| 	"github.com/stretchr/testify/assert" | |
| 	"github.com/stretchr/testify/require" | |
| ) | |
| 
 | |
| // TestSTSServiceInitialization tests STS service initialization | |
| func TestSTSServiceInitialization(t *testing.T) { | |
| 	tests := []struct { | |
| 		name    string | |
| 		config  *STSConfig | |
| 		wantErr bool | |
| 	}{ | |
| 		{ | |
| 			name: "valid config", | |
| 			config: &STSConfig{ | |
| 				TokenDuration:    time.Hour, | |
| 				MaxSessionLength: time.Hour * 12, | |
| 				Issuer:           "seaweedfs-sts", | |
| 				SigningKey:       []byte("test-signing-key"), | |
| 			}, | |
| 			wantErr: false, | |
| 		}, | |
| 		{ | |
| 			name: "missing signing key", | |
| 			config: &STSConfig{ | |
| 				TokenDuration: time.Hour, | |
| 				Issuer:        "seaweedfs-sts", | |
| 			}, | |
| 			wantErr: true, | |
| 		}, | |
| 		{ | |
| 			name: "invalid token duration", | |
| 			config: &STSConfig{ | |
| 				TokenDuration: -time.Hour, | |
| 				Issuer:        "seaweedfs-sts", | |
| 				SigningKey:    []byte("test-key"), | |
| 			}, | |
| 			wantErr: true, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			service := NewSTSService() | |
| 
 | |
| 			err := service.Initialize(tt.config) | |
| 
 | |
| 			if tt.wantErr { | |
| 				assert.Error(t, err) | |
| 			} else { | |
| 				assert.NoError(t, err) | |
| 				assert.True(t, service.IsInitialized()) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens | |
| func TestAssumeRoleWithWebIdentity(t *testing.T) { | |
| 	service := setupTestSTSService(t) | |
| 
 | |
| 	tests := []struct { | |
| 		name             string | |
| 		roleArn          string | |
| 		webIdentityToken string | |
| 		sessionName      string | |
| 		durationSeconds  *int64 | |
| 		wantErr          bool | |
| 		expectedSubject  string | |
| 	}{ | |
| 		{ | |
| 			name:             "successful role assumption", | |
| 			roleArn:          "arn:seaweed:iam::role/TestRole", | |
| 			webIdentityToken: "valid-oidc-token", | |
| 			sessionName:      "test-session", | |
| 			durationSeconds:  nil, // Use default | |
| 			wantErr:          false, | |
| 			expectedSubject:  "test-user-id", | |
| 		}, | |
| 		{ | |
| 			name:             "invalid web identity token", | |
| 			roleArn:          "arn:seaweed:iam::role/TestRole", | |
| 			webIdentityToken: "invalid-token", | |
| 			sessionName:      "test-session", | |
| 			wantErr:          true, | |
| 		}, | |
| 		{ | |
| 			name:             "non-existent role", | |
| 			roleArn:          "arn:seaweed:iam::role/NonExistentRole", | |
| 			webIdentityToken: "valid-oidc-token", | |
| 			sessionName:      "test-session", | |
| 			wantErr:          true, | |
| 		}, | |
| 		{ | |
| 			name:             "custom session duration", | |
| 			roleArn:          "arn:seaweed:iam::role/TestRole", | |
| 			webIdentityToken: "valid-oidc-token", | |
| 			sessionName:      "test-session", | |
| 			durationSeconds:  int64Ptr(7200), // 2 hours | |
| 			wantErr:          false, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			ctx := context.Background() | |
| 
 | |
| 			request := &AssumeRoleWithWebIdentityRequest{ | |
| 				RoleArn:          tt.roleArn, | |
| 				WebIdentityToken: tt.webIdentityToken, | |
| 				RoleSessionName:  tt.sessionName, | |
| 				DurationSeconds:  tt.durationSeconds, | |
| 			} | |
| 
 | |
| 			response, err := service.AssumeRoleWithWebIdentity(ctx, request) | |
| 
 | |
| 			if tt.wantErr { | |
| 				assert.Error(t, err) | |
| 				assert.Nil(t, response) | |
| 			} else { | |
| 				assert.NoError(t, err) | |
| 				assert.NotNil(t, response) | |
| 				assert.NotNil(t, response.Credentials) | |
| 				assert.NotNil(t, response.AssumedRoleUser) | |
| 
 | |
| 				// Verify credentials | |
| 				creds := response.Credentials | |
| 				assert.NotEmpty(t, creds.AccessKeyId) | |
| 				assert.NotEmpty(t, creds.SecretAccessKey) | |
| 				assert.NotEmpty(t, creds.SessionToken) | |
| 				assert.True(t, creds.Expiration.After(time.Now())) | |
| 
 | |
| 				// Verify assumed role user | |
| 				user := response.AssumedRoleUser | |
| 				assert.Equal(t, tt.roleArn, user.AssumedRoleId) | |
| 				assert.Contains(t, user.Arn, tt.sessionName) | |
| 
 | |
| 				if tt.expectedSubject != "" { | |
| 					assert.Equal(t, tt.expectedSubject, user.Subject) | |
| 				} | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestAssumeRoleWithLDAP tests role assumption with LDAP credentials | |
| func TestAssumeRoleWithLDAP(t *testing.T) { | |
| 	service := setupTestSTSService(t) | |
| 
 | |
| 	tests := []struct { | |
| 		name        string | |
| 		roleArn     string | |
| 		username    string | |
| 		password    string | |
| 		sessionName string | |
| 		wantErr     bool | |
| 	}{ | |
| 		{ | |
| 			name:        "successful LDAP role assumption", | |
| 			roleArn:     "arn:seaweed:iam::role/LDAPRole", | |
| 			username:    "testuser", | |
| 			password:    "testpass", | |
| 			sessionName: "ldap-session", | |
| 			wantErr:     false, | |
| 		}, | |
| 		{ | |
| 			name:        "invalid LDAP credentials", | |
| 			roleArn:     "arn:seaweed:iam::role/LDAPRole", | |
| 			username:    "testuser", | |
| 			password:    "wrongpass", | |
| 			sessionName: "ldap-session", | |
| 			wantErr:     true, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			ctx := context.Background() | |
| 
 | |
| 			request := &AssumeRoleWithCredentialsRequest{ | |
| 				RoleArn:         tt.roleArn, | |
| 				Username:        tt.username, | |
| 				Password:        tt.password, | |
| 				RoleSessionName: tt.sessionName, | |
| 				ProviderName:    "test-ldap", | |
| 			} | |
| 
 | |
| 			response, err := service.AssumeRoleWithCredentials(ctx, request) | |
| 
 | |
| 			if tt.wantErr { | |
| 				assert.Error(t, err) | |
| 				assert.Nil(t, response) | |
| 			} else { | |
| 				assert.NoError(t, err) | |
| 				assert.NotNil(t, response) | |
| 				assert.NotNil(t, response.Credentials) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestSessionTokenValidation tests session token validation | |
| func TestSessionTokenValidation(t *testing.T) { | |
| 	service := setupTestSTSService(t) | |
| 	ctx := context.Background() | |
| 
 | |
| 	// First, create a session | |
| 	request := &AssumeRoleWithWebIdentityRequest{ | |
| 		RoleArn:          "arn:seaweed:iam::role/TestRole", | |
| 		WebIdentityToken: "valid-oidc-token", | |
| 		RoleSessionName:  "test-session", | |
| 	} | |
| 
 | |
| 	response, err := service.AssumeRoleWithWebIdentity(ctx, request) | |
| 	require.NoError(t, err) | |
| 	require.NotNil(t, response) | |
| 
 | |
| 	sessionToken := response.Credentials.SessionToken | |
| 
 | |
| 	tests := []struct { | |
| 		name    string | |
| 		token   string | |
| 		wantErr bool | |
| 	}{ | |
| 		{ | |
| 			name:    "valid session token", | |
| 			token:   sessionToken, | |
| 			wantErr: false, | |
| 		}, | |
| 		{ | |
| 			name:    "invalid session token", | |
| 			token:   "invalid-session-token", | |
| 			wantErr: true, | |
| 		}, | |
| 		{ | |
| 			name:    "empty session token", | |
| 			token:   "", | |
| 			wantErr: true, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			session, err := service.ValidateSessionToken(ctx, tt.token) | |
| 
 | |
| 			if tt.wantErr { | |
| 				assert.Error(t, err) | |
| 				assert.Nil(t, session) | |
| 			} else { | |
| 				assert.NoError(t, err) | |
| 				assert.NotNil(t, session) | |
| 				assert.Equal(t, "test-session", session.SessionName) | |
| 				assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime | |
| // Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration | |
| func TestSessionTokenPersistence(t *testing.T) { | |
| 	service := setupTestSTSService(t) | |
| 	ctx := context.Background() | |
| 
 | |
| 	// Create a session first | |
| 	request := &AssumeRoleWithWebIdentityRequest{ | |
| 		RoleArn:          "arn:seaweed:iam::role/TestRole", | |
| 		WebIdentityToken: "valid-oidc-token", | |
| 		RoleSessionName:  "test-session", | |
| 	} | |
| 
 | |
| 	response, err := service.AssumeRoleWithWebIdentity(ctx, request) | |
| 	require.NoError(t, err) | |
| 
 | |
| 	sessionToken := response.Credentials.SessionToken | |
| 
 | |
| 	// Verify token is valid initially | |
| 	session, err := service.ValidateSessionToken(ctx, sessionToken) | |
| 	assert.NoError(t, err) | |
| 	assert.NotNil(t, session) | |
| 	assert.Equal(t, "test-session", session.SessionName) | |
| 
 | |
| 	// In a stateless JWT system, tokens remain valid throughout their lifetime | |
| 	// Multiple validations should all succeed as long as the token hasn't expired | |
| 	session2, err := service.ValidateSessionToken(ctx, sessionToken) | |
| 	assert.NoError(t, err, "Token should remain valid in stateless system") | |
| 	assert.NotNil(t, session2, "Session should be returned from JWT token") | |
| 	assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent") | |
| } | |
| 
 | |
| // Helper functions | |
|  | |
| func setupTestSTSService(t *testing.T) *STSService { | |
| 	service := NewSTSService() | |
| 
 | |
| 	config := &STSConfig{ | |
| 		TokenDuration:    time.Hour, | |
| 		MaxSessionLength: time.Hour * 12, | |
| 		Issuer:           "test-sts", | |
| 		SigningKey:       []byte("test-signing-key-32-characters-long"), | |
| 	} | |
| 
 | |
| 	err := service.Initialize(config) | |
| 	require.NoError(t, err) | |
| 
 | |
| 	// Set up mock trust policy validator (required for STS testing) | |
| 	mockValidator := &MockTrustPolicyValidator{} | |
| 	service.SetTrustPolicyValidator(mockValidator) | |
| 
 | |
| 	// Register test providers | |
| 	mockOIDCProvider := &MockIdentityProvider{ | |
| 		name: "test-oidc", | |
| 		validTokens: map[string]*providers.TokenClaims{ | |
| 			"valid-oidc-token": { | |
| 				Subject: "test-user-id", | |
| 				Issuer:  "test-issuer", | |
| 				Claims: map[string]interface{}{ | |
| 					"email": "test@example.com", | |
| 					"name":  "Test User", | |
| 				}, | |
| 			}, | |
| 		}, | |
| 	} | |
| 
 | |
| 	mockLDAPProvider := &MockIdentityProvider{ | |
| 		name: "test-ldap", | |
| 		validCredentials: map[string]string{ | |
| 			"testuser": "testpass", | |
| 		}, | |
| 	} | |
| 
 | |
| 	service.RegisterProvider(mockOIDCProvider) | |
| 	service.RegisterProvider(mockLDAPProvider) | |
| 
 | |
| 	return service | |
| } | |
| 
 | |
| func int64Ptr(v int64) *int64 { | |
| 	return &v | |
| } | |
| 
 | |
| // Mock identity provider for testing | |
| type MockIdentityProvider struct { | |
| 	name             string | |
| 	validTokens      map[string]*providers.TokenClaims | |
| 	validCredentials map[string]string | |
| } | |
| 
 | |
| func (m *MockIdentityProvider) Name() string { | |
| 	return m.name | |
| } | |
| 
 | |
| func (m *MockIdentityProvider) Initialize(config interface{}) error { | |
| 	return nil | |
| } | |
| 
 | |
| func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { | |
| 	// Handle OIDC tokens | |
| 	if claims, exists := m.validTokens[token]; exists { | |
| 		email, _ := claims.GetClaimString("email") | |
| 		name, _ := claims.GetClaimString("name") | |
| 
 | |
| 		return &providers.ExternalIdentity{ | |
| 			UserID:      claims.Subject, | |
| 			Email:       email, | |
| 			DisplayName: name, | |
| 			Provider:    m.name, | |
| 		}, nil | |
| 	} | |
| 
 | |
| 	// Handle LDAP credentials (username:password format) | |
| 	if m.validCredentials != nil { | |
| 		parts := strings.Split(token, ":") | |
| 		if len(parts) == 2 { | |
| 			username, password := parts[0], parts[1] | |
| 			if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password { | |
| 				return &providers.ExternalIdentity{ | |
| 					UserID:      username, | |
| 					Email:       username + "@" + m.name + ".com", | |
| 					DisplayName: "Test User " + username, | |
| 					Provider:    m.name, | |
| 				}, nil | |
| 			} | |
| 		} | |
| 	} | |
| 
 | |
| 	return nil, fmt.Errorf("invalid token") | |
| } | |
| 
 | |
| func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { | |
| 	return &providers.ExternalIdentity{ | |
| 		UserID:   userID, | |
| 		Email:    userID + "@" + m.name + ".com", | |
| 		Provider: m.name, | |
| 	}, nil | |
| } | |
| 
 | |
| func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { | |
| 	if claims, exists := m.validTokens[token]; exists { | |
| 		return claims, nil | |
| 	} | |
| 	return nil, fmt.Errorf("invalid token") | |
| }
 |