Browse Source

feat(sts): pass filerAddress at call-time instead of init-time

This change addresses the requirement that filer addresses should be
passed when methods are called, not during initialization, to support:
- Dynamic filer failover and load balancing
- Runtime changes to filer topology
- Environment-agnostic configuration files

### Changes Made:

#### SessionStore Interface & Implementations:
- Updated SessionStore interface to accept filerAddress parameter in all methods
- Modified FilerSessionStore to remove filerAddress field from struct
- Updated MemorySessionStore to accept filerAddress (ignored) for interface consistency
- All methods now take: (ctx, filerAddress, sessionId, ...) parameters

#### STS Service Methods:
- Updated all public STS methods to accept filerAddress parameter:
  - AssumeRoleWithWebIdentity(ctx, filerAddress, request)
  - AssumeRoleWithCredentials(ctx, filerAddress, request)
  - ValidateSessionToken(ctx, filerAddress, sessionToken)
  - RevokeSession(ctx, filerAddress, sessionToken)
  - ExpireSessionForTesting(ctx, filerAddress, sessionToken)

#### Configuration Cleanup:
- Removed filerAddress from all configuration files (iam_config_distributed.json)
- Configuration now only contains basePath and other store-specific settings
- Makes configs environment-agnostic (dev/staging/prod compatible)

#### Test Updates:
- Updated all test files to pass testFilerAddress parameter
- Tests use dummy filerAddress ('localhost:8888') for consistency
- Maintains test functionality while validating new interface

### Benefits:
-  Filer addresses determined at runtime by caller (S3 API server)
-  Supports filer failover without service restart
-  Configuration files work across environments
-  Follows SeaweedFS patterns used elsewhere in codebase
-  Load balancer friendly - no filer affinity required
-  Horizontal scaling compatible

### Breaking Change:
This is a breaking change for any code calling STS service methods.
Callers must now pass filerAddress as the second parameter.
pull/7160/head
chrislu 2 months ago
parent
commit
8718c301ba
  1. 3
      test/s3/iam/iam_config_distributed.json
  2. 32
      weed/iam/sts/cross_instance_token_test.go
  3. 97
      weed/iam/sts/session_store.go
  4. 36
      weed/iam/sts/sts_service.go
  5. 20
      weed/iam/sts/sts_service_test.go

3
test/s3/iam/iam_config_distributed.json

@ -6,7 +6,6 @@
"signingKey": "dGVzdC1zaWduaW5nLWtleS0zMi1jaGFyYWN0ZXJzLWxvbmc=",
"sessionStoreType": "filer",
"sessionStoreConfig": {
"filerAddress": "localhost:8888",
"basePath": "/etc/iam/sessions"
},
"providers": [
@ -41,14 +40,12 @@
"defaultEffect": "Deny",
"storeType": "filer",
"storeConfig": {
"filerAddress": "localhost:8888",
"basePath": "/etc/iam/policies"
}
},
"roleStore": {
"storeType": "filer",
"storeConfig": {
"filerAddress": "localhost:8888",
"basePath": "/etc/iam/roles"
}
},

32
weed/iam/sts/cross_instance_token_test.go

@ -13,6 +13,7 @@ import (
// can be used and validated by other STS instances in a distributed environment
func TestCrossInstanceTokenUsage(t *testing.T) {
ctx := context.Background()
testFilerAddress := "localhost:8888" // Dummy filer address for testing
// Common configuration that would be shared across all instances in production
sharedConfig := &STSConfig{
@ -99,7 +100,7 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
}
// Instance A processes assume role request
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, err, "Instance A should process assume role")
sessionToken := responseFromA.Credentials.SessionToken
@ -113,14 +114,14 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
// Step 2: Use session token on Instance B (different instance)
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Instance B should validate session token from Instance A")
assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
// Step 3: Use same session token on Instance C (yet another instance)
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Instance C should validate session token from Instance A")
// All instances should return identical session information
@ -140,24 +141,24 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
RoleSessionName: "revocation-test-session",
}
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, err)
sessionToken := response.Credentials.SessionToken
// Verify token works on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
_, err = instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Token should be valid on Instance B initially")
// Revoke session on Instance C
err = instanceC.RevokeSession(ctx, sessionToken)
err = instanceC.RevokeSession(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Instance C should be able to revoke session")
// Verify token is now invalid on Instance A (revoked by Instance C)
_, err = instanceA.ValidateSessionToken(ctx, sessionToken)
_, err = instanceA.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.Error(t, err, "Token should be invalid on Instance A after revocation")
// Verify token is also invalid on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
_, err = instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.Error(t, err, "Token should be invalid on Instance B after revocation")
})
@ -182,9 +183,9 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
}
// Should work on any instance
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, errA, "Instance A should process OIDC token")
require.NoError(t, errB, "Instance B should process OIDC token")
@ -200,6 +201,7 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
// TestSTSDistributedConfigurationRequirements tests the configuration requirements
// for cross-instance token compatibility
func TestSTSDistributedConfigurationRequirements(t *testing.T) {
_ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
t.Run("same_signing_key_required", func(t *testing.T) {
// Instance A with signing key 1
@ -319,6 +321,7 @@ func TestSTSDistributedConfigurationRequirements(t *testing.T) {
// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
func TestSTSRealWorldDistributedScenarios(t *testing.T) {
ctx := context.Background()
testFilerAddress := "prod-filer-cluster:8888" // Test filer address
t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
// Simulate real production scenario:
@ -334,7 +337,6 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) {
SigningKey: []byte("prod-signing-key-32-characters-lon"),
SessionStoreType: "filer",
SessionStoreConfig: map[string]interface{}{
"filerAddress": "prod-filer-cluster:8888",
"basePath": "/seaweedfs/iam/sessions",
},
Providers: []*ProviderConfig{
@ -374,7 +376,7 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) {
DurationSeconds: int64ToPtr(7200), // 2 hours
}
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, err, "Gateway 1 should handle AssumeRole")
sessionToken := stsResponse.Credentials.SessionToken
@ -383,13 +385,13 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) {
// Step 2: User makes S3 requests that hit different gateways via load balancer
// Simulate S3 request validation on Gateway 2
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
assert.Equal(t, "arn:seaweed:iam::role/ProductionS3User", sessionInfo2.RoleArn)
// Simulate S3 request validation on Gateway 3
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")

97
weed/iam/sts/session_store.go

@ -27,10 +27,10 @@ func NewMemorySessionStore() *MemorySessionStore {
}
}
// StoreSession stores session information in memory
func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error {
// StoreSession stores session information in memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error {
if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
}
if session == nil {
@ -44,10 +44,10 @@ func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string,
return nil
}
// GetSession retrieves session information from memory
func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) {
// GetSession retrieves session information from memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) {
if sessionId == "" {
return nil, fmt.Errorf("session ID cannot be empty")
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty)
}
m.mutex.RLock()
@ -66,10 +66,10 @@ func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) (
return session, nil
}
// RevokeSession revokes a session from memory
func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string) error {
// RevokeSession revokes a session from memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error {
if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
}
m.mutex.Lock()
@ -79,8 +79,8 @@ func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string
return nil
}
// CleanupExpiredSessions removes expired sessions from memory
func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error {
// CleanupExpiredSessions removes expired sessions from memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -94,10 +94,10 @@ func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error {
return nil
}
// ExpireSessionForTesting manually expires a session for testing purposes
func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, sessionId string) error {
// ExpireSessionForTesting manually expires a session for testing purposes (filerAddress ignored for memory store)
func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionId string) error {
if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
}
m.mutex.Lock()
@ -117,7 +117,6 @@ func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, sessio
// FilerSessionStore implements SessionStore using SeaweedFS filer
type FilerSessionStore struct {
filerGrpcAddress string
grpcDialOption grpc.DialOption
basePath string
}
@ -125,34 +124,28 @@ type FilerSessionStore struct {
// NewFilerSessionStore creates a new filer-based session store
func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) {
store := &FilerSessionStore{
basePath: "/seaweedfs/iam/sessions", // Default path for session storage
basePath: DefaultSessionBasePath, // Use constant default
}
// Parse configuration
// Parse configuration - only basePath and other settings, NOT filerAddress
if config != nil {
if filerAddr, ok := config["filerAddress"].(string); ok {
store.filerGrpcAddress = filerAddr
}
if basePath, ok := config["basePath"].(string); ok {
if basePath, ok := config[ConfigFieldBasePath].(string); ok && basePath != "" {
store.basePath = strings.TrimSuffix(basePath, "/")
}
}
// Validate configuration
if store.filerGrpcAddress == "" {
return nil, fmt.Errorf("filer address is required for FilerSessionStore")
}
glog.V(2).Infof("Initialized FilerSessionStore with filer %s, basePath %s",
store.filerGrpcAddress, store.basePath)
glog.V(2).Infof("Initialized FilerSessionStore with basePath %s", store.basePath)
return store, nil
}
// StoreSession stores session information in filer
func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error {
func (f *FilerSessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
}
if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
}
if session == nil {
return fmt.Errorf("session cannot be nil")
@ -167,7 +160,7 @@ func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string,
sessionPath := f.getSessionPath(sessionId)
// Store in filer
return f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.CreateEntryRequest{
Directory: f.basePath,
Entry: &filer_pb.Entry{
@ -195,13 +188,16 @@ func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string,
}
// GetSession retrieves session information from filer
func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) {
func (f *FilerSessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) {
if filerAddress == "" {
return nil, fmt.Errorf(ErrFilerAddressRequired)
}
if sessionId == "" {
return nil, fmt.Errorf("session ID cannot be empty")
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty)
}
var sessionData []byte
err := f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.LookupDirectoryEntryRequest{
Directory: f.basePath,
Name: f.getSessionFileName(sessionId),
@ -234,7 +230,7 @@ func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*
// Check if session has expired
if time.Now().After(session.ExpiresAt) {
// Clean up expired session
_ = f.RevokeSession(ctx, sessionId)
_ = f.RevokeSession(ctx, filerAddress, sessionId)
return nil, fmt.Errorf("session has expired")
}
@ -242,12 +238,15 @@ func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*
}
// RevokeSession revokes a session from filer
func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string) error {
func (f *FilerSessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
}
if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
}
return f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.DeleteEntryRequest{
Directory: f.basePath,
Name: f.getSessionFileName(sessionId),
@ -280,11 +279,15 @@ func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string)
}
// CleanupExpiredSessions removes expired sessions from filer
func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error {
func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
}
now := time.Now()
expiredCount := 0
err := f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
// List all entries in the session directory
request := &filer_pb.ListEntriesRequest{
Directory: f.basePath,
@ -345,20 +348,14 @@ func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error {
// Helper methods
// SetFilerClient sets the filer client connection details
func (f *FilerSessionStore) SetFilerClient(filerAddress string, grpcDialOption grpc.DialOption) {
f.filerGrpcAddress = filerAddress
f.grpcDialOption = grpcDialOption
}
// withFilerClient executes a function with a filer client
func (f *FilerSessionStore) withFilerClient(fn func(client filer_pb.SeaweedFilerClient) error) error {
if f.filerGrpcAddress == "" {
return fmt.Errorf("filer address not configured")
func (f *FilerSessionStore) withFilerClient(filerAddress string, fn func(client filer_pb.SeaweedFilerClient) error) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
}
// Use the pb.WithGrpcFilerClient helper similar to existing SeaweedFS code
return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(f.filerGrpcAddress), f.grpcDialOption, fn)
return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(filerAddress), f.grpcDialOption, fn)
}
// getSessionPath returns the full path for a session

36
weed/iam/sts/sts_service.go

@ -162,17 +162,17 @@ type SessionInfo struct {
// SessionStore defines the interface for storing session information
type SessionStore interface {
// StoreSession stores session information
StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error
// StoreSession stores session information (filerAddress ignored for memory stores)
StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error
// GetSession retrieves session information
GetSession(ctx context.Context, sessionId string) (*SessionInfo, error)
// GetSession retrieves session information (filerAddress ignored for memory stores)
GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error)
// RevokeSession revokes a session
RevokeSession(ctx context.Context, sessionId string) error
// RevokeSession revokes a session (filerAddress ignored for memory stores)
RevokeSession(ctx context.Context, filerAddress string, sessionId string) error
// CleanupExpiredSessions removes expired sessions
CleanupExpiredSessions(ctx context.Context) error
// CleanupExpiredSessions removes expired sessions (filerAddress ignored for memory stores)
CleanupExpiredSessions(ctx context.Context, filerAddress string) error
}
// NewSTSService creates a new STS service
@ -300,7 +300,7 @@ func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error
}
// AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC)
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) {
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, filerAddress string, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) {
if !s.initialized {
return nil, fmt.Errorf(ErrSTSServiceNotInitialized)
}
@ -361,7 +361,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
}
// 6. Store session information
if err := s.sessionStore.StoreSession(ctx, sessionId, session); err != nil {
if err := s.sessionStore.StoreSession(ctx, filerAddress, sessionId, session); err != nil {
return nil, fmt.Errorf("failed to store session: %w", err)
}
@ -379,7 +379,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
}
// AssumeRoleWithCredentials assumes a role using username/password credentials
func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) {
func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, filerAddress string, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) {
if !s.initialized {
return nil, fmt.Errorf("STS service not initialized")
}
@ -447,7 +447,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
}
// 7. Store session information
if err := s.sessionStore.StoreSession(ctx, sessionId, session); err != nil {
if err := s.sessionStore.StoreSession(ctx, filerAddress, sessionId, session); err != nil {
return nil, fmt.Errorf("failed to store session: %w", err)
}
@ -465,7 +465,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
}
// ValidateSessionToken validates a session token and returns session information
func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) {
func (s *STSService) ValidateSessionToken(ctx context.Context, filerAddress string, sessionToken string) (*SessionInfo, error) {
if !s.initialized {
return nil, fmt.Errorf(ErrSTSServiceNotInitialized)
}
@ -481,7 +481,7 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri
}
// Retrieve session from store using session ID from claims
session, err := s.sessionStore.GetSession(ctx, claims.SessionId)
session, err := s.sessionStore.GetSession(ctx, filerAddress, claims.SessionId)
if err != nil {
return nil, fmt.Errorf(ErrSessionValidationFailed, err)
}
@ -495,7 +495,7 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri
}
// RevokeSession revokes an active session
func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) error {
func (s *STSService) RevokeSession(ctx context.Context, filerAddress string, sessionToken string) error {
if !s.initialized {
return fmt.Errorf("STS service not initialized")
}
@ -511,7 +511,7 @@ func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) err
}
// Remove session from store using session ID from claims
err = s.sessionStore.RevokeSession(ctx, claims.SessionId)
err = s.sessionStore.RevokeSession(ctx, filerAddress, claims.SessionId)
if err != nil {
return fmt.Errorf("failed to revoke session: %w", err)
}
@ -679,7 +679,7 @@ func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRol
}
// ExpireSessionForTesting manually expires a session for testing purposes
func (s *STSService) ExpireSessionForTesting(ctx context.Context, sessionToken string) error {
func (s *STSService) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionToken string) error {
if !s.initialized {
return fmt.Errorf("STS service not initialized")
}
@ -696,7 +696,7 @@ func (s *STSService) ExpireSessionForTesting(ctx context.Context, sessionToken s
// Check if session store supports manual expiration (for MemorySessionStore)
if memStore, ok := s.sessionStore.(*MemorySessionStore); ok {
return memStore.ExpireSessionForTesting(ctx, sessionId)
return memStore.ExpireSessionForTesting(ctx, filerAddress, sessionId)
}
// For other session stores, we could implement similar functionality

20
weed/iam/sts/sts_service_test.go

@ -67,6 +67,7 @@ func TestSTSServiceInitialization(t *testing.T) {
// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
func TestAssumeRoleWithWebIdentity(t *testing.T) {
service := setupTestSTSService(t)
testFilerAddress := "localhost:8888" // Dummy filer address for testing
tests := []struct {
name string
@ -121,7 +122,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) {
DurationSeconds: tt.durationSeconds,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request)
if tt.wantErr {
assert.Error(t, err)
@ -155,6 +156,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) {
// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
func TestAssumeRoleWithLDAP(t *testing.T) {
service := setupTestSTSService(t)
testFilerAddress := "localhost:8888" // Dummy filer address for testing
tests := []struct {
name string
@ -194,7 +196,7 @@ func TestAssumeRoleWithLDAP(t *testing.T) {
ProviderName: "test-ldap",
}
response, err := service.AssumeRoleWithCredentials(ctx, request)
response, err := service.AssumeRoleWithCredentials(ctx, testFilerAddress, request)
if tt.wantErr {
assert.Error(t, err)
@ -212,6 +214,7 @@ func TestAssumeRoleWithLDAP(t *testing.T) {
func TestSessionTokenValidation(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
testFilerAddress := "localhost:8888" // Dummy filer address for testing
// First, create a session
request := &AssumeRoleWithWebIdentityRequest{
@ -220,7 +223,7 @@ func TestSessionTokenValidation(t *testing.T) {
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request)
require.NoError(t, err)
require.NotNil(t, response)
@ -250,7 +253,7 @@ func TestSessionTokenValidation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session, err := service.ValidateSessionToken(ctx, tt.token)
session, err := service.ValidateSessionToken(ctx, testFilerAddress, tt.token)
if tt.wantErr {
assert.Error(t, err)
@ -269,6 +272,7 @@ func TestSessionTokenValidation(t *testing.T) {
func TestSessionRevocation(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
testFilerAddress := "localhost:8888" // Dummy filer address for testing
// Create a session first
request := &AssumeRoleWithWebIdentityRequest{
@ -277,22 +281,22 @@ func TestSessionRevocation(t *testing.T) {
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request)
require.NoError(t, err)
sessionToken := response.Credentials.SessionToken
// Verify token is valid before revocation
session, err := service.ValidateSessionToken(ctx, sessionToken)
session, err := service.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.NoError(t, err)
assert.NotNil(t, session)
// Revoke the session
err = service.RevokeSession(ctx, sessionToken)
err = service.RevokeSession(ctx, testFilerAddress, sessionToken)
assert.NoError(t, err)
// Verify token is no longer valid after revocation
session, err = service.ValidateSessionToken(ctx, sessionToken)
session, err = service.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.Error(t, err)
assert.Nil(t, session)
}

Loading…
Cancel
Save