Browse Source

feat(sts): pass filerAddress at call-time instead of init-time

This change addresses the requirement that filer addresses should be
passed when methods are called, not during initialization, to support:
- Dynamic filer failover and load balancing
- Runtime changes to filer topology
- Environment-agnostic configuration files

### Changes Made:

#### SessionStore Interface & Implementations:
- Updated SessionStore interface to accept filerAddress parameter in all methods
- Modified FilerSessionStore to remove filerAddress field from struct
- Updated MemorySessionStore to accept filerAddress (ignored) for interface consistency
- All methods now take: (ctx, filerAddress, sessionId, ...) parameters

#### STS Service Methods:
- Updated all public STS methods to accept filerAddress parameter:
  - AssumeRoleWithWebIdentity(ctx, filerAddress, request)
  - AssumeRoleWithCredentials(ctx, filerAddress, request)
  - ValidateSessionToken(ctx, filerAddress, sessionToken)
  - RevokeSession(ctx, filerAddress, sessionToken)
  - ExpireSessionForTesting(ctx, filerAddress, sessionToken)

#### Configuration Cleanup:
- Removed filerAddress from all configuration files (iam_config_distributed.json)
- Configuration now only contains basePath and other store-specific settings
- Makes configs environment-agnostic (dev/staging/prod compatible)

#### Test Updates:
- Updated all test files to pass testFilerAddress parameter
- Tests use dummy filerAddress ('localhost:8888') for consistency
- Maintains test functionality while validating new interface

### Benefits:
-  Filer addresses determined at runtime by caller (S3 API server)
-  Supports filer failover without service restart
-  Configuration files work across environments
-  Follows SeaweedFS patterns used elsewhere in codebase
-  Load balancer friendly - no filer affinity required
-  Horizontal scaling compatible

### Breaking Change:
This is a breaking change for any code calling STS service methods.
Callers must now pass filerAddress as the second parameter.
pull/7160/head
chrislu 2 months ago
parent
commit
8718c301ba
  1. 3
      test/s3/iam/iam_config_distributed.json
  2. 32
      weed/iam/sts/cross_instance_token_test.go
  3. 97
      weed/iam/sts/session_store.go
  4. 36
      weed/iam/sts/sts_service.go
  5. 20
      weed/iam/sts/sts_service_test.go

3
test/s3/iam/iam_config_distributed.json

@ -6,7 +6,6 @@
"signingKey": "dGVzdC1zaWduaW5nLWtleS0zMi1jaGFyYWN0ZXJzLWxvbmc=", "signingKey": "dGVzdC1zaWduaW5nLWtleS0zMi1jaGFyYWN0ZXJzLWxvbmc=",
"sessionStoreType": "filer", "sessionStoreType": "filer",
"sessionStoreConfig": { "sessionStoreConfig": {
"filerAddress": "localhost:8888",
"basePath": "/etc/iam/sessions" "basePath": "/etc/iam/sessions"
}, },
"providers": [ "providers": [
@ -41,14 +40,12 @@
"defaultEffect": "Deny", "defaultEffect": "Deny",
"storeType": "filer", "storeType": "filer",
"storeConfig": { "storeConfig": {
"filerAddress": "localhost:8888",
"basePath": "/etc/iam/policies" "basePath": "/etc/iam/policies"
} }
}, },
"roleStore": { "roleStore": {
"storeType": "filer", "storeType": "filer",
"storeConfig": { "storeConfig": {
"filerAddress": "localhost:8888",
"basePath": "/etc/iam/roles" "basePath": "/etc/iam/roles"
} }
}, },

32
weed/iam/sts/cross_instance_token_test.go

@ -13,6 +13,7 @@ import (
// can be used and validated by other STS instances in a distributed environment // can be used and validated by other STS instances in a distributed environment
func TestCrossInstanceTokenUsage(t *testing.T) { func TestCrossInstanceTokenUsage(t *testing.T) {
ctx := context.Background() ctx := context.Background()
testFilerAddress := "localhost:8888" // Dummy filer address for testing
// Common configuration that would be shared across all instances in production // Common configuration that would be shared across all instances in production
sharedConfig := &STSConfig{ sharedConfig := &STSConfig{
@ -99,7 +100,7 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
} }
// Instance A processes assume role request // Instance A processes assume role request
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, err, "Instance A should process assume role") require.NoError(t, err, "Instance A should process assume role")
sessionToken := responseFromA.Credentials.SessionToken sessionToken := responseFromA.Credentials.SessionToken
@ -113,14 +114,14 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user") assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
// Step 2: Use session token on Instance B (different instance) // Step 2: Use session token on Instance B (different instance)
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Instance B should validate session token from Instance A") require.NoError(t, err, "Instance B should validate session token from Instance A")
assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName) assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn) assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
// Step 3: Use same session token on Instance C (yet another instance) // Step 3: Use same session token on Instance C (yet another instance)
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Instance C should validate session token from Instance A") require.NoError(t, err, "Instance C should validate session token from Instance A")
// All instances should return identical session information // All instances should return identical session information
@ -140,24 +141,24 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
RoleSessionName: "revocation-test-session", RoleSessionName: "revocation-test-session",
} }
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, err) require.NoError(t, err)
sessionToken := response.Credentials.SessionToken sessionToken := response.Credentials.SessionToken
// Verify token works on Instance B // Verify token works on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
_, err = instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Token should be valid on Instance B initially") require.NoError(t, err, "Token should be valid on Instance B initially")
// Revoke session on Instance C // Revoke session on Instance C
err = instanceC.RevokeSession(ctx, sessionToken)
err = instanceC.RevokeSession(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Instance C should be able to revoke session") require.NoError(t, err, "Instance C should be able to revoke session")
// Verify token is now invalid on Instance A (revoked by Instance C) // Verify token is now invalid on Instance A (revoked by Instance C)
_, err = instanceA.ValidateSessionToken(ctx, sessionToken)
_, err = instanceA.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.Error(t, err, "Token should be invalid on Instance A after revocation") assert.Error(t, err, "Token should be invalid on Instance A after revocation")
// Verify token is also invalid on Instance B // Verify token is also invalid on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
_, err = instanceB.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.Error(t, err, "Token should be invalid on Instance B after revocation") assert.Error(t, err, "Token should be invalid on Instance B after revocation")
}) })
@ -182,9 +183,9 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
} }
// Should work on any instance // Should work on any instance
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, errA, "Instance A should process OIDC token") require.NoError(t, errA, "Instance A should process OIDC token")
require.NoError(t, errB, "Instance B should process OIDC token") require.NoError(t, errB, "Instance B should process OIDC token")
@ -200,6 +201,7 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
// TestSTSDistributedConfigurationRequirements tests the configuration requirements // TestSTSDistributedConfigurationRequirements tests the configuration requirements
// for cross-instance token compatibility // for cross-instance token compatibility
func TestSTSDistributedConfigurationRequirements(t *testing.T) { func TestSTSDistributedConfigurationRequirements(t *testing.T) {
_ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
t.Run("same_signing_key_required", func(t *testing.T) { t.Run("same_signing_key_required", func(t *testing.T) {
// Instance A with signing key 1 // Instance A with signing key 1
@ -319,6 +321,7 @@ func TestSTSDistributedConfigurationRequirements(t *testing.T) {
// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios // TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
func TestSTSRealWorldDistributedScenarios(t *testing.T) { func TestSTSRealWorldDistributedScenarios(t *testing.T) {
ctx := context.Background() ctx := context.Background()
testFilerAddress := "prod-filer-cluster:8888" // Test filer address
t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) { t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
// Simulate real production scenario: // Simulate real production scenario:
@ -334,7 +337,6 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) {
SigningKey: []byte("prod-signing-key-32-characters-lon"), SigningKey: []byte("prod-signing-key-32-characters-lon"),
SessionStoreType: "filer", SessionStoreType: "filer",
SessionStoreConfig: map[string]interface{}{ SessionStoreConfig: map[string]interface{}{
"filerAddress": "prod-filer-cluster:8888",
"basePath": "/seaweedfs/iam/sessions", "basePath": "/seaweedfs/iam/sessions",
}, },
Providers: []*ProviderConfig{ Providers: []*ProviderConfig{
@ -374,7 +376,7 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) {
DurationSeconds: int64ToPtr(7200), // 2 hours DurationSeconds: int64ToPtr(7200), // 2 hours
} }
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, testFilerAddress, assumeRequest)
require.NoError(t, err, "Gateway 1 should handle AssumeRole") require.NoError(t, err, "Gateway 1 should handle AssumeRole")
sessionToken := stsResponse.Credentials.SessionToken sessionToken := stsResponse.Credentials.SessionToken
@ -383,13 +385,13 @@ func TestSTSRealWorldDistributedScenarios(t *testing.T) {
// Step 2: User makes S3 requests that hit different gateways via load balancer // Step 2: User makes S3 requests that hit different gateways via load balancer
// Simulate S3 request validation on Gateway 2 // Simulate S3 request validation on Gateway 2
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Gateway 2 should validate session from Gateway 1") require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
assert.Equal(t, "user-production-session", sessionInfo2.SessionName) assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
assert.Equal(t, "arn:seaweed:iam::role/ProductionS3User", sessionInfo2.RoleArn) assert.Equal(t, "arn:seaweed:iam::role/ProductionS3User", sessionInfo2.RoleArn)
// Simulate S3 request validation on Gateway 3 // Simulate S3 request validation on Gateway 3
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
require.NoError(t, err, "Gateway 3 should validate session from Gateway 1") require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session") assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")

97
weed/iam/sts/session_store.go

@ -27,10 +27,10 @@ func NewMemorySessionStore() *MemorySessionStore {
} }
} }
// StoreSession stores session information in memory
func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error {
// StoreSession stores session information in memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error {
if sessionId == "" { if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
} }
if session == nil { if session == nil {
@ -44,10 +44,10 @@ func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string,
return nil return nil
} }
// GetSession retrieves session information from memory
func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) {
// GetSession retrieves session information from memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) {
if sessionId == "" { if sessionId == "" {
return nil, fmt.Errorf("session ID cannot be empty")
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty)
} }
m.mutex.RLock() m.mutex.RLock()
@ -66,10 +66,10 @@ func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) (
return session, nil return session, nil
} }
// RevokeSession revokes a session from memory
func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string) error {
// RevokeSession revokes a session from memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error {
if sessionId == "" { if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
} }
m.mutex.Lock() m.mutex.Lock()
@ -79,8 +79,8 @@ func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string
return nil return nil
} }
// CleanupExpiredSessions removes expired sessions from memory
func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error {
// CleanupExpiredSessions removes expired sessions from memory (filerAddress ignored for memory store)
func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -94,10 +94,10 @@ func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error {
return nil return nil
} }
// ExpireSessionForTesting manually expires a session for testing purposes
func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, sessionId string) error {
// ExpireSessionForTesting manually expires a session for testing purposes (filerAddress ignored for memory store)
func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionId string) error {
if sessionId == "" { if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
} }
m.mutex.Lock() m.mutex.Lock()
@ -117,7 +117,6 @@ func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, sessio
// FilerSessionStore implements SessionStore using SeaweedFS filer // FilerSessionStore implements SessionStore using SeaweedFS filer
type FilerSessionStore struct { type FilerSessionStore struct {
filerGrpcAddress string
grpcDialOption grpc.DialOption grpcDialOption grpc.DialOption
basePath string basePath string
} }
@ -125,34 +124,28 @@ type FilerSessionStore struct {
// NewFilerSessionStore creates a new filer-based session store // NewFilerSessionStore creates a new filer-based session store
func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) { func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) {
store := &FilerSessionStore{ store := &FilerSessionStore{
basePath: "/seaweedfs/iam/sessions", // Default path for session storage
basePath: DefaultSessionBasePath, // Use constant default
} }
// Parse configuration
// Parse configuration - only basePath and other settings, NOT filerAddress
if config != nil { if config != nil {
if filerAddr, ok := config["filerAddress"].(string); ok {
store.filerGrpcAddress = filerAddr
}
if basePath, ok := config["basePath"].(string); ok {
if basePath, ok := config[ConfigFieldBasePath].(string); ok && basePath != "" {
store.basePath = strings.TrimSuffix(basePath, "/") store.basePath = strings.TrimSuffix(basePath, "/")
} }
} }
// Validate configuration
if store.filerGrpcAddress == "" {
return nil, fmt.Errorf("filer address is required for FilerSessionStore")
}
glog.V(2).Infof("Initialized FilerSessionStore with filer %s, basePath %s",
store.filerGrpcAddress, store.basePath)
glog.V(2).Infof("Initialized FilerSessionStore with basePath %s", store.basePath)
return store, nil return store, nil
} }
// StoreSession stores session information in filer // StoreSession stores session information in filer
func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error {
func (f *FilerSessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
}
if sessionId == "" { if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
} }
if session == nil { if session == nil {
return fmt.Errorf("session cannot be nil") return fmt.Errorf("session cannot be nil")
@ -167,7 +160,7 @@ func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string,
sessionPath := f.getSessionPath(sessionId) sessionPath := f.getSessionPath(sessionId)
// Store in filer // Store in filer
return f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.CreateEntryRequest{ request := &filer_pb.CreateEntryRequest{
Directory: f.basePath, Directory: f.basePath,
Entry: &filer_pb.Entry{ Entry: &filer_pb.Entry{
@ -195,13 +188,16 @@ func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string,
} }
// GetSession retrieves session information from filer // GetSession retrieves session information from filer
func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) {
func (f *FilerSessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) {
if filerAddress == "" {
return nil, fmt.Errorf(ErrFilerAddressRequired)
}
if sessionId == "" { if sessionId == "" {
return nil, fmt.Errorf("session ID cannot be empty")
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty)
} }
var sessionData []byte var sessionData []byte
err := f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.LookupDirectoryEntryRequest{ request := &filer_pb.LookupDirectoryEntryRequest{
Directory: f.basePath, Directory: f.basePath,
Name: f.getSessionFileName(sessionId), Name: f.getSessionFileName(sessionId),
@ -234,7 +230,7 @@ func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*
// Check if session has expired // Check if session has expired
if time.Now().After(session.ExpiresAt) { if time.Now().After(session.ExpiresAt) {
// Clean up expired session // Clean up expired session
_ = f.RevokeSession(ctx, sessionId)
_ = f.RevokeSession(ctx, filerAddress, sessionId)
return nil, fmt.Errorf("session has expired") return nil, fmt.Errorf("session has expired")
} }
@ -242,12 +238,15 @@ func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*
} }
// RevokeSession revokes a session from filer // RevokeSession revokes a session from filer
func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string) error {
func (f *FilerSessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
}
if sessionId == "" { if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
return fmt.Errorf(ErrSessionIDCannotBeEmpty)
} }
return f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.DeleteEntryRequest{ request := &filer_pb.DeleteEntryRequest{
Directory: f.basePath, Directory: f.basePath,
Name: f.getSessionFileName(sessionId), Name: f.getSessionFileName(sessionId),
@ -280,11 +279,15 @@ func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string)
} }
// CleanupExpiredSessions removes expired sessions from filer // CleanupExpiredSessions removes expired sessions from filer
func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error {
func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
}
now := time.Now() now := time.Now()
expiredCount := 0 expiredCount := 0
err := f.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error {
// List all entries in the session directory // List all entries in the session directory
request := &filer_pb.ListEntriesRequest{ request := &filer_pb.ListEntriesRequest{
Directory: f.basePath, Directory: f.basePath,
@ -345,20 +348,14 @@ func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error {
// Helper methods // Helper methods
// SetFilerClient sets the filer client connection details
func (f *FilerSessionStore) SetFilerClient(filerAddress string, grpcDialOption grpc.DialOption) {
f.filerGrpcAddress = filerAddress
f.grpcDialOption = grpcDialOption
}
// withFilerClient executes a function with a filer client // withFilerClient executes a function with a filer client
func (f *FilerSessionStore) withFilerClient(fn func(client filer_pb.SeaweedFilerClient) error) error {
if f.filerGrpcAddress == "" {
return fmt.Errorf("filer address not configured")
func (f *FilerSessionStore) withFilerClient(filerAddress string, fn func(client filer_pb.SeaweedFilerClient) error) error {
if filerAddress == "" {
return fmt.Errorf(ErrFilerAddressRequired)
} }
// Use the pb.WithGrpcFilerClient helper similar to existing SeaweedFS code // Use the pb.WithGrpcFilerClient helper similar to existing SeaweedFS code
return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(f.filerGrpcAddress), f.grpcDialOption, fn)
return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(filerAddress), f.grpcDialOption, fn)
} }
// getSessionPath returns the full path for a session // getSessionPath returns the full path for a session

36
weed/iam/sts/sts_service.go

@ -162,17 +162,17 @@ type SessionInfo struct {
// SessionStore defines the interface for storing session information // SessionStore defines the interface for storing session information
type SessionStore interface { type SessionStore interface {
// StoreSession stores session information
StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error
// StoreSession stores session information (filerAddress ignored for memory stores)
StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error
// GetSession retrieves session information
GetSession(ctx context.Context, sessionId string) (*SessionInfo, error)
// GetSession retrieves session information (filerAddress ignored for memory stores)
GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error)
// RevokeSession revokes a session
RevokeSession(ctx context.Context, sessionId string) error
// RevokeSession revokes a session (filerAddress ignored for memory stores)
RevokeSession(ctx context.Context, filerAddress string, sessionId string) error
// CleanupExpiredSessions removes expired sessions
CleanupExpiredSessions(ctx context.Context) error
// CleanupExpiredSessions removes expired sessions (filerAddress ignored for memory stores)
CleanupExpiredSessions(ctx context.Context, filerAddress string) error
} }
// NewSTSService creates a new STS service // NewSTSService creates a new STS service
@ -300,7 +300,7 @@ func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error
} }
// AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC) // AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC)
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) {
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, filerAddress string, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) {
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf(ErrSTSServiceNotInitialized) return nil, fmt.Errorf(ErrSTSServiceNotInitialized)
} }
@ -361,7 +361,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
} }
// 6. Store session information // 6. Store session information
if err := s.sessionStore.StoreSession(ctx, sessionId, session); err != nil {
if err := s.sessionStore.StoreSession(ctx, filerAddress, sessionId, session); err != nil {
return nil, fmt.Errorf("failed to store session: %w", err) return nil, fmt.Errorf("failed to store session: %w", err)
} }
@ -379,7 +379,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
} }
// AssumeRoleWithCredentials assumes a role using username/password credentials // AssumeRoleWithCredentials assumes a role using username/password credentials
func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) {
func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, filerAddress string, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) {
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf("STS service not initialized") return nil, fmt.Errorf("STS service not initialized")
} }
@ -447,7 +447,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
} }
// 7. Store session information // 7. Store session information
if err := s.sessionStore.StoreSession(ctx, sessionId, session); err != nil {
if err := s.sessionStore.StoreSession(ctx, filerAddress, sessionId, session); err != nil {
return nil, fmt.Errorf("failed to store session: %w", err) return nil, fmt.Errorf("failed to store session: %w", err)
} }
@ -465,7 +465,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
} }
// ValidateSessionToken validates a session token and returns session information // ValidateSessionToken validates a session token and returns session information
func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) {
func (s *STSService) ValidateSessionToken(ctx context.Context, filerAddress string, sessionToken string) (*SessionInfo, error) {
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf(ErrSTSServiceNotInitialized) return nil, fmt.Errorf(ErrSTSServiceNotInitialized)
} }
@ -481,7 +481,7 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri
} }
// Retrieve session from store using session ID from claims // Retrieve session from store using session ID from claims
session, err := s.sessionStore.GetSession(ctx, claims.SessionId)
session, err := s.sessionStore.GetSession(ctx, filerAddress, claims.SessionId)
if err != nil { if err != nil {
return nil, fmt.Errorf(ErrSessionValidationFailed, err) return nil, fmt.Errorf(ErrSessionValidationFailed, err)
} }
@ -495,7 +495,7 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri
} }
// RevokeSession revokes an active session // RevokeSession revokes an active session
func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) error {
func (s *STSService) RevokeSession(ctx context.Context, filerAddress string, sessionToken string) error {
if !s.initialized { if !s.initialized {
return fmt.Errorf("STS service not initialized") return fmt.Errorf("STS service not initialized")
} }
@ -511,7 +511,7 @@ func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) err
} }
// Remove session from store using session ID from claims // Remove session from store using session ID from claims
err = s.sessionStore.RevokeSession(ctx, claims.SessionId)
err = s.sessionStore.RevokeSession(ctx, filerAddress, claims.SessionId)
if err != nil { if err != nil {
return fmt.Errorf("failed to revoke session: %w", err) return fmt.Errorf("failed to revoke session: %w", err)
} }
@ -679,7 +679,7 @@ func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRol
} }
// ExpireSessionForTesting manually expires a session for testing purposes // ExpireSessionForTesting manually expires a session for testing purposes
func (s *STSService) ExpireSessionForTesting(ctx context.Context, sessionToken string) error {
func (s *STSService) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionToken string) error {
if !s.initialized { if !s.initialized {
return fmt.Errorf("STS service not initialized") return fmt.Errorf("STS service not initialized")
} }
@ -696,7 +696,7 @@ func (s *STSService) ExpireSessionForTesting(ctx context.Context, sessionToken s
// Check if session store supports manual expiration (for MemorySessionStore) // Check if session store supports manual expiration (for MemorySessionStore)
if memStore, ok := s.sessionStore.(*MemorySessionStore); ok { if memStore, ok := s.sessionStore.(*MemorySessionStore); ok {
return memStore.ExpireSessionForTesting(ctx, sessionId)
return memStore.ExpireSessionForTesting(ctx, filerAddress, sessionId)
} }
// For other session stores, we could implement similar functionality // For other session stores, we could implement similar functionality

20
weed/iam/sts/sts_service_test.go

@ -67,6 +67,7 @@ func TestSTSServiceInitialization(t *testing.T) {
// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens // TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
func TestAssumeRoleWithWebIdentity(t *testing.T) { func TestAssumeRoleWithWebIdentity(t *testing.T) {
service := setupTestSTSService(t) service := setupTestSTSService(t)
testFilerAddress := "localhost:8888" // Dummy filer address for testing
tests := []struct { tests := []struct {
name string name string
@ -121,7 +122,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) {
DurationSeconds: tt.durationSeconds, DurationSeconds: tt.durationSeconds,
} }
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
@ -155,6 +156,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) {
// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials // TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
func TestAssumeRoleWithLDAP(t *testing.T) { func TestAssumeRoleWithLDAP(t *testing.T) {
service := setupTestSTSService(t) service := setupTestSTSService(t)
testFilerAddress := "localhost:8888" // Dummy filer address for testing
tests := []struct { tests := []struct {
name string name string
@ -194,7 +196,7 @@ func TestAssumeRoleWithLDAP(t *testing.T) {
ProviderName: "test-ldap", ProviderName: "test-ldap",
} }
response, err := service.AssumeRoleWithCredentials(ctx, request)
response, err := service.AssumeRoleWithCredentials(ctx, testFilerAddress, request)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
@ -212,6 +214,7 @@ func TestAssumeRoleWithLDAP(t *testing.T) {
func TestSessionTokenValidation(t *testing.T) { func TestSessionTokenValidation(t *testing.T) {
service := setupTestSTSService(t) service := setupTestSTSService(t)
ctx := context.Background() ctx := context.Background()
testFilerAddress := "localhost:8888" // Dummy filer address for testing
// First, create a session // First, create a session
request := &AssumeRoleWithWebIdentityRequest{ request := &AssumeRoleWithWebIdentityRequest{
@ -220,7 +223,7 @@ func TestSessionTokenValidation(t *testing.T) {
RoleSessionName: "test-session", RoleSessionName: "test-session",
} }
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, response) require.NotNil(t, response)
@ -250,7 +253,7 @@ func TestSessionTokenValidation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
session, err := service.ValidateSessionToken(ctx, tt.token)
session, err := service.ValidateSessionToken(ctx, testFilerAddress, tt.token)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
@ -269,6 +272,7 @@ func TestSessionTokenValidation(t *testing.T) {
func TestSessionRevocation(t *testing.T) { func TestSessionRevocation(t *testing.T) {
service := setupTestSTSService(t) service := setupTestSTSService(t)
ctx := context.Background() ctx := context.Background()
testFilerAddress := "localhost:8888" // Dummy filer address for testing
// Create a session first // Create a session first
request := &AssumeRoleWithWebIdentityRequest{ request := &AssumeRoleWithWebIdentityRequest{
@ -277,22 +281,22 @@ func TestSessionRevocation(t *testing.T) {
RoleSessionName: "test-session", RoleSessionName: "test-session",
} }
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request)
require.NoError(t, err) require.NoError(t, err)
sessionToken := response.Credentials.SessionToken sessionToken := response.Credentials.SessionToken
// Verify token is valid before revocation // Verify token is valid before revocation
session, err := service.ValidateSessionToken(ctx, sessionToken)
session, err := service.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, session) assert.NotNil(t, session)
// Revoke the session // Revoke the session
err = service.RevokeSession(ctx, sessionToken)
err = service.RevokeSession(ctx, testFilerAddress, sessionToken)
assert.NoError(t, err) assert.NoError(t, err)
// Verify token is no longer valid after revocation // Verify token is no longer valid after revocation
session, err = service.ValidateSessionToken(ctx, sessionToken)
session, err = service.ValidateSessionToken(ctx, testFilerAddress, sessionToken)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, session) assert.Nil(t, session)
} }

Loading…
Cancel
Save