From 8718c301baceef05b60883b6dfd8bd8735c1cb80 Mon Sep 17 00:00:00 2001 From: chrislu Date: Sun, 24 Aug 2025 14:10:55 -0700 Subject: [PATCH] feat(sts): pass filerAddress at call-time instead of init-time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change addresses the requirement that filer addresses should be passed when methods are called, not during initialization, to support: - Dynamic filer failover and load balancing - Runtime changes to filer topology - Environment-agnostic configuration files ### Changes Made: #### SessionStore Interface & Implementations: - Updated SessionStore interface to accept filerAddress parameter in all methods - Modified FilerSessionStore to remove filerAddress field from struct - Updated MemorySessionStore to accept filerAddress (ignored) for interface consistency - All methods now take: (ctx, filerAddress, sessionId, ...) parameters #### STS Service Methods: - Updated all public STS methods to accept filerAddress parameter: - AssumeRoleWithWebIdentity(ctx, filerAddress, request) - AssumeRoleWithCredentials(ctx, filerAddress, request) - ValidateSessionToken(ctx, filerAddress, sessionToken) - RevokeSession(ctx, filerAddress, sessionToken) - ExpireSessionForTesting(ctx, filerAddress, sessionToken) #### Configuration Cleanup: - Removed filerAddress from all configuration files (iam_config_distributed.json) - Configuration now only contains basePath and other store-specific settings - Makes configs environment-agnostic (dev/staging/prod compatible) #### Test Updates: - Updated all test files to pass testFilerAddress parameter - Tests use dummy filerAddress ('localhost:8888') for consistency - Maintains test functionality while validating new interface ### Benefits: - ✅ Filer addresses determined at runtime by caller (S3 API server) - ✅ Supports filer failover without service restart - ✅ Configuration files work across environments - ✅ Follows SeaweedFS patterns used elsewhere in codebase - ✅ Load balancer friendly - no filer affinity required - ✅ Horizontal scaling compatible ### Breaking Change: This is a breaking change for any code calling STS service methods. Callers must now pass filerAddress as the second parameter. --- test/s3/iam/iam_config_distributed.json | 3 - weed/iam/sts/cross_instance_token_test.go | 34 ++++---- weed/iam/sts/session_store.go | 101 +++++++++++----------- weed/iam/sts/sts_service.go | 36 ++++---- weed/iam/sts/sts_service_test.go | 20 +++-- 5 files changed, 97 insertions(+), 97 deletions(-) diff --git a/test/s3/iam/iam_config_distributed.json b/test/s3/iam/iam_config_distributed.json index 7b6a363c4..595172171 100644 --- a/test/s3/iam/iam_config_distributed.json +++ b/test/s3/iam/iam_config_distributed.json @@ -6,7 +6,6 @@ "signingKey": "dGVzdC1zaWduaW5nLWtleS0zMi1jaGFyYWN0ZXJzLWxvbmc=", "sessionStoreType": "filer", "sessionStoreConfig": { - "filerAddress": "localhost:8888", "basePath": "/etc/iam/sessions" }, "providers": [ @@ -41,14 +40,12 @@ "defaultEffect": "Deny", "storeType": "filer", "storeConfig": { - "filerAddress": "localhost:8888", "basePath": "/etc/iam/policies" } }, "roleStore": { "storeType": "filer", "storeConfig": { - "filerAddress": "localhost:8888", "basePath": "/etc/iam/roles" } }, diff --git a/weed/iam/sts/cross_instance_token_test.go b/weed/iam/sts/cross_instance_token_test.go index 776795ab9..5b4803ce5 100644 --- a/weed/iam/sts/cross_instance_token_test.go +++ b/weed/iam/sts/cross_instance_token_test.go @@ -13,6 +13,7 @@ import ( // can be used and validated by other STS instances in a distributed environment func TestCrossInstanceTokenUsage(t *testing.T) { ctx := context.Background() + testFilerAddress := "localhost:8888" // Dummy filer address for testing // Common configuration that would be shared across all instances in production sharedConfig := &STSConfig{ @@ -99,7 +100,7 @@ func TestCrossInstanceTokenUsage(t *testing.T) { } // Instance A processes assume role request - responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) + responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest) require.NoError(t, err, "Instance A should process assume role") sessionToken := responseFromA.Credentials.SessionToken @@ -113,14 +114,14 @@ func TestCrossInstanceTokenUsage(t *testing.T) { assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user") // Step 2: Use session token on Instance B (different instance) - sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken) + sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken) require.NoError(t, err, "Instance B should validate session token from Instance A") assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName) assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn) // Step 3: Use same session token on Instance C (yet another instance) - sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken) + sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, testFilerAddress, sessionToken) require.NoError(t, err, "Instance C should validate session token from Instance A") // All instances should return identical session information @@ -140,24 +141,24 @@ func TestCrossInstanceTokenUsage(t *testing.T) { RoleSessionName: "revocation-test-session", } - response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) + response, err := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest) require.NoError(t, err) sessionToken := response.Credentials.SessionToken // Verify token works on Instance B - _, err = instanceB.ValidateSessionToken(ctx, sessionToken) + _, err = instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken) require.NoError(t, err, "Token should be valid on Instance B initially") // Revoke session on Instance C - err = instanceC.RevokeSession(ctx, sessionToken) + err = instanceC.RevokeSession(ctx, testFilerAddress, sessionToken) require.NoError(t, err, "Instance C should be able to revoke session") // Verify token is now invalid on Instance A (revoked by Instance C) - _, err = instanceA.ValidateSessionToken(ctx, sessionToken) + _, err = instanceA.ValidateSessionToken(ctx, testFilerAddress, sessionToken) assert.Error(t, err, "Token should be invalid on Instance A after revocation") // Verify token is also invalid on Instance B - _, err = instanceB.ValidateSessionToken(ctx, sessionToken) + _, err = instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken) assert.Error(t, err, "Token should be invalid on Instance B after revocation") }) @@ -182,9 +183,9 @@ func TestCrossInstanceTokenUsage(t *testing.T) { } // Should work on any instance - responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) - responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest) - responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest) + responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest) + responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest) + responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest) require.NoError(t, errA, "Instance A should process OIDC token") require.NoError(t, errB, "Instance B should process OIDC token") @@ -200,6 +201,7 @@ func TestCrossInstanceTokenUsage(t *testing.T) { // TestSTSDistributedConfigurationRequirements tests the configuration requirements // for cross-instance token compatibility func TestSTSDistributedConfigurationRequirements(t *testing.T) { + _ = "localhost:8888" // Dummy filer address for testing (not used in these tests) t.Run("same_signing_key_required", func(t *testing.T) { // Instance A with signing key 1 @@ -319,6 +321,7 @@ func TestSTSDistributedConfigurationRequirements(t *testing.T) { // TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios func TestSTSRealWorldDistributedScenarios(t *testing.T) { ctx := context.Background() + testFilerAddress := "prod-filer-cluster:8888" // Test filer address t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) { // Simulate real production scenario: @@ -334,8 +337,7 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) { SigningKey: []byte("prod-signing-key-32-characters-lon"), SessionStoreType: "filer", SessionStoreConfig: map[string]interface{}{ - "filerAddress": "prod-filer-cluster:8888", - "basePath": "/seaweedfs/iam/sessions", + "basePath": "/seaweedfs/iam/sessions", }, Providers: []*ProviderConfig{ { @@ -374,7 +376,7 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) { DurationSeconds: int64ToPtr(7200), // 2 hours } - stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest) + stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest) require.NoError(t, err, "Gateway 1 should handle AssumeRole") sessionToken := stsResponse.Credentials.SessionToken @@ -383,13 +385,13 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) { // Step 2: User makes S3 requests that hit different gateways via load balancer // Simulate S3 request validation on Gateway 2 - sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken) + sessionInfo2, err := gateway2.ValidateSessionToken(ctx, testFilerAddress, sessionToken) require.NoError(t, err, "Gateway 2 should validate session from Gateway 1") assert.Equal(t, "user-production-session", sessionInfo2.SessionName) assert.Equal(t, "arn:seaweed:iam::role/ProductionS3User", sessionInfo2.RoleArn) // Simulate S3 request validation on Gateway 3 - sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken) + sessionInfo3, err := gateway3.ValidateSessionToken(ctx, testFilerAddress, sessionToken) require.NoError(t, err, "Gateway 3 should validate session from Gateway 1") assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session") diff --git a/weed/iam/sts/session_store.go b/weed/iam/sts/session_store.go index e0d96d322..44ca06925 100644 --- a/weed/iam/sts/session_store.go +++ b/weed/iam/sts/session_store.go @@ -27,10 +27,10 @@ func NewMemorySessionStore() *MemorySessionStore { } } -// StoreSession stores session information in memory -func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error { +// StoreSession stores session information in memory (filerAddress ignored for memory store) +func (m *MemorySessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { if sessionId == "" { - return fmt.Errorf("session ID cannot be empty") + return fmt.Errorf(ErrSessionIDCannotBeEmpty) } if session == nil { @@ -44,10 +44,10 @@ func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string, return nil } -// GetSession retrieves session information from memory -func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) { +// GetSession retrieves session information from memory (filerAddress ignored for memory store) +func (m *MemorySessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) { if sessionId == "" { - return nil, fmt.Errorf("session ID cannot be empty") + return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty) } m.mutex.RLock() @@ -66,10 +66,10 @@ func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) ( return session, nil } -// RevokeSession revokes a session from memory -func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string) error { +// RevokeSession revokes a session from memory (filerAddress ignored for memory store) +func (m *MemorySessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error { if sessionId == "" { - return fmt.Errorf("session ID cannot be empty") + return fmt.Errorf(ErrSessionIDCannotBeEmpty) } m.mutex.Lock() @@ -79,8 +79,8 @@ func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string return nil } -// CleanupExpiredSessions removes expired sessions from memory -func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error { +// CleanupExpiredSessions removes expired sessions from memory (filerAddress ignored for memory store) +func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -94,10 +94,10 @@ func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error { return nil } -// ExpireSessionForTesting manually expires a session for testing purposes -func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, sessionId string) error { +// ExpireSessionForTesting manually expires a session for testing purposes (filerAddress ignored for memory store) +func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionId string) error { if sessionId == "" { - return fmt.Errorf("session ID cannot be empty") + return fmt.Errorf(ErrSessionIDCannotBeEmpty) } m.mutex.Lock() @@ -117,42 +117,35 @@ func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, sessio // FilerSessionStore implements SessionStore using SeaweedFS filer type FilerSessionStore struct { - filerGrpcAddress string - grpcDialOption grpc.DialOption - basePath string + grpcDialOption grpc.DialOption + basePath string } // NewFilerSessionStore creates a new filer-based session store func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) { store := &FilerSessionStore{ - basePath: "/seaweedfs/iam/sessions", // Default path for session storage + basePath: DefaultSessionBasePath, // Use constant default } - // Parse configuration + // Parse configuration - only basePath and other settings, NOT filerAddress if config != nil { - if filerAddr, ok := config["filerAddress"].(string); ok { - store.filerGrpcAddress = filerAddr - } - if basePath, ok := config["basePath"].(string); ok { + if basePath, ok := config[ConfigFieldBasePath].(string); ok && basePath != "" { store.basePath = strings.TrimSuffix(basePath, "/") } } - // Validate configuration - if store.filerGrpcAddress == "" { - return nil, fmt.Errorf("filer address is required for FilerSessionStore") - } - - glog.V(2).Infof("Initialized FilerSessionStore with filer %s, basePath %s", - store.filerGrpcAddress, store.basePath) + glog.V(2).Infof("Initialized FilerSessionStore with basePath %s", store.basePath) return store, nil } // StoreSession stores session information in filer -func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error { +func (f *FilerSessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { + if filerAddress == "" { + return fmt.Errorf(ErrFilerAddressRequired) + } if sessionId == "" { - return fmt.Errorf("session ID cannot be empty") + return fmt.Errorf(ErrSessionIDCannotBeEmpty) } if session == nil { return fmt.Errorf("session cannot be nil") @@ -167,7 +160,7 @@ func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string, sessionPath := f.getSessionPath(sessionId) // Store in filer - return f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error { + return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { request := &filer_pb.CreateEntryRequest{ Directory: f.basePath, Entry: &filer_pb.Entry{ @@ -195,13 +188,16 @@ func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string, } // GetSession retrieves session information from filer -func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) { +func (f *FilerSessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) { + if filerAddress == "" { + return nil, fmt.Errorf(ErrFilerAddressRequired) + } if sessionId == "" { - return nil, fmt.Errorf("session ID cannot be empty") + return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty) } var sessionData []byte - err := f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error { + err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { request := &filer_pb.LookupDirectoryEntryRequest{ Directory: f.basePath, Name: f.getSessionFileName(sessionId), @@ -234,7 +230,7 @@ func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (* // Check if session has expired if time.Now().After(session.ExpiresAt) { // Clean up expired session - _ = f.RevokeSession(ctx, sessionId) + _ = f.RevokeSession(ctx, filerAddress, sessionId) return nil, fmt.Errorf("session has expired") } @@ -242,12 +238,15 @@ func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (* } // RevokeSession revokes a session from filer -func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string) error { +func (f *FilerSessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error { + if filerAddress == "" { + return fmt.Errorf(ErrFilerAddressRequired) + } if sessionId == "" { - return fmt.Errorf("session ID cannot be empty") + return fmt.Errorf(ErrSessionIDCannotBeEmpty) } - return f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error { + return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { request := &filer_pb.DeleteEntryRequest{ Directory: f.basePath, Name: f.getSessionFileName(sessionId), @@ -280,11 +279,15 @@ func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string) } // CleanupExpiredSessions removes expired sessions from filer -func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error { +func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error { + if filerAddress == "" { + return fmt.Errorf(ErrFilerAddressRequired) + } + now := time.Now() expiredCount := 0 - err := f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error { + err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { // List all entries in the session directory request := &filer_pb.ListEntriesRequest{ Directory: f.basePath, @@ -345,20 +348,14 @@ func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error { // Helper methods -// SetFilerClient sets the filer client connection details -func (f *FilerSessionStore) SetFilerClient(filerAddress string, grpcDialOption grpc.DialOption) { - f.filerGrpcAddress = filerAddress - f.grpcDialOption = grpcDialOption -} - // withFilerClient executes a function with a filer client -func (f *FilerSessionStore) withFilerClient(fn func(client filer_pb.SeaweedFilerClient) error) error { - if f.filerGrpcAddress == "" { - return fmt.Errorf("filer address not configured") +func (f *FilerSessionStore) withFilerClient(filerAddress string, fn func(client filer_pb.SeaweedFilerClient) error) error { + if filerAddress == "" { + return fmt.Errorf(ErrFilerAddressRequired) } // Use the pb.WithGrpcFilerClient helper similar to existing SeaweedFS code - return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(f.filerGrpcAddress), f.grpcDialOption, fn) + return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(filerAddress), f.grpcDialOption, fn) } // getSessionPath returns the full path for a session diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go index 8199a14cf..736a07a48 100644 --- a/weed/iam/sts/sts_service.go +++ b/weed/iam/sts/sts_service.go @@ -162,17 +162,17 @@ type SessionInfo struct { // SessionStore defines the interface for storing session information type SessionStore interface { - // StoreSession stores session information - StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error + // StoreSession stores session information (filerAddress ignored for memory stores) + StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error - // GetSession retrieves session information - GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) + // GetSession retrieves session information (filerAddress ignored for memory stores) + GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) - // RevokeSession revokes a session - RevokeSession(ctx context.Context, sessionId string) error + // RevokeSession revokes a session (filerAddress ignored for memory stores) + RevokeSession(ctx context.Context, filerAddress string, sessionId string) error - // CleanupExpiredSessions removes expired sessions - CleanupExpiredSessions(ctx context.Context) error + // CleanupExpiredSessions removes expired sessions (filerAddress ignored for memory stores) + CleanupExpiredSessions(ctx context.Context, filerAddress string) error } // NewSTSService creates a new STS service @@ -300,7 +300,7 @@ func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error } // AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC) -func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) { +func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, filerAddress string, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) { if !s.initialized { return nil, fmt.Errorf(ErrSTSServiceNotInitialized) } @@ -361,7 +361,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass } // 6. Store session information - if err := s.sessionStore.StoreSession(ctx, sessionId, session); err != nil { + if err := s.sessionStore.StoreSession(ctx, filerAddress, sessionId, session); err != nil { return nil, fmt.Errorf("failed to store session: %w", err) } @@ -379,7 +379,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass } // AssumeRoleWithCredentials assumes a role using username/password credentials -func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) { +func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, filerAddress string, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) { if !s.initialized { return nil, fmt.Errorf("STS service not initialized") } @@ -447,7 +447,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass } // 7. Store session information - if err := s.sessionStore.StoreSession(ctx, sessionId, session); err != nil { + if err := s.sessionStore.StoreSession(ctx, filerAddress, sessionId, session); err != nil { return nil, fmt.Errorf("failed to store session: %w", err) } @@ -465,7 +465,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass } // ValidateSessionToken validates a session token and returns session information -func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) { +func (s *STSService) ValidateSessionToken(ctx context.Context, filerAddress string, sessionToken string) (*SessionInfo, error) { if !s.initialized { return nil, fmt.Errorf(ErrSTSServiceNotInitialized) } @@ -481,7 +481,7 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri } // Retrieve session from store using session ID from claims - session, err := s.sessionStore.GetSession(ctx, claims.SessionId) + session, err := s.sessionStore.GetSession(ctx, filerAddress, claims.SessionId) if err != nil { return nil, fmt.Errorf(ErrSessionValidationFailed, err) } @@ -495,7 +495,7 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri } // RevokeSession revokes an active session -func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) error { +func (s *STSService) RevokeSession(ctx context.Context, filerAddress string, sessionToken string) error { if !s.initialized { return fmt.Errorf("STS service not initialized") } @@ -511,7 +511,7 @@ func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) err } // Remove session from store using session ID from claims - err = s.sessionStore.RevokeSession(ctx, claims.SessionId) + err = s.sessionStore.RevokeSession(ctx, filerAddress, claims.SessionId) if err != nil { return fmt.Errorf("failed to revoke session: %w", err) } @@ -679,7 +679,7 @@ func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRol } // ExpireSessionForTesting manually expires a session for testing purposes -func (s *STSService) ExpireSessionForTesting(ctx context.Context, sessionToken string) error { +func (s *STSService) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionToken string) error { if !s.initialized { return fmt.Errorf("STS service not initialized") } @@ -696,7 +696,7 @@ func (s *STSService) ExpireSessionForTesting(ctx context.Context, sessionToken s // Check if session store supports manual expiration (for MemorySessionStore) if memStore, ok := s.sessionStore.(*MemorySessionStore); ok { - return memStore.ExpireSessionForTesting(ctx, sessionId) + return memStore.ExpireSessionForTesting(ctx, filerAddress, sessionId) } // For other session stores, we could implement similar functionality diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go index 027080845..0959399ee 100644 --- a/weed/iam/sts/sts_service_test.go +++ b/weed/iam/sts/sts_service_test.go @@ -67,6 +67,7 @@ func TestSTSServiceInitialization(t *testing.T) { // TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens func TestAssumeRoleWithWebIdentity(t *testing.T) { service := setupTestSTSService(t) + testFilerAddress := "localhost:8888" // Dummy filer address for testing tests := []struct { name string @@ -121,7 +122,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) { DurationSeconds: tt.durationSeconds, } - response, err := service.AssumeRoleWithWebIdentity(ctx, request) + response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request) if tt.wantErr { assert.Error(t, err) @@ -155,6 +156,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) { // TestAssumeRoleWithLDAP tests role assumption with LDAP credentials func TestAssumeRoleWithLDAP(t *testing.T) { service := setupTestSTSService(t) + testFilerAddress := "localhost:8888" // Dummy filer address for testing tests := []struct { name string @@ -194,7 +196,7 @@ func TestAssumeRoleWithLDAP(t *testing.T) { ProviderName: "test-ldap", } - response, err := service.AssumeRoleWithCredentials(ctx, request) + response, err := service.AssumeRoleWithCredentials(ctx, testFilerAddress, request) if tt.wantErr { assert.Error(t, err) @@ -212,6 +214,7 @@ func TestAssumeRoleWithLDAP(t *testing.T) { func TestSessionTokenValidation(t *testing.T) { service := setupTestSTSService(t) ctx := context.Background() + testFilerAddress := "localhost:8888" // Dummy filer address for testing // First, create a session request := &AssumeRoleWithWebIdentityRequest{ @@ -220,7 +223,7 @@ func TestSessionTokenValidation(t *testing.T) { RoleSessionName: "test-session", } - response, err := service.AssumeRoleWithWebIdentity(ctx, request) + response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request) require.NoError(t, err) require.NotNil(t, response) @@ -250,7 +253,7 @@ func TestSessionTokenValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - session, err := service.ValidateSessionToken(ctx, tt.token) + session, err := service.ValidateSessionToken(ctx, testFilerAddress, tt.token) if tt.wantErr { assert.Error(t, err) @@ -269,6 +272,7 @@ func TestSessionTokenValidation(t *testing.T) { func TestSessionRevocation(t *testing.T) { service := setupTestSTSService(t) ctx := context.Background() + testFilerAddress := "localhost:8888" // Dummy filer address for testing // Create a session first request := &AssumeRoleWithWebIdentityRequest{ @@ -277,22 +281,22 @@ func TestSessionRevocation(t *testing.T) { RoleSessionName: "test-session", } - response, err := service.AssumeRoleWithWebIdentity(ctx, request) + response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request) require.NoError(t, err) sessionToken := response.Credentials.SessionToken // Verify token is valid before revocation - session, err := service.ValidateSessionToken(ctx, sessionToken) + session, err := service.ValidateSessionToken(ctx, testFilerAddress, sessionToken) assert.NoError(t, err) assert.NotNil(t, session) // Revoke the session - err = service.RevokeSession(ctx, sessionToken) + err = service.RevokeSession(ctx, testFilerAddress, sessionToken) assert.NoError(t, err) // Verify token is no longer valid after revocation - session, err = service.ValidateSessionToken(ctx, sessionToken) + session, err = service.ValidateSessionToken(ctx, testFilerAddress, sessionToken) assert.Error(t, err) assert.Nil(t, session) }