package sts import ( "context" "fmt" "strings" "testing" "time" "github.com/seaweedfs/seaweedfs/weed/iam/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestSTSServiceInitialization tests STS service initialization func TestSTSServiceInitialization(t *testing.T) { tests := []struct { name string config *STSConfig wantErr bool }{ { name: "valid config", config: &STSConfig{ TokenDuration: time.Hour, MaxSessionLength: time.Hour * 12, Issuer: "seaweedfs-sts", SigningKey: []byte("test-signing-key"), }, wantErr: false, }, { name: "missing signing key", config: &STSConfig{ TokenDuration: time.Hour, Issuer: "seaweedfs-sts", }, wantErr: true, }, { name: "invalid token duration", config: &STSConfig{ TokenDuration: -time.Hour, Issuer: "seaweedfs-sts", SigningKey: []byte("test-key"), }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { service := NewSTSService() err := service.Initialize(tt.config) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) assert.True(t, service.IsInitialized()) } }) } } // TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens func TestAssumeRoleWithWebIdentity(t *testing.T) { service := setupTestSTSService(t) tests := []struct { name string roleArn string webIdentityToken string sessionName string durationSeconds *int64 wantErr bool expectedSubject string }{ { name: "successful role assumption", roleArn: "arn:seaweed:iam::role/TestRole", webIdentityToken: "valid-oidc-token", sessionName: "test-session", durationSeconds: nil, // Use default wantErr: false, expectedSubject: "test-user-id", }, { name: "invalid web identity token", roleArn: "arn:seaweed:iam::role/TestRole", webIdentityToken: "invalid-token", sessionName: "test-session", wantErr: true, }, { name: "non-existent role", roleArn: "arn:seaweed:iam::role/NonExistentRole", webIdentityToken: "valid-oidc-token", sessionName: "test-session", wantErr: true, }, { name: "custom session duration", roleArn: "arn:seaweed:iam::role/TestRole", webIdentityToken: "valid-oidc-token", sessionName: "test-session", durationSeconds: int64Ptr(7200), // 2 hours wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() request := &AssumeRoleWithWebIdentityRequest{ RoleArn: tt.roleArn, WebIdentityToken: tt.webIdentityToken, RoleSessionName: tt.sessionName, DurationSeconds: tt.durationSeconds, } response, err := service.AssumeRoleWithWebIdentity(ctx, request) if tt.wantErr { assert.Error(t, err) assert.Nil(t, response) } else { assert.NoError(t, err) assert.NotNil(t, response) assert.NotNil(t, response.Credentials) assert.NotNil(t, response.AssumedRoleUser) // Verify credentials creds := response.Credentials assert.NotEmpty(t, creds.AccessKeyId) assert.NotEmpty(t, creds.SecretAccessKey) assert.NotEmpty(t, creds.SessionToken) assert.True(t, creds.Expiration.After(time.Now())) // Verify assumed role user user := response.AssumedRoleUser assert.Equal(t, tt.roleArn, user.AssumedRoleId) assert.Contains(t, user.Arn, tt.sessionName) if tt.expectedSubject != "" { assert.Equal(t, tt.expectedSubject, user.Subject) } } }) } } // TestAssumeRoleWithLDAP tests role assumption with LDAP credentials func TestAssumeRoleWithLDAP(t *testing.T) { service := setupTestSTSService(t) tests := []struct { name string roleArn string username string password string sessionName string wantErr bool }{ { name: "successful LDAP role assumption", roleArn: "arn:seaweed:iam::role/LDAPRole", username: "testuser", password: "testpass", sessionName: "ldap-session", wantErr: false, }, { name: "invalid LDAP credentials", roleArn: "arn:seaweed:iam::role/LDAPRole", username: "testuser", password: "wrongpass", sessionName: "ldap-session", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() request := &AssumeRoleWithCredentialsRequest{ RoleArn: tt.roleArn, Username: tt.username, Password: tt.password, RoleSessionName: tt.sessionName, ProviderName: "test-ldap", } response, err := service.AssumeRoleWithCredentials(ctx, request) if tt.wantErr { assert.Error(t, err) assert.Nil(t, response) } else { assert.NoError(t, err) assert.NotNil(t, response) assert.NotNil(t, response.Credentials) } }) } } // TestSessionTokenValidation tests session token validation func TestSessionTokenValidation(t *testing.T) { service := setupTestSTSService(t) ctx := context.Background() // First, create a session request := &AssumeRoleWithWebIdentityRequest{ RoleArn: "arn:seaweed:iam::role/TestRole", WebIdentityToken: "valid-oidc-token", RoleSessionName: "test-session", } response, err := service.AssumeRoleWithWebIdentity(ctx, request) require.NoError(t, err) require.NotNil(t, response) sessionToken := response.Credentials.SessionToken tests := []struct { name string token string wantErr bool }{ { name: "valid session token", token: sessionToken, wantErr: false, }, { name: "invalid session token", token: "invalid-session-token", wantErr: true, }, { name: "empty session token", token: "", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { session, err := service.ValidateSessionToken(ctx, tt.token) if tt.wantErr { assert.Error(t, err) assert.Nil(t, session) } else { assert.NoError(t, err) assert.NotNil(t, session) assert.Equal(t, "test-session", session.SessionName) assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn) } }) } } // TestSessionRevocation tests session token revocation func TestSessionRevocation(t *testing.T) { service := setupTestSTSService(t) ctx := context.Background() // Create a session first request := &AssumeRoleWithWebIdentityRequest{ RoleArn: "arn:seaweed:iam::role/TestRole", WebIdentityToken: "valid-oidc-token", RoleSessionName: "test-session", } response, err := service.AssumeRoleWithWebIdentity(ctx, request) require.NoError(t, err) sessionToken := response.Credentials.SessionToken // Verify token is valid before revocation session, err := service.ValidateSessionToken(ctx, sessionToken) assert.NoError(t, err) assert.NotNil(t, session) // Revoke the session err = service.RevokeSession(ctx, sessionToken) assert.NoError(t, err) // Verify token is no longer valid after revocation session, err = service.ValidateSessionToken(ctx, sessionToken) assert.Error(t, err) assert.Nil(t, session) } // Helper functions func setupTestSTSService(t *testing.T) *STSService { service := NewSTSService() config := &STSConfig{ TokenDuration: time.Hour, MaxSessionLength: time.Hour * 12, Issuer: "test-sts", SigningKey: []byte("test-signing-key-32-characters-long"), } err := service.Initialize(config) require.NoError(t, err) // Register test providers mockOIDCProvider := &MockIdentityProvider{ name: "test-oidc", validTokens: map[string]*providers.TokenClaims{ "valid-oidc-token": { Subject: "test-user-id", Issuer: "test-issuer", Claims: map[string]interface{}{ "email": "test@example.com", "name": "Test User", }, }, }, } mockLDAPProvider := &MockIdentityProvider{ name: "test-ldap", validCredentials: map[string]string{ "testuser": "testpass", }, } service.RegisterProvider(mockOIDCProvider) service.RegisterProvider(mockLDAPProvider) return service } func int64Ptr(v int64) *int64 { return &v } // Mock identity provider for testing type MockIdentityProvider struct { name string validTokens map[string]*providers.TokenClaims validCredentials map[string]string } func (m *MockIdentityProvider) Name() string { return m.name } func (m *MockIdentityProvider) Initialize(config interface{}) error { return nil } func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { // Handle OIDC tokens if claims, exists := m.validTokens[token]; exists { email, _ := claims.GetClaimString("email") name, _ := claims.GetClaimString("name") return &providers.ExternalIdentity{ UserID: claims.Subject, Email: email, DisplayName: name, Provider: m.name, }, nil } // Handle LDAP credentials (username:password format) if m.validCredentials != nil { parts := strings.Split(token, ":") if len(parts) == 2 { username, password := parts[0], parts[1] if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password { return &providers.ExternalIdentity{ UserID: username, Email: username + "@" + m.name + ".com", DisplayName: "Test User " + username, Provider: m.name, }, nil } } } return nil, fmt.Errorf("invalid token") } func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { return &providers.ExternalIdentity{ UserID: userID, Email: userID + "@" + m.name + ".com", Provider: m.name, }, nil } func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { if claims, exists := m.validTokens[token]; exists { return claims, nil } return nil, fmt.Errorf("invalid token") }