diff --git a/weed/iam/sts/session_store.go b/weed/iam/sts/session_store.go new file mode 100644 index 000000000..f646caf02 --- /dev/null +++ b/weed/iam/sts/session_store.go @@ -0,0 +1,144 @@ +package sts + +import ( + "context" + "fmt" + "sync" + "time" +) + +// MemorySessionStore implements SessionStore using in-memory storage +type MemorySessionStore struct { + sessions map[string]*SessionInfo + mutex sync.RWMutex +} + +// NewMemorySessionStore creates a new memory-based session store +func NewMemorySessionStore() *MemorySessionStore { + return &MemorySessionStore{ + sessions: make(map[string]*SessionInfo), + } +} + +// StoreSession stores session information in memory +func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error { + if sessionId == "" { + return fmt.Errorf("session ID cannot be empty") + } + + if session == nil { + return fmt.Errorf("session cannot be nil") + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + m.sessions[sessionId] = session + return nil +} + +// GetSession retrieves session information from memory +func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) { + if sessionId == "" { + return nil, fmt.Errorf("session ID cannot be empty") + } + + m.mutex.RLock() + defer m.mutex.RUnlock() + + session, exists := m.sessions[sessionId] + if !exists { + return nil, fmt.Errorf("session not found") + } + + // Check if session has expired + if time.Now().After(session.ExpiresAt) { + return nil, fmt.Errorf("session has expired") + } + + return session, nil +} + +// RevokeSession revokes a session from memory +func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string) error { + if sessionId == "" { + return fmt.Errorf("session ID cannot be empty") + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.sessions, sessionId) + return nil +} + +// CleanupExpiredSessions removes expired sessions from memory +func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + now := time.Now() + for sessionId, session := range m.sessions { + if now.After(session.ExpiresAt) { + delete(m.sessions, sessionId) + } + } + + return nil +} + +// FilerSessionStore implements SessionStore using SeaweedFS filer +type FilerSessionStore struct { + // TODO: Add filer client configuration + basePath string +} + +// NewFilerSessionStore creates a new filer-based session store +func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) { + // TODO: Implement filer session store initialization + // 1. Parse configuration for filer connection + // 2. Set up filer client + // 3. Configure base path for session storage + + return nil, fmt.Errorf("filer session store not implemented yet") +} + +// StoreSession stores session information in filer +func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error { + // TODO: Implement filer session storage + // 1. Serialize session information to JSON/protobuf + // 2. Store in filer at configured path + sessionId + // 3. Handle errors and retries + + return fmt.Errorf("filer session storage not implemented yet") +} + +// GetSession retrieves session information from filer +func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) { + // TODO: Implement filer session retrieval + // 1. Read session data from filer + // 2. Deserialize JSON/protobuf to SessionInfo + // 3. Check expiration + // 4. Handle not found cases + + return nil, fmt.Errorf("filer session retrieval not implemented yet") +} + +// RevokeSession revokes a session from filer +func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string) error { + // TODO: Implement filer session revocation + // 1. Delete session file from filer + // 2. Handle errors + + return fmt.Errorf("filer session revocation not implemented yet") +} + +// CleanupExpiredSessions removes expired sessions from filer +func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error { + // TODO: Implement filer session cleanup + // 1. List all session files in base path + // 2. Read and check expiration times + // 3. Delete expired sessions + + return fmt.Errorf("filer session cleanup not implemented yet") +} diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go new file mode 100644 index 000000000..25aae268d --- /dev/null +++ b/weed/iam/sts/sts_service.go @@ -0,0 +1,317 @@ +package sts + +import ( + "context" + "fmt" + "time" + + "github.com/seaweedfs/seaweedfs/weed/iam/providers" +) + +// STSService provides Security Token Service functionality +type STSService struct { + config *STSConfig + initialized bool + providers map[string]providers.IdentityProvider + sessionStore SessionStore +} + +// STSConfig holds STS service configuration +type STSConfig struct { + // TokenDuration is the default duration for issued tokens + TokenDuration time.Duration `json:"tokenDuration"` + + // MaxSessionLength is the maximum duration for any session + MaxSessionLength time.Duration `json:"maxSessionLength"` + + // Issuer is the STS issuer identifier + Issuer string `json:"issuer"` + + // SigningKey is used to sign session tokens + SigningKey []byte `json:"signingKey"` + + // SessionStore configuration + SessionStoreType string `json:"sessionStoreType"` // memory, filer, redis + SessionStoreConfig map[string]interface{} `json:"sessionStoreConfig,omitempty"` +} + +// AssumeRoleWithWebIdentityRequest represents a request to assume role with web identity +type AssumeRoleWithWebIdentityRequest struct { + // RoleArn is the ARN of the role to assume + RoleArn string `json:"RoleArn"` + + // WebIdentityToken is the OIDC token from the identity provider + WebIdentityToken string `json:"WebIdentityToken"` + + // RoleSessionName is a name for the assumed role session + RoleSessionName string `json:"RoleSessionName"` + + // DurationSeconds is the duration of the role session (optional) + DurationSeconds *int64 `json:"DurationSeconds,omitempty"` + + // Policy is an optional session policy (optional) + Policy *string `json:"Policy,omitempty"` +} + +// AssumeRoleWithCredentialsRequest represents a request to assume role with username/password +type AssumeRoleWithCredentialsRequest struct { + // RoleArn is the ARN of the role to assume + RoleArn string `json:"RoleArn"` + + // Username is the username for authentication + Username string `json:"Username"` + + // Password is the password for authentication + Password string `json:"Password"` + + // RoleSessionName is a name for the assumed role session + RoleSessionName string `json:"RoleSessionName"` + + // ProviderName is the name of the identity provider to use + ProviderName string `json:"ProviderName"` + + // DurationSeconds is the duration of the role session (optional) + DurationSeconds *int64 `json:"DurationSeconds,omitempty"` +} + +// AssumeRoleResponse represents the response from assume role operations +type AssumeRoleResponse struct { + // Credentials contains the temporary security credentials + Credentials *Credentials `json:"Credentials"` + + // AssumedRoleUser contains information about the assumed role user + AssumedRoleUser *AssumedRoleUser `json:"AssumedRoleUser"` + + // PackedPolicySize is the percentage of max policy size used (AWS compatibility) + PackedPolicySize *int64 `json:"PackedPolicySize,omitempty"` +} + +// Credentials represents temporary security credentials +type Credentials struct { + // AccessKeyId is the access key ID + AccessKeyId string `json:"AccessKeyId"` + + // SecretAccessKey is the secret access key + SecretAccessKey string `json:"SecretAccessKey"` + + // SessionToken is the session token + SessionToken string `json:"SessionToken"` + + // Expiration is when the credentials expire + Expiration time.Time `json:"Expiration"` +} + +// AssumedRoleUser contains information about the assumed role user +type AssumedRoleUser struct { + // AssumedRoleId is the unique identifier of the assumed role + AssumedRoleId string `json:"AssumedRoleId"` + + // Arn is the ARN of the assumed role user + Arn string `json:"Arn"` + + // Subject is the subject identifier from the identity provider + Subject string `json:"Subject,omitempty"` +} + +// SessionInfo represents information about an active session +type SessionInfo struct { + // SessionId is the unique identifier for the session + SessionId string `json:"sessionId"` + + // SessionName is the name of the role session + SessionName string `json:"sessionName"` + + // RoleArn is the ARN of the assumed role + RoleArn string `json:"roleArn"` + + // Subject is the subject identifier from the identity provider + Subject string `json:"subject"` + + // Provider is the identity provider used + Provider string `json:"provider"` + + // CreatedAt is when the session was created + CreatedAt time.Time `json:"createdAt"` + + // ExpiresAt is when the session expires + ExpiresAt time.Time `json:"expiresAt"` + + // Credentials are the temporary credentials for this session + Credentials *Credentials `json:"credentials"` +} + +// SessionStore defines the interface for storing session information +type SessionStore interface { + // StoreSession stores session information + StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error + + // GetSession retrieves session information + GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) + + // RevokeSession revokes a session + RevokeSession(ctx context.Context, sessionId string) error + + // CleanupExpiredSessions removes expired sessions + CleanupExpiredSessions(ctx context.Context) error +} + +// NewSTSService creates a new STS service +func NewSTSService() *STSService { + return &STSService{ + providers: make(map[string]providers.IdentityProvider), + } +} + +// Initialize initializes the STS service with configuration +func (s *STSService) Initialize(config *STSConfig) error { + if config == nil { + return fmt.Errorf("config cannot be nil") + } + + if err := s.validateConfig(config); err != nil { + return fmt.Errorf("invalid STS configuration: %w", err) + } + + s.config = config + + // Initialize session store + sessionStore, err := s.createSessionStore(config) + if err != nil { + return fmt.Errorf("failed to initialize session store: %w", err) + } + s.sessionStore = sessionStore + + s.initialized = true + return nil +} + +// validateConfig validates the STS configuration +func (s *STSService) validateConfig(config *STSConfig) error { + if config.TokenDuration <= 0 { + return fmt.Errorf("token duration must be positive") + } + + if config.MaxSessionLength <= 0 { + return fmt.Errorf("max session length must be positive") + } + + if config.Issuer == "" { + return fmt.Errorf("issuer is required") + } + + if len(config.SigningKey) < 16 { + return fmt.Errorf("signing key must be at least 16 bytes") + } + + return nil +} + +// createSessionStore creates a session store based on configuration +func (s *STSService) createSessionStore(config *STSConfig) (SessionStore, error) { + switch config.SessionStoreType { + case "", "memory": + return NewMemorySessionStore(), nil + case "filer": + return NewFilerSessionStore(config.SessionStoreConfig) + default: + return nil, fmt.Errorf("unsupported session store type: %s", config.SessionStoreType) + } +} + +// IsInitialized returns whether the service is initialized +func (s *STSService) IsInitialized() bool { + return s.initialized +} + +// RegisterProvider registers an identity provider +func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error { + if provider == nil { + return fmt.Errorf("provider cannot be nil") + } + + name := provider.Name() + if name == "" { + return fmt.Errorf("provider name cannot be empty") + } + + s.providers[name] = provider + return nil +} + +// AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC) +func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) { + if !s.initialized { + return nil, fmt.Errorf("STS service not initialized") + } + + if request == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + // TODO: Implement AssumeRoleWithWebIdentity + // 1. Validate the web identity token with appropriate provider + // 2. Check if the role exists and can be assumed + // 3. Generate temporary credentials + // 4. Create and store session information + // 5. Return response with credentials + + return nil, fmt.Errorf("AssumeRoleWithWebIdentity not implemented yet") +} + +// AssumeRoleWithCredentials assumes a role using username/password credentials +func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) { + if !s.initialized { + return nil, fmt.Errorf("STS service not initialized") + } + + if request == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + // TODO: Implement AssumeRoleWithCredentials + // 1. Validate credentials with the specified provider + // 2. Check if the role exists and can be assumed + // 3. Generate temporary credentials + // 4. Create and store session information + // 5. Return response with credentials + + return nil, fmt.Errorf("AssumeRoleWithCredentials not implemented yet") +} + +// ValidateSessionToken validates a session token and returns session information +func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) { + if !s.initialized { + return nil, fmt.Errorf("STS service not initialized") + } + + if sessionToken == "" { + return nil, fmt.Errorf("session token cannot be empty") + } + + // TODO: Implement session token validation + // 1. Parse and verify the session token + // 2. Extract session ID from token + // 3. Retrieve session from store + // 4. Validate expiration and other claims + // 5. Return session information + + return nil, fmt.Errorf("session token validation not implemented yet") +} + +// RevokeSession revokes an active session +func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) error { + if !s.initialized { + return fmt.Errorf("STS service not initialized") + } + + if sessionToken == "" { + return fmt.Errorf("session token cannot be empty") + } + + // TODO: Implement session revocation + // 1. Parse session token to extract session ID + // 2. Remove session from store + // 3. Add token to revocation list + + return fmt.Errorf("session revocation not implemented yet") +} diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go new file mode 100644 index 000000000..b92ee1828 --- /dev/null +++ b/weed/iam/sts/sts_service_test.go @@ -0,0 +1,389 @@ +package sts + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/seaweedfs/seaweedfs/weed/iam/providers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestSTSServiceInitialization tests STS service initialization +func TestSTSServiceInitialization(t *testing.T) { + tests := []struct { + name string + config *STSConfig + wantErr bool + }{ + { + name: "valid config", + config: &STSConfig{ + TokenDuration: time.Hour, + MaxSessionLength: time.Hour * 12, + Issuer: "seaweedfs-sts", + SigningKey: []byte("test-signing-key"), + }, + wantErr: false, + }, + { + name: "missing signing key", + config: &STSConfig{ + TokenDuration: time.Hour, + Issuer: "seaweedfs-sts", + }, + wantErr: true, + }, + { + name: "invalid token duration", + config: &STSConfig{ + TokenDuration: -time.Hour, + Issuer: "seaweedfs-sts", + SigningKey: []byte("test-key"), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := NewSTSService() + + err := service.Initialize(tt.config) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.True(t, service.IsInitialized()) + } + }) + } +} + +// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens +func TestAssumeRoleWithWebIdentity(t *testing.T) { + service := setupTestSTSService(t) + + tests := []struct { + name string + roleArn string + webIdentityToken string + sessionName string + durationSeconds *int64 + wantErr bool + expectedSubject string + }{ + { + name: "successful role assumption", + roleArn: "arn:seaweed:iam::role/TestRole", + webIdentityToken: "valid-oidc-token", + sessionName: "test-session", + durationSeconds: nil, // Use default + wantErr: false, + expectedSubject: "test-user-id", + }, + { + name: "invalid web identity token", + roleArn: "arn:seaweed:iam::role/TestRole", + webIdentityToken: "invalid-token", + sessionName: "test-session", + wantErr: true, + }, + { + name: "non-existent role", + roleArn: "arn:seaweed:iam::role/NonExistentRole", + webIdentityToken: "valid-oidc-token", + sessionName: "test-session", + wantErr: true, + }, + { + name: "custom session duration", + roleArn: "arn:seaweed:iam::role/TestRole", + webIdentityToken: "valid-oidc-token", + sessionName: "test-session", + durationSeconds: int64Ptr(7200), // 2 hours + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + request := &AssumeRoleWithWebIdentityRequest{ + RoleArn: tt.roleArn, + WebIdentityToken: tt.webIdentityToken, + RoleSessionName: tt.sessionName, + DurationSeconds: tt.durationSeconds, + } + + response, err := service.AssumeRoleWithWebIdentity(ctx, request) + + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.NotNil(t, response) + assert.NotNil(t, response.Credentials) + assert.NotNil(t, response.AssumedRoleUser) + + // Verify credentials + creds := response.Credentials + assert.NotEmpty(t, creds.AccessKeyId) + assert.NotEmpty(t, creds.SecretAccessKey) + assert.NotEmpty(t, creds.SessionToken) + assert.True(t, creds.Expiration.After(time.Now())) + + // Verify assumed role user + user := response.AssumedRoleUser + assert.Equal(t, tt.roleArn, user.AssumedRoleId) + assert.Contains(t, user.Arn, tt.sessionName) + + if tt.expectedSubject != "" { + assert.Equal(t, tt.expectedSubject, user.Subject) + } + } + }) + } +} + +// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials +func TestAssumeRoleWithLDAP(t *testing.T) { + service := setupTestSTSService(t) + + tests := []struct { + name string + roleArn string + username string + password string + sessionName string + wantErr bool + }{ + { + name: "successful LDAP role assumption", + roleArn: "arn:seaweed:iam::role/LDAPRole", + username: "testuser", + password: "testpass", + sessionName: "ldap-session", + wantErr: false, + }, + { + name: "invalid LDAP credentials", + roleArn: "arn:seaweed:iam::role/LDAPRole", + username: "testuser", + password: "wrongpass", + sessionName: "ldap-session", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + request := &AssumeRoleWithCredentialsRequest{ + RoleArn: tt.roleArn, + Username: tt.username, + Password: tt.password, + RoleSessionName: tt.sessionName, + ProviderName: "test-ldap", + } + + response, err := service.AssumeRoleWithCredentials(ctx, request) + + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.NotNil(t, response) + assert.NotNil(t, response.Credentials) + } + }) + } +} + +// TestSessionTokenValidation tests session token validation +func TestSessionTokenValidation(t *testing.T) { + service := setupTestSTSService(t) + ctx := context.Background() + + // First, create a session + request := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:seaweed:iam::role/TestRole", + WebIdentityToken: "valid-oidc-token", + RoleSessionName: "test-session", + } + + response, err := service.AssumeRoleWithWebIdentity(ctx, request) + require.NoError(t, err) + require.NotNil(t, response) + + sessionToken := response.Credentials.SessionToken + + tests := []struct { + name string + token string + wantErr bool + }{ + { + name: "valid session token", + token: sessionToken, + wantErr: false, + }, + { + name: "invalid session token", + token: "invalid-session-token", + wantErr: true, + }, + { + name: "empty session token", + token: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session, err := service.ValidateSessionToken(ctx, tt.token) + + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, session) + } else { + assert.NoError(t, err) + assert.NotNil(t, session) + assert.Equal(t, "test-session", session.SessionName) + assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn) + } + }) + } +} + +// TestSessionRevocation tests session token revocation +func TestSessionRevocation(t *testing.T) { + service := setupTestSTSService(t) + ctx := context.Background() + + // Create a session first + request := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:seaweed:iam::role/TestRole", + WebIdentityToken: "valid-oidc-token", + RoleSessionName: "test-session", + } + + response, err := service.AssumeRoleWithWebIdentity(ctx, request) + require.NoError(t, err) + + sessionToken := response.Credentials.SessionToken + + // Verify token is valid before revocation + session, err := service.ValidateSessionToken(ctx, sessionToken) + assert.NoError(t, err) + assert.NotNil(t, session) + + // Revoke the session + err = service.RevokeSession(ctx, sessionToken) + assert.NoError(t, err) + + // Verify token is no longer valid after revocation + session, err = service.ValidateSessionToken(ctx, sessionToken) + assert.Error(t, err) + assert.Nil(t, session) +} + +// Helper functions + +func setupTestSTSService(t *testing.T) *STSService { + service := NewSTSService() + + config := &STSConfig{ + TokenDuration: time.Hour, + MaxSessionLength: time.Hour * 12, + Issuer: "test-sts", + SigningKey: []byte("test-signing-key-32-characters-long"), + } + + err := service.Initialize(config) + require.NoError(t, err) + + // Register test providers + mockOIDCProvider := &MockIdentityProvider{ + name: "test-oidc", + validTokens: map[string]*providers.TokenClaims{ + "valid-oidc-token": { + Subject: "test-user-id", + Issuer: "test-issuer", + Claims: map[string]interface{}{ + "email": "test@example.com", + "name": "Test User", + }, + }, + }, + } + + mockLDAPProvider := &MockIdentityProvider{ + name: "test-ldap", + validCredentials: map[string]string{ + "testuser": "testpass", + }, + } + + service.RegisterProvider(mockOIDCProvider) + service.RegisterProvider(mockLDAPProvider) + + return service +} + +func int64Ptr(v int64) *int64 { + return &v +} + +// Mock identity provider for testing +type MockIdentityProvider struct { + name string + validTokens map[string]*providers.TokenClaims + validCredentials map[string]string +} + +func (m *MockIdentityProvider) Name() string { + return m.name +} + +func (m *MockIdentityProvider) Initialize(config interface{}) error { + return nil +} + +func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { + if claims, exists := m.validTokens[token]; exists { + email, _ := claims.GetClaimString("email") + name, _ := claims.GetClaimString("name") + + return &providers.ExternalIdentity{ + UserID: claims.Subject, + Email: email, + DisplayName: name, + Provider: m.name, + }, nil + } + return nil, fmt.Errorf("invalid token") +} + +func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { + return &providers.ExternalIdentity{ + UserID: userID, + Email: userID + "@" + m.name + ".com", + Provider: m.name, + }, nil +} + +func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { + if claims, exists := m.validTokens[token]; exists { + return claims, nil + } + return nil, fmt.Errorf("invalid token") +}