Browse Source

TDD RED Phase: Security Token Service (STS) foundation

Phase 2 of Advanced IAM Development Plan - STS Implementation

 WHAT WAS CREATED:
- Complete STS service interface with comprehensive test coverage
- AssumeRoleWithWebIdentity (OIDC) and AssumeRoleWithCredentials (LDAP) APIs
- Session token validation and revocation functionality
- Multiple session store implementations (Memory + Filer)
- Professional AWS STS-compatible API structures

 TDD RED PHASE RESULTS:
- All tests compile successfully - interfaces are correct
- Basic initialization tests PASS as expected
- Feature tests FAIL with honest 'not implemented yet' errors
- Production code doesn't lie about its capabilities

📋 COMPREHENSIVE TEST COVERAGE:
- STS service initialization and configuration validation
- Role assumption with OIDC tokens (various scenarios)
- Role assumption with LDAP credentials
- Session token validation and expiration
- Session revocation and cleanup
- Mock providers for isolated testing

🎯 NEXT STEPS (GREEN Phase):
- Implement real JWT token generation and validation
- Build role assumption logic with provider integration
- Create session management and storage
- Add security validations and error handling

This establishes the complete STS foundation with failing tests
that will guide implementation in the GREEN phase.
pull/7160/head
chrislu 1 month ago
parent
commit
c35b75e7c0
  1. 144
      weed/iam/sts/session_store.go
  2. 317
      weed/iam/sts/sts_service.go
  3. 389
      weed/iam/sts/sts_service_test.go

144
weed/iam/sts/session_store.go

@ -0,0 +1,144 @@
package sts
import (
"context"
"fmt"
"sync"
"time"
)
// MemorySessionStore implements SessionStore using in-memory storage
type MemorySessionStore struct {
sessions map[string]*SessionInfo
mutex sync.RWMutex
}
// NewMemorySessionStore creates a new memory-based session store
func NewMemorySessionStore() *MemorySessionStore {
return &MemorySessionStore{
sessions: make(map[string]*SessionInfo),
}
}
// StoreSession stores session information in memory
func (m *MemorySessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error {
if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
}
if session == nil {
return fmt.Errorf("session cannot be nil")
}
m.mutex.Lock()
defer m.mutex.Unlock()
m.sessions[sessionId] = session
return nil
}
// GetSession retrieves session information from memory
func (m *MemorySessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) {
if sessionId == "" {
return nil, fmt.Errorf("session ID cannot be empty")
}
m.mutex.RLock()
defer m.mutex.RUnlock()
session, exists := m.sessions[sessionId]
if !exists {
return nil, fmt.Errorf("session not found")
}
// Check if session has expired
if time.Now().After(session.ExpiresAt) {
return nil, fmt.Errorf("session has expired")
}
return session, nil
}
// RevokeSession revokes a session from memory
func (m *MemorySessionStore) RevokeSession(ctx context.Context, sessionId string) error {
if sessionId == "" {
return fmt.Errorf("session ID cannot be empty")
}
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.sessions, sessionId)
return nil
}
// CleanupExpiredSessions removes expired sessions from memory
func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context) error {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
for sessionId, session := range m.sessions {
if now.After(session.ExpiresAt) {
delete(m.sessions, sessionId)
}
}
return nil
}
// FilerSessionStore implements SessionStore using SeaweedFS filer
type FilerSessionStore struct {
// TODO: Add filer client configuration
basePath string
}
// NewFilerSessionStore creates a new filer-based session store
func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) {
// TODO: Implement filer session store initialization
// 1. Parse configuration for filer connection
// 2. Set up filer client
// 3. Configure base path for session storage
return nil, fmt.Errorf("filer session store not implemented yet")
}
// StoreSession stores session information in filer
func (f *FilerSessionStore) StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error {
// TODO: Implement filer session storage
// 1. Serialize session information to JSON/protobuf
// 2. Store in filer at configured path + sessionId
// 3. Handle errors and retries
return fmt.Errorf("filer session storage not implemented yet")
}
// GetSession retrieves session information from filer
func (f *FilerSessionStore) GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) {
// TODO: Implement filer session retrieval
// 1. Read session data from filer
// 2. Deserialize JSON/protobuf to SessionInfo
// 3. Check expiration
// 4. Handle not found cases
return nil, fmt.Errorf("filer session retrieval not implemented yet")
}
// RevokeSession revokes a session from filer
func (f *FilerSessionStore) RevokeSession(ctx context.Context, sessionId string) error {
// TODO: Implement filer session revocation
// 1. Delete session file from filer
// 2. Handle errors
return fmt.Errorf("filer session revocation not implemented yet")
}
// CleanupExpiredSessions removes expired sessions from filer
func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context) error {
// TODO: Implement filer session cleanup
// 1. List all session files in base path
// 2. Read and check expiration times
// 3. Delete expired sessions
return fmt.Errorf("filer session cleanup not implemented yet")
}

317
weed/iam/sts/sts_service.go

@ -0,0 +1,317 @@
package sts
import (
"context"
"fmt"
"time"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
)
// STSService provides Security Token Service functionality
type STSService struct {
config *STSConfig
initialized bool
providers map[string]providers.IdentityProvider
sessionStore SessionStore
}
// STSConfig holds STS service configuration
type STSConfig struct {
// TokenDuration is the default duration for issued tokens
TokenDuration time.Duration `json:"tokenDuration"`
// MaxSessionLength is the maximum duration for any session
MaxSessionLength time.Duration `json:"maxSessionLength"`
// Issuer is the STS issuer identifier
Issuer string `json:"issuer"`
// SigningKey is used to sign session tokens
SigningKey []byte `json:"signingKey"`
// SessionStore configuration
SessionStoreType string `json:"sessionStoreType"` // memory, filer, redis
SessionStoreConfig map[string]interface{} `json:"sessionStoreConfig,omitempty"`
}
// AssumeRoleWithWebIdentityRequest represents a request to assume role with web identity
type AssumeRoleWithWebIdentityRequest struct {
// RoleArn is the ARN of the role to assume
RoleArn string `json:"RoleArn"`
// WebIdentityToken is the OIDC token from the identity provider
WebIdentityToken string `json:"WebIdentityToken"`
// RoleSessionName is a name for the assumed role session
RoleSessionName string `json:"RoleSessionName"`
// DurationSeconds is the duration of the role session (optional)
DurationSeconds *int64 `json:"DurationSeconds,omitempty"`
// Policy is an optional session policy (optional)
Policy *string `json:"Policy,omitempty"`
}
// AssumeRoleWithCredentialsRequest represents a request to assume role with username/password
type AssumeRoleWithCredentialsRequest struct {
// RoleArn is the ARN of the role to assume
RoleArn string `json:"RoleArn"`
// Username is the username for authentication
Username string `json:"Username"`
// Password is the password for authentication
Password string `json:"Password"`
// RoleSessionName is a name for the assumed role session
RoleSessionName string `json:"RoleSessionName"`
// ProviderName is the name of the identity provider to use
ProviderName string `json:"ProviderName"`
// DurationSeconds is the duration of the role session (optional)
DurationSeconds *int64 `json:"DurationSeconds,omitempty"`
}
// AssumeRoleResponse represents the response from assume role operations
type AssumeRoleResponse struct {
// Credentials contains the temporary security credentials
Credentials *Credentials `json:"Credentials"`
// AssumedRoleUser contains information about the assumed role user
AssumedRoleUser *AssumedRoleUser `json:"AssumedRoleUser"`
// PackedPolicySize is the percentage of max policy size used (AWS compatibility)
PackedPolicySize *int64 `json:"PackedPolicySize,omitempty"`
}
// Credentials represents temporary security credentials
type Credentials struct {
// AccessKeyId is the access key ID
AccessKeyId string `json:"AccessKeyId"`
// SecretAccessKey is the secret access key
SecretAccessKey string `json:"SecretAccessKey"`
// SessionToken is the session token
SessionToken string `json:"SessionToken"`
// Expiration is when the credentials expire
Expiration time.Time `json:"Expiration"`
}
// AssumedRoleUser contains information about the assumed role user
type AssumedRoleUser struct {
// AssumedRoleId is the unique identifier of the assumed role
AssumedRoleId string `json:"AssumedRoleId"`
// Arn is the ARN of the assumed role user
Arn string `json:"Arn"`
// Subject is the subject identifier from the identity provider
Subject string `json:"Subject,omitempty"`
}
// SessionInfo represents information about an active session
type SessionInfo struct {
// SessionId is the unique identifier for the session
SessionId string `json:"sessionId"`
// SessionName is the name of the role session
SessionName string `json:"sessionName"`
// RoleArn is the ARN of the assumed role
RoleArn string `json:"roleArn"`
// Subject is the subject identifier from the identity provider
Subject string `json:"subject"`
// Provider is the identity provider used
Provider string `json:"provider"`
// CreatedAt is when the session was created
CreatedAt time.Time `json:"createdAt"`
// ExpiresAt is when the session expires
ExpiresAt time.Time `json:"expiresAt"`
// Credentials are the temporary credentials for this session
Credentials *Credentials `json:"credentials"`
}
// SessionStore defines the interface for storing session information
type SessionStore interface {
// StoreSession stores session information
StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error
// GetSession retrieves session information
GetSession(ctx context.Context, sessionId string) (*SessionInfo, error)
// RevokeSession revokes a session
RevokeSession(ctx context.Context, sessionId string) error
// CleanupExpiredSessions removes expired sessions
CleanupExpiredSessions(ctx context.Context) error
}
// NewSTSService creates a new STS service
func NewSTSService() *STSService {
return &STSService{
providers: make(map[string]providers.IdentityProvider),
}
}
// Initialize initializes the STS service with configuration
func (s *STSService) Initialize(config *STSConfig) error {
if config == nil {
return fmt.Errorf("config cannot be nil")
}
if err := s.validateConfig(config); err != nil {
return fmt.Errorf("invalid STS configuration: %w", err)
}
s.config = config
// Initialize session store
sessionStore, err := s.createSessionStore(config)
if err != nil {
return fmt.Errorf("failed to initialize session store: %w", err)
}
s.sessionStore = sessionStore
s.initialized = true
return nil
}
// validateConfig validates the STS configuration
func (s *STSService) validateConfig(config *STSConfig) error {
if config.TokenDuration <= 0 {
return fmt.Errorf("token duration must be positive")
}
if config.MaxSessionLength <= 0 {
return fmt.Errorf("max session length must be positive")
}
if config.Issuer == "" {
return fmt.Errorf("issuer is required")
}
if len(config.SigningKey) < 16 {
return fmt.Errorf("signing key must be at least 16 bytes")
}
return nil
}
// createSessionStore creates a session store based on configuration
func (s *STSService) createSessionStore(config *STSConfig) (SessionStore, error) {
switch config.SessionStoreType {
case "", "memory":
return NewMemorySessionStore(), nil
case "filer":
return NewFilerSessionStore(config.SessionStoreConfig)
default:
return nil, fmt.Errorf("unsupported session store type: %s", config.SessionStoreType)
}
}
// IsInitialized returns whether the service is initialized
func (s *STSService) IsInitialized() bool {
return s.initialized
}
// RegisterProvider registers an identity provider
func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error {
if provider == nil {
return fmt.Errorf("provider cannot be nil")
}
name := provider.Name()
if name == "" {
return fmt.Errorf("provider name cannot be empty")
}
s.providers[name] = provider
return nil
}
// AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC)
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) {
if !s.initialized {
return nil, fmt.Errorf("STS service not initialized")
}
if request == nil {
return nil, fmt.Errorf("request cannot be nil")
}
// TODO: Implement AssumeRoleWithWebIdentity
// 1. Validate the web identity token with appropriate provider
// 2. Check if the role exists and can be assumed
// 3. Generate temporary credentials
// 4. Create and store session information
// 5. Return response with credentials
return nil, fmt.Errorf("AssumeRoleWithWebIdentity not implemented yet")
}
// AssumeRoleWithCredentials assumes a role using username/password credentials
func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *AssumeRoleWithCredentialsRequest) (*AssumeRoleResponse, error) {
if !s.initialized {
return nil, fmt.Errorf("STS service not initialized")
}
if request == nil {
return nil, fmt.Errorf("request cannot be nil")
}
// TODO: Implement AssumeRoleWithCredentials
// 1. Validate credentials with the specified provider
// 2. Check if the role exists and can be assumed
// 3. Generate temporary credentials
// 4. Create and store session information
// 5. Return response with credentials
return nil, fmt.Errorf("AssumeRoleWithCredentials not implemented yet")
}
// ValidateSessionToken validates a session token and returns session information
func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) {
if !s.initialized {
return nil, fmt.Errorf("STS service not initialized")
}
if sessionToken == "" {
return nil, fmt.Errorf("session token cannot be empty")
}
// TODO: Implement session token validation
// 1. Parse and verify the session token
// 2. Extract session ID from token
// 3. Retrieve session from store
// 4. Validate expiration and other claims
// 5. Return session information
return nil, fmt.Errorf("session token validation not implemented yet")
}
// RevokeSession revokes an active session
func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) error {
if !s.initialized {
return fmt.Errorf("STS service not initialized")
}
if sessionToken == "" {
return fmt.Errorf("session token cannot be empty")
}
// TODO: Implement session revocation
// 1. Parse session token to extract session ID
// 2. Remove session from store
// 3. Add token to revocation list
return fmt.Errorf("session revocation not implemented yet")
}

389
weed/iam/sts/sts_service_test.go

@ -0,0 +1,389 @@
package sts
import (
"context"
"fmt"
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSTSServiceInitialization tests STS service initialization
func TestSTSServiceInitialization(t *testing.T) {
tests := []struct {
name string
config *STSConfig
wantErr bool
}{
{
name: "valid config",
config: &STSConfig{
TokenDuration: time.Hour,
MaxSessionLength: time.Hour * 12,
Issuer: "seaweedfs-sts",
SigningKey: []byte("test-signing-key"),
},
wantErr: false,
},
{
name: "missing signing key",
config: &STSConfig{
TokenDuration: time.Hour,
Issuer: "seaweedfs-sts",
},
wantErr: true,
},
{
name: "invalid token duration",
config: &STSConfig{
TokenDuration: -time.Hour,
Issuer: "seaweedfs-sts",
SigningKey: []byte("test-key"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewSTSService()
err := service.Initialize(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.True(t, service.IsInitialized())
}
})
}
}
// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
func TestAssumeRoleWithWebIdentity(t *testing.T) {
service := setupTestSTSService(t)
tests := []struct {
name string
roleArn string
webIdentityToken string
sessionName string
durationSeconds *int64
wantErr bool
expectedSubject string
}{
{
name: "successful role assumption",
roleArn: "arn:seaweed:iam::role/TestRole",
webIdentityToken: "valid-oidc-token",
sessionName: "test-session",
durationSeconds: nil, // Use default
wantErr: false,
expectedSubject: "test-user-id",
},
{
name: "invalid web identity token",
roleArn: "arn:seaweed:iam::role/TestRole",
webIdentityToken: "invalid-token",
sessionName: "test-session",
wantErr: true,
},
{
name: "non-existent role",
roleArn: "arn:seaweed:iam::role/NonExistentRole",
webIdentityToken: "valid-oidc-token",
sessionName: "test-session",
wantErr: true,
},
{
name: "custom session duration",
roleArn: "arn:seaweed:iam::role/TestRole",
webIdentityToken: "valid-oidc-token",
sessionName: "test-session",
durationSeconds: int64Ptr(7200), // 2 hours
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: tt.roleArn,
WebIdentityToken: tt.webIdentityToken,
RoleSessionName: tt.sessionName,
DurationSeconds: tt.durationSeconds,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, response)
} else {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.NotNil(t, response.Credentials)
assert.NotNil(t, response.AssumedRoleUser)
// Verify credentials
creds := response.Credentials
assert.NotEmpty(t, creds.AccessKeyId)
assert.NotEmpty(t, creds.SecretAccessKey)
assert.NotEmpty(t, creds.SessionToken)
assert.True(t, creds.Expiration.After(time.Now()))
// Verify assumed role user
user := response.AssumedRoleUser
assert.Equal(t, tt.roleArn, user.AssumedRoleId)
assert.Contains(t, user.Arn, tt.sessionName)
if tt.expectedSubject != "" {
assert.Equal(t, tt.expectedSubject, user.Subject)
}
}
})
}
}
// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
func TestAssumeRoleWithLDAP(t *testing.T) {
service := setupTestSTSService(t)
tests := []struct {
name string
roleArn string
username string
password string
sessionName string
wantErr bool
}{
{
name: "successful LDAP role assumption",
roleArn: "arn:seaweed:iam::role/LDAPRole",
username: "testuser",
password: "testpass",
sessionName: "ldap-session",
wantErr: false,
},
{
name: "invalid LDAP credentials",
roleArn: "arn:seaweed:iam::role/LDAPRole",
username: "testuser",
password: "wrongpass",
sessionName: "ldap-session",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
request := &AssumeRoleWithCredentialsRequest{
RoleArn: tt.roleArn,
Username: tt.username,
Password: tt.password,
RoleSessionName: tt.sessionName,
ProviderName: "test-ldap",
}
response, err := service.AssumeRoleWithCredentials(ctx, request)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, response)
} else {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.NotNil(t, response.Credentials)
}
})
}
}
// TestSessionTokenValidation tests session token validation
func TestSessionTokenValidation(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
// First, create a session
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:seaweed:iam::role/TestRole",
WebIdentityToken: "valid-oidc-token",
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionToken := response.Credentials.SessionToken
tests := []struct {
name string
token string
wantErr bool
}{
{
name: "valid session token",
token: sessionToken,
wantErr: false,
},
{
name: "invalid session token",
token: "invalid-session-token",
wantErr: true,
},
{
name: "empty session token",
token: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session, err := service.ValidateSessionToken(ctx, tt.token)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, session)
} else {
assert.NoError(t, err)
assert.NotNil(t, session)
assert.Equal(t, "test-session", session.SessionName)
assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn)
}
})
}
}
// TestSessionRevocation tests session token revocation
func TestSessionRevocation(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
// Create a session first
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:seaweed:iam::role/TestRole",
WebIdentityToken: "valid-oidc-token",
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
sessionToken := response.Credentials.SessionToken
// Verify token is valid before revocation
session, err := service.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err)
assert.NotNil(t, session)
// Revoke the session
err = service.RevokeSession(ctx, sessionToken)
assert.NoError(t, err)
// Verify token is no longer valid after revocation
session, err = service.ValidateSessionToken(ctx, sessionToken)
assert.Error(t, err)
assert.Nil(t, session)
}
// Helper functions
func setupTestSTSService(t *testing.T) *STSService {
service := NewSTSService()
config := &STSConfig{
TokenDuration: time.Hour,
MaxSessionLength: time.Hour * 12,
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Register test providers
mockOIDCProvider := &MockIdentityProvider{
name: "test-oidc",
validTokens: map[string]*providers.TokenClaims{
"valid-oidc-token": {
Subject: "test-user-id",
Issuer: "test-issuer",
Claims: map[string]interface{}{
"email": "test@example.com",
"name": "Test User",
},
},
},
}
mockLDAPProvider := &MockIdentityProvider{
name: "test-ldap",
validCredentials: map[string]string{
"testuser": "testpass",
},
}
service.RegisterProvider(mockOIDCProvider)
service.RegisterProvider(mockLDAPProvider)
return service
}
func int64Ptr(v int64) *int64 {
return &v
}
// Mock identity provider for testing
type MockIdentityProvider struct {
name string
validTokens map[string]*providers.TokenClaims
validCredentials map[string]string
}
func (m *MockIdentityProvider) Name() string {
return m.name
}
func (m *MockIdentityProvider) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
if claims, exists := m.validTokens[token]; exists {
email, _ := claims.GetClaimString("email")
name, _ := claims.GetClaimString("name")
return &providers.ExternalIdentity{
UserID: claims.Subject,
Email: email,
DisplayName: name,
Provider: m.name,
}, nil
}
return nil, fmt.Errorf("invalid token")
}
func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Email: userID + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if claims, exists := m.validTokens[token]; exists {
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
Loading…
Cancel
Save