13 changed files with 40 additions and 2034 deletions
-
29test/s3/iam/DISTRIBUTED.md
-
13test/s3/iam/STS_DISTRIBUTED.md
-
8weed/iam/integration/iam_integration_test.go
-
866weed/iam/ldap/ldap_provider.go
-
360weed/iam/ldap/ldap_provider_test.go
-
18weed/iam/ldap/mock_provider.go
-
356weed/iam/sts/RUNTIME_FILER_ADDRESS.md
-
383weed/iam/sts/session_store.go
-
15weed/iam/sts/sts_service.go
-
8weed/s3api/s3_end_to_end_test.go
-
6weed/s3api/s3_jwt_auth_test.go
-
6weed/s3api/s3_multipart_iam_test.go
-
6weed/s3api/s3_presigned_url_iam_test.go
@ -1,866 +0,0 @@ |
|||
package ldap |
|||
|
|||
import ( |
|||
"context" |
|||
"crypto/tls" |
|||
"fmt" |
|||
"net" |
|||
"strings" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/glog" |
|||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|||
) |
|||
|
|||
// LDAPProvider implements LDAP authentication
|
|||
type LDAPProvider struct { |
|||
name string |
|||
config *LDAPConfig |
|||
initialized bool |
|||
connPool *LDAPConnectionPool |
|||
} |
|||
|
|||
// LDAPConnectionPool manages LDAP connections
|
|||
type LDAPConnectionPool struct { |
|||
config *LDAPConfig |
|||
connections chan *LDAPConn |
|||
mu sync.Mutex |
|||
maxConns int |
|||
} |
|||
|
|||
// LDAPConn represents an LDAP connection (simplified implementation)
|
|||
type LDAPConn struct { |
|||
serverAddr string |
|||
conn net.Conn |
|||
bound bool |
|||
tlsConfig *tls.Config |
|||
} |
|||
|
|||
// LDAPSearchResult represents LDAP search results
|
|||
type LDAPSearchResult struct { |
|||
Entries []*LDAPEntry |
|||
} |
|||
|
|||
// LDAPEntry represents an LDAP directory entry
|
|||
type LDAPEntry struct { |
|||
DN string |
|||
Attributes []*LDAPAttribute |
|||
} |
|||
|
|||
// LDAPAttribute represents an LDAP attribute
|
|||
type LDAPAttribute struct { |
|||
Name string |
|||
Values []string |
|||
} |
|||
|
|||
// LDAPSearchRequest represents an LDAP search request
|
|||
type LDAPSearchRequest struct { |
|||
BaseDN string |
|||
Scope int |
|||
DerefAliases int |
|||
SizeLimit int |
|||
TimeLimit int |
|||
TypesOnly bool |
|||
Filter string |
|||
Attributes []string |
|||
} |
|||
|
|||
// LDAP search scope constants
|
|||
const ( |
|||
ScopeBaseObject = iota |
|||
ScopeWholeSubtree |
|||
NeverDerefAliases = 0 |
|||
) |
|||
|
|||
// LDAPConfig holds LDAP provider configuration
|
|||
type LDAPConfig struct { |
|||
// Server is the LDAP server URL (e.g., ldap://localhost:389)
|
|||
Server string `json:"server"` |
|||
|
|||
// BaseDN is the base distinguished name for searches
|
|||
BaseDN string `json:"baseDn"` |
|||
|
|||
// BindDN is the distinguished name for binding (authentication)
|
|||
BindDN string `json:"bindDn,omitempty"` |
|||
|
|||
// BindPass is the password for binding
|
|||
BindPass string `json:"bindPass,omitempty"` |
|||
|
|||
// UserFilter is the LDAP filter for finding users (e.g., "(sAMAccountName=%s)")
|
|||
UserFilter string `json:"userFilter"` |
|||
|
|||
// GroupFilter is the LDAP filter for finding groups (e.g., "(member=%s)")
|
|||
GroupFilter string `json:"groupFilter,omitempty"` |
|||
|
|||
// Attributes maps SeaweedFS identity fields to LDAP attributes
|
|||
Attributes map[string]string `json:"attributes,omitempty"` |
|||
|
|||
// RoleMapping defines how to map LDAP groups to roles
|
|||
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` |
|||
|
|||
// TLS configuration
|
|||
UseTLS bool `json:"useTls,omitempty"` |
|||
TLSCert string `json:"tlsCert,omitempty"` |
|||
TLSKey string `json:"tlsKey,omitempty"` |
|||
TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"` |
|||
|
|||
// Connection pool settings
|
|||
MaxConnections int `json:"maxConnections,omitempty"` |
|||
ConnTimeout int `json:"connTimeout,omitempty"` // seconds
|
|||
} |
|||
|
|||
// NewLDAPProvider creates a new LDAP provider
|
|||
func NewLDAPProvider(name string) *LDAPProvider { |
|||
return &LDAPProvider{ |
|||
name: name, |
|||
} |
|||
} |
|||
|
|||
// Name returns the provider name
|
|||
func (p *LDAPProvider) Name() string { |
|||
return p.name |
|||
} |
|||
|
|||
// Initialize initializes the LDAP provider with configuration
|
|||
func (p *LDAPProvider) Initialize(config interface{}) error { |
|||
if config == nil { |
|||
return fmt.Errorf("config cannot be nil") |
|||
} |
|||
|
|||
ldapConfig, ok := config.(*LDAPConfig) |
|||
if !ok { |
|||
return fmt.Errorf("invalid config type for LDAP provider") |
|||
} |
|||
|
|||
if err := p.validateConfig(ldapConfig); err != nil { |
|||
return fmt.Errorf("invalid LDAP configuration: %w", err) |
|||
} |
|||
|
|||
p.config = ldapConfig |
|||
|
|||
// Initialize LDAP connection pool
|
|||
pool, err := NewLDAPConnectionPool(ldapConfig) |
|||
if err != nil { |
|||
glog.V(2).Infof("Failed to initialize LDAP connection pool: %v (using mock for testing)", err) |
|||
// In case of connection failure, continue but mark as testing mode
|
|||
p.initialized = true |
|||
return nil |
|||
} |
|||
p.connPool = pool |
|||
|
|||
// Test connectivity with one connection
|
|||
conn, err := p.connPool.GetConnection() |
|||
if err != nil { |
|||
glog.V(2).Infof("Failed to establish test LDAP connection: %v (using mock for testing)", err) |
|||
p.initialized = true |
|||
return nil |
|||
} |
|||
p.connPool.ReleaseConnection(conn) |
|||
|
|||
p.initialized = true |
|||
glog.V(2).Infof("LDAP provider %s initialized with server %s", p.name, ldapConfig.Server) |
|||
return nil |
|||
} |
|||
|
|||
// validateConfig validates the LDAP configuration
|
|||
func (p *LDAPProvider) validateConfig(config *LDAPConfig) error { |
|||
if config.Server == "" { |
|||
return fmt.Errorf("server is required") |
|||
} |
|||
|
|||
if config.BaseDN == "" { |
|||
return fmt.Errorf("base DN is required") |
|||
} |
|||
|
|||
// Basic URL validation
|
|||
if !strings.HasPrefix(config.Server, "ldap://") && !strings.HasPrefix(config.Server, "ldaps://") { |
|||
return fmt.Errorf("invalid server URL format") |
|||
} |
|||
|
|||
// Set default user filter if not provided
|
|||
if config.UserFilter == "" { |
|||
config.UserFilter = "(uid=%s)" // Default LDAP user filter
|
|||
} |
|||
|
|||
// Set default attributes if not provided
|
|||
if config.Attributes == nil { |
|||
config.Attributes = map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "cn", |
|||
"groups": "memberOf", |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Authenticate authenticates a user with LDAP
|
|||
func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*providers.ExternalIdentity, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
if credentials == "" { |
|||
return nil, fmt.Errorf("credentials cannot be empty") |
|||
} |
|||
|
|||
// Parse credentials (username:password format)
|
|||
parts := strings.SplitN(credentials, ":", 2) |
|||
if len(parts) != 2 { |
|||
return nil, fmt.Errorf("invalid credentials format (expected username:password)") |
|||
} |
|||
|
|||
username, password := parts[0], parts[1] |
|||
|
|||
// Get connection from pool
|
|||
conn, err := p.getConnection() |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to get LDAP connection: %v", err) |
|||
} |
|||
defer p.releaseConnection(conn) |
|||
|
|||
// Perform LDAP bind with service account if configured
|
|||
if p.config.BindDN != "" && p.config.BindPass != "" { |
|||
err = conn.Bind(p.config.BindDN, p.config.BindPass) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to bind with service account: %v", err) |
|||
} |
|||
} |
|||
|
|||
// Search for user
|
|||
userFilter := fmt.Sprintf(p.config.UserFilter, EscapeFilter(username)) |
|||
searchRequest := &LDAPSearchRequest{ |
|||
BaseDN: p.config.BaseDN, |
|||
Scope: ScopeWholeSubtree, |
|||
DerefAliases: NeverDerefAliases, |
|||
SizeLimit: 0, |
|||
TimeLimit: 0, |
|||
TypesOnly: false, |
|||
Filter: userFilter, |
|||
Attributes: p.getSearchAttributes(), |
|||
} |
|||
|
|||
searchResult, err := conn.Search(searchRequest) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("LDAP search failed: %v", err) |
|||
} |
|||
|
|||
if len(searchResult.Entries) == 0 { |
|||
return nil, fmt.Errorf("user not found in LDAP: %s", username) |
|||
} |
|||
|
|||
if len(searchResult.Entries) > 1 { |
|||
return nil, fmt.Errorf("multiple users found for username: %s", username) |
|||
} |
|||
|
|||
userEntry := searchResult.Entries[0] |
|||
userDN := userEntry.DN |
|||
|
|||
// Authenticate user by binding with their credentials
|
|||
err = conn.Bind(userDN, password) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("LDAP authentication failed for user %s: %v", username, err) |
|||
} |
|||
|
|||
// Extract user attributes
|
|||
attributes := make(map[string][]string) |
|||
for _, attr := range userEntry.Attributes { |
|||
attributes[attr.Name] = attr.Values |
|||
} |
|||
|
|||
// Map to ExternalIdentity
|
|||
identity := p.mapLDAPAttributes(username, attributes) |
|||
identity.UserID = username |
|||
|
|||
// Get user groups if group filter is configured
|
|||
if p.config.GroupFilter != "" { |
|||
groups, err := p.getUserGroups(conn, userDN, username) |
|||
if err != nil { |
|||
glog.V(2).Infof("Failed to retrieve groups for user %s: %v", username, err) |
|||
} else { |
|||
identity.Groups = groups |
|||
} |
|||
} |
|||
|
|||
glog.V(3).Infof("LDAP authentication successful for user: %s", username) |
|||
return identity, nil |
|||
} |
|||
|
|||
// GetUserInfo retrieves user information from LDAP
|
|||
func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
if userID == "" { |
|||
return nil, fmt.Errorf("user ID cannot be empty") |
|||
} |
|||
|
|||
// Get connection from pool
|
|||
conn, err := p.getConnection() |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to get LDAP connection: %v", err) |
|||
} |
|||
defer p.releaseConnection(conn) |
|||
|
|||
// Perform LDAP bind with service account if configured
|
|||
if p.config.BindDN != "" && p.config.BindPass != "" { |
|||
err = conn.Bind(p.config.BindDN, p.config.BindPass) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to bind with service account: %v", err) |
|||
} |
|||
} |
|||
|
|||
// Search for user by userID using configured user filter
|
|||
userFilter := fmt.Sprintf(p.config.UserFilter, EscapeFilter(userID)) |
|||
searchRequest := &LDAPSearchRequest{ |
|||
BaseDN: p.config.BaseDN, |
|||
Scope: ScopeWholeSubtree, |
|||
DerefAliases: NeverDerefAliases, |
|||
SizeLimit: 1, // We only need one user
|
|||
TimeLimit: 30, // 30 second timeout
|
|||
TypesOnly: false, |
|||
Filter: userFilter, |
|||
Attributes: p.getSearchAttributes(), |
|||
} |
|||
|
|||
glog.V(3).Infof("Searching for user %s with filter: %s", userID, userFilter) |
|||
searchResult, err := conn.Search(searchRequest) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("LDAP user search failed: %v", err) |
|||
} |
|||
|
|||
if len(searchResult.Entries) == 0 { |
|||
return nil, fmt.Errorf("user not found in LDAP: %s", userID) |
|||
} |
|||
|
|||
if len(searchResult.Entries) > 1 { |
|||
glog.V(2).Infof("Multiple entries found for user %s, using first one", userID) |
|||
} |
|||
|
|||
userEntry := searchResult.Entries[0] |
|||
userDN := userEntry.DN |
|||
|
|||
glog.V(3).Infof("Found LDAP user: %s with DN: %s", userID, userDN) |
|||
|
|||
// Extract user attributes
|
|||
attributes := make(map[string][]string) |
|||
for _, attr := range userEntry.Attributes { |
|||
attributes[attr.Name] = attr.Values |
|||
} |
|||
|
|||
// Map to ExternalIdentity
|
|||
identity := p.mapLDAPAttributes(userID, attributes) |
|||
identity.UserID = userID |
|||
|
|||
// Get user groups if group filter is configured
|
|||
if p.config.GroupFilter != "" { |
|||
groups, err := p.getUserGroups(conn, userDN, userID) |
|||
if err != nil { |
|||
glog.V(2).Infof("Failed to retrieve groups for user %s: %v", userID, err) |
|||
} else { |
|||
identity.Groups = groups |
|||
} |
|||
} |
|||
|
|||
glog.V(3).Infof("Successfully retrieved user info for: %s", userID) |
|||
return identity, nil |
|||
} |
|||
|
|||
// ValidateToken validates credentials (for LDAP, this is username/password)
|
|||
func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
if token == "" { |
|||
return nil, fmt.Errorf("token cannot be empty") |
|||
} |
|||
|
|||
// Parse credentials (username:password format)
|
|||
parts := strings.SplitN(token, ":", 2) |
|||
if len(parts) != 2 { |
|||
return nil, fmt.Errorf("invalid token format (expected username:password)") |
|||
} |
|||
|
|||
username, password := parts[0], parts[1] |
|||
|
|||
// Get connection from pool
|
|||
conn, err := p.getConnection() |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to get LDAP connection: %v", err) |
|||
} |
|||
defer p.releaseConnection(conn) |
|||
|
|||
// Perform LDAP bind with service account if configured
|
|||
if p.config.BindDN != "" && p.config.BindPass != "" { |
|||
err = conn.Bind(p.config.BindDN, p.config.BindPass) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to bind with service account: %v", err) |
|||
} |
|||
} |
|||
|
|||
// Search for user using configured user filter
|
|||
userFilter := fmt.Sprintf(p.config.UserFilter, EscapeFilter(username)) |
|||
searchRequest := &LDAPSearchRequest{ |
|||
BaseDN: p.config.BaseDN, |
|||
Scope: ScopeWholeSubtree, |
|||
DerefAliases: NeverDerefAliases, |
|||
SizeLimit: 1, // We only need one user
|
|||
TimeLimit: 30, // 30 second timeout
|
|||
TypesOnly: false, |
|||
Filter: userFilter, |
|||
Attributes: p.getSearchAttributes(), |
|||
} |
|||
|
|||
glog.V(3).Infof("Validating credentials for user %s with filter: %s", username, userFilter) |
|||
searchResult, err := conn.Search(searchRequest) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("LDAP user search failed: %v", err) |
|||
} |
|||
|
|||
if len(searchResult.Entries) == 0 { |
|||
return nil, fmt.Errorf("user not found in LDAP: %s", username) |
|||
} |
|||
|
|||
if len(searchResult.Entries) > 1 { |
|||
glog.V(2).Infof("Multiple entries found for user %s, using first one", username) |
|||
} |
|||
|
|||
userEntry := searchResult.Entries[0] |
|||
userDN := userEntry.DN |
|||
|
|||
// Attempt to bind with user credentials to validate password
|
|||
err = conn.Bind(userDN, password) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("LDAP authentication failed for user %s: %v", username, err) |
|||
} |
|||
|
|||
glog.V(3).Infof("LDAP credential validation successful for user: %s", username) |
|||
|
|||
// Extract user claims (DN, attributes, group memberships)
|
|||
attributes := make(map[string][]string) |
|||
for _, attr := range userEntry.Attributes { |
|||
attributes[attr.Name] = attr.Values |
|||
} |
|||
|
|||
// Get user groups if group filter is configured
|
|||
var groups []string |
|||
if p.config.GroupFilter != "" { |
|||
groups, err = p.getUserGroups(conn, userDN, username) |
|||
if err != nil { |
|||
glog.V(2).Infof("Failed to retrieve groups for user %s: %v", username, err) |
|||
} |
|||
} |
|||
|
|||
// Return TokenClaims with LDAP-specific information
|
|||
claims := &providers.TokenClaims{ |
|||
Subject: username, |
|||
Issuer: p.name, |
|||
Claims: map[string]interface{}{ |
|||
"dn": userDN, |
|||
"provider": p.name, |
|||
"groups": groups, |
|||
"attributes": attributes, |
|||
}, |
|||
} |
|||
|
|||
return claims, nil |
|||
} |
|||
|
|||
// mapLDAPAttributes maps LDAP attributes to ExternalIdentity
|
|||
func (p *LDAPProvider) mapLDAPAttributes(userID string, attrs map[string][]string) *providers.ExternalIdentity { |
|||
identity := &providers.ExternalIdentity{ |
|||
UserID: userID, |
|||
Provider: p.name, |
|||
Attributes: make(map[string]string), |
|||
} |
|||
|
|||
// Map configured attributes
|
|||
for identityField, ldapAttr := range p.config.Attributes { |
|||
if values, exists := attrs[ldapAttr]; exists && len(values) > 0 { |
|||
switch identityField { |
|||
case "email": |
|||
identity.Email = values[0] |
|||
case "displayName": |
|||
identity.DisplayName = values[0] |
|||
case "groups": |
|||
identity.Groups = values |
|||
default: |
|||
// Store as custom attribute
|
|||
identity.Attributes[identityField] = values[0] |
|||
} |
|||
} |
|||
} |
|||
|
|||
return identity |
|||
} |
|||
|
|||
// mapUserToRole maps user groups to roles based on role mapping rules
|
|||
func (p *LDAPProvider) mapUserToRole(identity *providers.ExternalIdentity) string { |
|||
if p.config.RoleMapping == nil { |
|||
return "" |
|||
} |
|||
|
|||
// Create token claims from identity for rule matching
|
|||
claims := &providers.TokenClaims{ |
|||
Subject: identity.UserID, |
|||
Claims: map[string]interface{}{ |
|||
"groups": identity.Groups, |
|||
"email": identity.Email, |
|||
}, |
|||
} |
|||
|
|||
// Check mapping rules
|
|||
for _, rule := range p.config.RoleMapping.Rules { |
|||
if rule.Matches(claims) { |
|||
return rule.Role |
|||
} |
|||
} |
|||
|
|||
// Return default role if no rules match
|
|||
return p.config.RoleMapping.DefaultRole |
|||
} |
|||
|
|||
// Connection management methods (stubs for now)
|
|||
func (p *LDAPProvider) getConnectionPool() interface{} { |
|||
return p.connPool |
|||
} |
|||
|
|||
func (p *LDAPProvider) getConnection() (*LDAPConn, error) { |
|||
if p.connPool == nil { |
|||
return nil, fmt.Errorf("LDAP connection pool not initialized") |
|||
} |
|||
return p.connPool.GetConnection() |
|||
} |
|||
|
|||
func (p *LDAPProvider) releaseConnection(conn *LDAPConn) { |
|||
if p.connPool != nil && conn != nil { |
|||
p.connPool.ReleaseConnection(conn) |
|||
} |
|||
} |
|||
|
|||
// getSearchAttributes returns the list of attributes to retrieve
|
|||
func (p *LDAPProvider) getSearchAttributes() []string { |
|||
attrs := make([]string, 0, len(p.config.Attributes)+1) |
|||
attrs = append(attrs, "dn") // Always include DN
|
|||
|
|||
for _, ldapAttr := range p.config.Attributes { |
|||
attrs = append(attrs, ldapAttr) |
|||
} |
|||
|
|||
return attrs |
|||
} |
|||
|
|||
// getUserGroups retrieves user groups using the configured group filter
|
|||
func (p *LDAPProvider) getUserGroups(conn *LDAPConn, userDN, username string) ([]string, error) { |
|||
// Try different group search approaches
|
|||
|
|||
// 1. Search by member DN
|
|||
groupFilter := fmt.Sprintf(p.config.GroupFilter, EscapeFilter(userDN)) |
|||
groups, err := p.searchGroups(conn, groupFilter) |
|||
if err == nil && len(groups) > 0 { |
|||
return groups, nil |
|||
} |
|||
|
|||
// 2. Search by username if DN search fails
|
|||
groupFilter = fmt.Sprintf(p.config.GroupFilter, EscapeFilter(username)) |
|||
groups, err = p.searchGroups(conn, groupFilter) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return groups, nil |
|||
} |
|||
|
|||
// searchGroups performs the actual group search
|
|||
func (p *LDAPProvider) searchGroups(conn *LDAPConn, filter string) ([]string, error) { |
|||
searchRequest := &LDAPSearchRequest{ |
|||
BaseDN: p.config.BaseDN, |
|||
Scope: ScopeWholeSubtree, |
|||
DerefAliases: NeverDerefAliases, |
|||
SizeLimit: 0, |
|||
TimeLimit: 0, |
|||
TypesOnly: false, |
|||
Filter: filter, |
|||
Attributes: []string{"cn", "dn"}, |
|||
} |
|||
|
|||
searchResult, err := conn.Search(searchRequest) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("group search failed: %v", err) |
|||
} |
|||
|
|||
groups := make([]string, 0, len(searchResult.Entries)) |
|||
for _, entry := range searchResult.Entries { |
|||
// Try to get CN first, fall back to DN
|
|||
if cn := entry.GetAttributeValue("cn"); cn != "" { |
|||
groups = append(groups, cn) |
|||
} else { |
|||
groups = append(groups, entry.DN) |
|||
} |
|||
} |
|||
|
|||
return groups, nil |
|||
} |
|||
|
|||
// NewLDAPConnectionPool creates a new LDAP connection pool
|
|||
func NewLDAPConnectionPool(config *LDAPConfig) (*LDAPConnectionPool, error) { |
|||
maxConns := config.MaxConnections |
|||
if maxConns <= 0 { |
|||
maxConns = 10 |
|||
} |
|||
|
|||
pool := &LDAPConnectionPool{ |
|||
config: config, |
|||
connections: make(chan *LDAPConn, maxConns), |
|||
maxConns: maxConns, |
|||
} |
|||
|
|||
// Pre-populate the pool with a few connections for testing
|
|||
for i := 0; i < 2 && i < maxConns; i++ { |
|||
conn, err := pool.createConnection() |
|||
if err != nil { |
|||
// If we can't create any connections, return error
|
|||
if i == 0 { |
|||
return nil, err |
|||
} |
|||
// If we created at least one, continue
|
|||
break |
|||
} |
|||
pool.connections <- conn |
|||
} |
|||
|
|||
return pool, nil |
|||
} |
|||
|
|||
// createConnection creates a new LDAP connection
|
|||
func (pool *LDAPConnectionPool) createConnection() (*LDAPConn, error) { |
|||
var netConn net.Conn |
|||
var err error |
|||
|
|||
timeout := time.Duration(pool.config.ConnTimeout) * time.Second |
|||
|
|||
// Parse server address
|
|||
serverAddr := pool.config.Server |
|||
if strings.HasPrefix(serverAddr, "ldap://") { |
|||
serverAddr = strings.TrimPrefix(serverAddr, "ldap://") |
|||
} else if strings.HasPrefix(serverAddr, "ldaps://") { |
|||
serverAddr = strings.TrimPrefix(serverAddr, "ldaps://") |
|||
} |
|||
|
|||
// Add default port if not specified
|
|||
if !strings.Contains(serverAddr, ":") { |
|||
if strings.HasPrefix(pool.config.Server, "ldaps://") { |
|||
serverAddr += ":636" |
|||
} else { |
|||
serverAddr += ":389" |
|||
} |
|||
} |
|||
|
|||
if strings.HasPrefix(pool.config.Server, "ldaps://") { |
|||
// LDAPS connection
|
|||
tlsConfig := &tls.Config{ |
|||
InsecureSkipVerify: pool.config.TLSSkipVerify, |
|||
} |
|||
dialer := &net.Dialer{Timeout: timeout} |
|||
netConn, err = tls.DialWithDialer(dialer, "tcp", serverAddr, tlsConfig) |
|||
} else { |
|||
// Plain LDAP connection
|
|||
netConn, err = net.DialTimeout("tcp", serverAddr, timeout) |
|||
} |
|||
|
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to connect to LDAP server %s: %v", pool.config.Server, err) |
|||
} |
|||
|
|||
conn := &LDAPConn{ |
|||
serverAddr: serverAddr, |
|||
conn: netConn, |
|||
bound: false, |
|||
tlsConfig: &tls.Config{ |
|||
InsecureSkipVerify: pool.config.TLSSkipVerify, |
|||
}, |
|||
} |
|||
|
|||
// Start TLS if configured and not already using LDAPS
|
|||
if pool.config.UseTLS && !strings.HasPrefix(pool.config.Server, "ldaps://") { |
|||
err = conn.StartTLS(conn.tlsConfig) |
|||
if err != nil { |
|||
conn.Close() |
|||
return nil, fmt.Errorf("failed to start TLS: %v", err) |
|||
} |
|||
} |
|||
|
|||
return conn, nil |
|||
} |
|||
|
|||
// GetConnection retrieves a connection from the pool
|
|||
func (pool *LDAPConnectionPool) GetConnection() (*LDAPConn, error) { |
|||
select { |
|||
case conn := <-pool.connections: |
|||
// Test if connection is still valid
|
|||
if pool.isConnectionValid(conn) { |
|||
return conn, nil |
|||
} |
|||
// Connection is stale, close it and create a new one
|
|||
conn.Close() |
|||
default: |
|||
// No connection available in pool
|
|||
} |
|||
|
|||
// Create a new connection
|
|||
return pool.createConnection() |
|||
} |
|||
|
|||
// ReleaseConnection returns a connection to the pool
|
|||
func (pool *LDAPConnectionPool) ReleaseConnection(conn *LDAPConn) { |
|||
if conn == nil { |
|||
return |
|||
} |
|||
|
|||
select { |
|||
case pool.connections <- conn: |
|||
// Successfully returned to pool
|
|||
default: |
|||
// Pool is full, close the connection
|
|||
conn.Close() |
|||
} |
|||
} |
|||
|
|||
// isConnectionValid tests if a connection is still valid
|
|||
func (pool *LDAPConnectionPool) isConnectionValid(conn *LDAPConn) bool { |
|||
// Simple test: check if underlying connection is still open
|
|||
if conn == nil || conn.conn == nil { |
|||
return false |
|||
} |
|||
|
|||
// Try to perform a simple operation to test connectivity
|
|||
searchRequest := &LDAPSearchRequest{ |
|||
BaseDN: "", |
|||
Scope: ScopeBaseObject, |
|||
DerefAliases: NeverDerefAliases, |
|||
SizeLimit: 0, |
|||
TimeLimit: 0, |
|||
TypesOnly: false, |
|||
Filter: "(objectClass=*)", |
|||
Attributes: []string{"1.1"}, // Minimal attributes
|
|||
} |
|||
|
|||
_, err := conn.Search(searchRequest) |
|||
return err == nil |
|||
} |
|||
|
|||
// Close closes all connections in the pool
|
|||
func (pool *LDAPConnectionPool) Close() { |
|||
pool.mu.Lock() |
|||
defer pool.mu.Unlock() |
|||
|
|||
close(pool.connections) |
|||
for conn := range pool.connections { |
|||
conn.Close() |
|||
} |
|||
} |
|||
|
|||
// Helper functions and LDAP connection methods
|
|||
|
|||
// EscapeFilter escapes special characters in LDAP filter values
|
|||
func EscapeFilter(filter string) string { |
|||
// Basic LDAP filter escaping
|
|||
filter = strings.ReplaceAll(filter, "\\", "\\5c") |
|||
filter = strings.ReplaceAll(filter, "*", "\\2a") |
|||
filter = strings.ReplaceAll(filter, "(", "\\28") |
|||
filter = strings.ReplaceAll(filter, ")", "\\29") |
|||
filter = strings.ReplaceAll(filter, "/", "\\2f") |
|||
filter = strings.ReplaceAll(filter, "=", "\\3d") |
|||
return filter |
|||
} |
|||
|
|||
// LDAPConn methods
|
|||
|
|||
// Bind performs an LDAP bind operation
|
|||
func (conn *LDAPConn) Bind(bindDN, bindPassword string) error { |
|||
if conn == nil || conn.conn == nil { |
|||
return fmt.Errorf("connection is nil") |
|||
} |
|||
|
|||
// In a real implementation, this would send an LDAP bind request
|
|||
// For now, we simulate the bind operation
|
|||
glog.V(3).Infof("LDAP Bind attempt for DN: %s", bindDN) |
|||
|
|||
// Simple validation
|
|||
if bindDN == "" { |
|||
return fmt.Errorf("bind DN cannot be empty") |
|||
} |
|||
|
|||
// Simulate bind success for valid credentials
|
|||
if bindPassword != "" { |
|||
conn.bound = true |
|||
return nil |
|||
} |
|||
|
|||
return fmt.Errorf("invalid credentials") |
|||
} |
|||
|
|||
// Search performs an LDAP search operation
|
|||
func (conn *LDAPConn) Search(searchRequest *LDAPSearchRequest) (*LDAPSearchResult, error) { |
|||
if conn == nil || conn.conn == nil { |
|||
return nil, fmt.Errorf("connection is nil") |
|||
} |
|||
|
|||
glog.V(3).Infof("LDAP Search - BaseDN: %s, Filter: %s", searchRequest.BaseDN, searchRequest.Filter) |
|||
|
|||
// In a real implementation, this would send an LDAP search request
|
|||
// For now, we simulate a search operation
|
|||
result := &LDAPSearchResult{ |
|||
Entries: []*LDAPEntry{}, |
|||
} |
|||
|
|||
// Simulate finding a test user for certain searches
|
|||
if strings.Contains(searchRequest.Filter, "testuser") || strings.Contains(searchRequest.Filter, "admin") { |
|||
entry := &LDAPEntry{ |
|||
DN: fmt.Sprintf("uid=%s,%s", "testuser", searchRequest.BaseDN), |
|||
Attributes: []*LDAPAttribute{ |
|||
{Name: "uid", Values: []string{"testuser"}}, |
|||
{Name: "mail", Values: []string{"testuser@example.com"}}, |
|||
{Name: "cn", Values: []string{"Test User"}}, |
|||
{Name: "memberOf", Values: []string{"cn=users,ou=groups," + searchRequest.BaseDN}}, |
|||
}, |
|||
} |
|||
result.Entries = append(result.Entries, entry) |
|||
} |
|||
|
|||
return result, nil |
|||
} |
|||
|
|||
// Close closes the LDAP connection
|
|||
func (conn *LDAPConn) Close() error { |
|||
if conn != nil && conn.conn != nil { |
|||
return conn.conn.Close() |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// StartTLS starts TLS on the connection
|
|||
func (conn *LDAPConn) StartTLS(config *tls.Config) error { |
|||
if conn == nil || conn.conn == nil { |
|||
return fmt.Errorf("connection is nil") |
|||
} |
|||
|
|||
// In a real implementation, this would upgrade the connection to TLS
|
|||
glog.V(3).Info("LDAP StartTLS operation") |
|||
return nil |
|||
} |
|||
|
|||
// LDAPEntry methods
|
|||
|
|||
// GetAttributeValue returns the first value of the specified attribute
|
|||
func (entry *LDAPEntry) GetAttributeValue(attrName string) string { |
|||
for _, attr := range entry.Attributes { |
|||
if attr.Name == attrName && len(attr.Values) > 0 { |
|||
return attr.Values[0] |
|||
} |
|||
} |
|||
return "" |
|||
} |
@ -1,360 +0,0 @@ |
|||
package ldap |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"testing" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|||
"github.com/stretchr/testify/assert" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
// TestLDAPProviderInitialization tests LDAP provider initialization
|
|||
func TestLDAPProviderInitialization(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
config *LDAPConfig |
|||
wantErr bool |
|||
}{ |
|||
{ |
|||
name: "valid config", |
|||
config: &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
UserFilter: "(sAMAccountName=%s)", |
|||
GroupFilter: "(member=%s)", |
|||
}, |
|||
wantErr: false, |
|||
}, |
|||
{ |
|||
name: "missing server", |
|||
config: &LDAPConfig{ |
|||
BaseDN: "DC=example,DC=com", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "missing base DN", |
|||
config: &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "invalid server URL", |
|||
config: &LDAPConfig{ |
|||
Server: "invalid-url", |
|||
BaseDN: "DC=example,DC=com", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
provider := NewLDAPProvider("test-ldap") |
|||
|
|||
err := provider.Initialize(tt.config) |
|||
|
|||
if tt.wantErr { |
|||
assert.Error(t, err) |
|||
} else { |
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "test-ldap", provider.Name()) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestLDAPProviderAuthentication tests LDAP authentication
|
|||
func TestLDAPProviderAuthentication(t *testing.T) { |
|||
// Skip if no LDAP test server available
|
|||
if testing.Short() { |
|||
t.Skip("Skipping LDAP integration test in short mode") |
|||
} |
|||
|
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
UserFilter: "(sAMAccountName=%s)", |
|||
GroupFilter: "(member=%s)", |
|||
Attributes: map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "displayName", |
|||
"groups": "memberOf", |
|||
}, |
|||
RoleMapping: &providers.RoleMapping{ |
|||
Rules: []providers.MappingRule{ |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*CN=Admins*", |
|||
Role: "arn:seaweed:iam::role/AdminRole", |
|||
}, |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*CN=Users*", |
|||
Role: "arn:seaweed:iam::role/UserRole", |
|||
}, |
|||
}, |
|||
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("authenticate with username/password", func(t *testing.T) { |
|||
// This would require an actual LDAP server for integration testing
|
|||
credentials := "user:password" // Basic auth format
|
|||
|
|||
identity, err := provider.Authenticate(context.Background(), credentials) |
|||
if err != nil { |
|||
t.Skip("LDAP server not available for testing") |
|||
} |
|||
|
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "user", identity.UserID) |
|||
assert.Equal(t, "test-ldap", identity.Provider) |
|||
assert.NotEmpty(t, identity.Email) |
|||
}) |
|||
|
|||
t.Run("authenticate with invalid credentials", func(t *testing.T) { |
|||
_, err := provider.Authenticate(context.Background(), "invalid:credentials") |
|||
assert.Error(t, err) |
|||
}) |
|||
} |
|||
|
|||
// TestLDAPProviderUserInfo tests LDAP user info retrieval
|
|||
func TestLDAPProviderUserInfo(t *testing.T) { |
|||
if testing.Short() { |
|||
t.Skip("Skipping LDAP integration test in short mode") |
|||
} |
|||
|
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
UserFilter: "(sAMAccountName=%s)", |
|||
Attributes: map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "displayName", |
|||
"department": "department", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("get user info", func(t *testing.T) { |
|||
identity, err := provider.GetUserInfo(context.Background(), "testuser") |
|||
if err != nil { |
|||
t.Skip("LDAP server not available for testing") |
|||
} |
|||
|
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "testuser", identity.UserID) |
|||
assert.Equal(t, "test-ldap", identity.Provider) |
|||
assert.NotEmpty(t, identity.Email) |
|||
assert.NotEmpty(t, identity.DisplayName) |
|||
}) |
|||
|
|||
t.Run("get user info with empty username", func(t *testing.T) { |
|||
_, err := provider.GetUserInfo(context.Background(), "") |
|||
assert.Error(t, err) |
|||
}) |
|||
|
|||
t.Run("get user info for non-existent user", func(t *testing.T) { |
|||
_, err := provider.GetUserInfo(context.Background(), "nonexistent") |
|||
assert.Error(t, err) |
|||
}) |
|||
} |
|||
|
|||
// TestLDAPAttributeMapping tests LDAP attribute mapping
|
|||
func TestLDAPAttributeMapping(t *testing.T) { |
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
Attributes: map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "cn", |
|||
"department": "departmentNumber", |
|||
"groups": "memberOf", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("map LDAP attributes to identity", func(t *testing.T) { |
|||
ldapAttrs := map[string][]string{ |
|||
"mail": {"user@example.com"}, |
|||
"cn": {"John Doe"}, |
|||
"departmentNumber": {"IT"}, |
|||
"memberOf": { |
|||
"CN=Users,OU=Groups,DC=example,DC=com", |
|||
"CN=Developers,OU=Groups,DC=example,DC=com", |
|||
}, |
|||
} |
|||
|
|||
identity := provider.mapLDAPAttributes("testuser", ldapAttrs) |
|||
|
|||
assert.Equal(t, "testuser", identity.UserID) |
|||
assert.Equal(t, "user@example.com", identity.Email) |
|||
assert.Equal(t, "John Doe", identity.DisplayName) |
|||
assert.Equal(t, "test-ldap", identity.Provider) |
|||
|
|||
// Check groups
|
|||
assert.Contains(t, identity.Groups, "CN=Users,OU=Groups,DC=example,DC=com") |
|||
assert.Contains(t, identity.Groups, "CN=Developers,OU=Groups,DC=example,DC=com") |
|||
|
|||
// Check attributes
|
|||
assert.Equal(t, "IT", identity.Attributes["department"]) |
|||
}) |
|||
} |
|||
|
|||
// TestLDAPGroupFiltering tests LDAP group filtering and role mapping
|
|||
func TestLDAPGroupFiltering(t *testing.T) { |
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
RoleMapping: &providers.RoleMapping{ |
|||
Rules: []providers.MappingRule{ |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*Admins*", |
|||
Role: "arn:seaweed:iam::role/AdminRole", |
|||
}, |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*Users*", |
|||
Role: "arn:seaweed:iam::role/UserRole", |
|||
}, |
|||
}, |
|||
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
tests := []struct { |
|||
name string |
|||
groups []string |
|||
expectedRole string |
|||
expectedClaims map[string]interface{} |
|||
}{ |
|||
{ |
|||
name: "admin user", |
|||
groups: []string{"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
expectedRole: "arn:seaweed:iam::role/AdminRole", |
|||
}, |
|||
{ |
|||
name: "regular user", |
|||
groups: []string{"CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
expectedRole: "arn:seaweed:iam::role/UserRole", |
|||
}, |
|||
{ |
|||
name: "guest user", |
|||
groups: []string{"CN=Guests,OU=Groups,DC=example,DC=com"}, |
|||
expectedRole: "arn:seaweed:iam::role/GuestRole", |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
identity := &providers.ExternalIdentity{ |
|||
UserID: "testuser", |
|||
Groups: tt.groups, |
|||
Provider: "test-ldap", |
|||
} |
|||
|
|||
role := provider.mapUserToRole(identity) |
|||
assert.Equal(t, tt.expectedRole, role) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestLDAPConnectionPool tests LDAP connection pooling
|
|||
func TestLDAPConnectionPool(t *testing.T) { |
|||
if testing.Short() { |
|||
t.Skip("Skipping LDAP connection pool test in short mode") |
|||
} |
|||
|
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
MaxConnections: 5, |
|||
ConnTimeout: 30, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("connection pool management", func(t *testing.T) { |
|||
// Test that multiple concurrent requests work
|
|||
// This would require actual LDAP server for full testing
|
|||
pool := provider.getConnectionPool() |
|||
|
|||
// In CI environments where no LDAP server is available, pool might be nil
|
|||
// Skip the test if we can't establish a connection
|
|||
conn, err := provider.getConnection() |
|||
if err != nil { |
|||
t.Skip("LDAP server not available - skipping connection pool test") |
|||
return |
|||
} |
|||
|
|||
// Only test if we successfully got a connection
|
|||
assert.NotNil(t, pool) |
|||
assert.NotNil(t, conn) |
|||
provider.releaseConnection(conn) |
|||
}) |
|||
} |
|||
|
|||
// MockLDAPServer for unit testing (without external dependencies)
|
|||
type MockLDAPServer struct { |
|||
users map[string]map[string][]string |
|||
} |
|||
|
|||
func NewMockLDAPServer() *MockLDAPServer { |
|||
return &MockLDAPServer{ |
|||
users: map[string]map[string][]string{ |
|||
"testuser": { |
|||
"mail": {"testuser@example.com"}, |
|||
"cn": {"Test User"}, |
|||
"department": {"Engineering"}, |
|||
"memberOf": {"CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
}, |
|||
"admin": { |
|||
"mail": {"admin@example.com"}, |
|||
"cn": {"Administrator"}, |
|||
"department": {"IT"}, |
|||
"memberOf": {"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
}, |
|||
}, |
|||
} |
|||
} |
|||
|
|||
func (m *MockLDAPServer) Authenticate(username, password string) bool { |
|||
_, exists := m.users[username] |
|||
return exists && password == "password" // Mock authentication
|
|||
} |
|||
|
|||
func (m *MockLDAPServer) GetUserAttributes(username string) (map[string][]string, error) { |
|||
if attrs, exists := m.users[username]; exists { |
|||
return attrs, nil |
|||
} |
|||
return nil, fmt.Errorf("user not found: %s", username) |
|||
} |
@ -1,356 +0,0 @@ |
|||
# Runtime Filer Address Implementation |
|||
|
|||
This document describes the implementation of runtime filer address passing for the STSService, addressing the requirement that filer addresses should be passed at call-time rather than initialization time. |
|||
|
|||
## Problem Statement |
|||
|
|||
The user identified a critical issue with the original STS implementation: |
|||
|
|||
> "the filer address should be passed when called, not during init time, since the filer may change." |
|||
|
|||
This is important because: |
|||
1. **Filer Failover**: Filer addresses can change during runtime due to failover scenarios |
|||
2. **Load Balancing**: Different requests may need to hit different filer instances |
|||
3. **Environment Agnostic**: Configuration files should work across dev/staging/prod without hardcoded addresses |
|||
4. **SeaweedFS Patterns**: Follows existing SeaweedFS patterns used throughout the codebase |
|||
|
|||
## Implementation Changes |
|||
|
|||
### 1. SessionStore Interface Refactoring |
|||
|
|||
**Before:** |
|||
```go |
|||
type SessionStore interface { |
|||
StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error |
|||
GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) |
|||
RevokeSession(ctx context.Context, sessionId string) error |
|||
CleanupExpiredSessions(ctx context.Context) error |
|||
} |
|||
``` |
|||
|
|||
**After:** |
|||
```go |
|||
type SessionStore interface { |
|||
// filerAddress ignored for memory stores, required for filer stores |
|||
StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error |
|||
GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) |
|||
RevokeSession(ctx context.Context, filerAddress string, sessionId string) error |
|||
CleanupExpiredSessions(ctx context.Context, filerAddress string) error |
|||
} |
|||
``` |
|||
|
|||
### 2. FilerSessionStore Changes |
|||
|
|||
**Before:** |
|||
```go |
|||
type FilerSessionStore struct { |
|||
filerGrpcAddress string // ❌ Fixed at init time |
|||
grpcDialOption grpc.DialOption |
|||
basePath string |
|||
} |
|||
|
|||
func NewFilerSessionStore(filerAddress string, config map[string]interface{}) (*FilerSessionStore, error) { |
|||
store := &FilerSessionStore{ |
|||
filerGrpcAddress: filerAddress, // ❌ Locked in during init |
|||
basePath: DefaultSessionBasePath, |
|||
} |
|||
// ... |
|||
} |
|||
``` |
|||
|
|||
**After:** |
|||
```go |
|||
type FilerSessionStore struct { |
|||
grpcDialOption grpc.DialOption // ✅ No fixed filer address |
|||
basePath string |
|||
} |
|||
|
|||
func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) { |
|||
store := &FilerSessionStore{ |
|||
basePath: DefaultSessionBasePath, // ✅ Only path configuration |
|||
} |
|||
// ✅ filerAddress passed at call time |
|||
} |
|||
|
|||
func (f *FilerSessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|||
// ✅ filerAddress provided per call |
|||
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|||
// ... store logic |
|||
}) |
|||
} |
|||
``` |
|||
|
|||
### 3. STS Service Method Signatures |
|||
|
|||
**Before:** |
|||
```go |
|||
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) |
|||
func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) |
|||
func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) error |
|||
``` |
|||
|
|||
**After:** |
|||
```go |
|||
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, filerAddress string, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) |
|||
func (s *STSService) ValidateSessionToken(ctx context.Context, filerAddress string, sessionToken string) (*SessionInfo, error) |
|||
func (s *STSService) RevokeSession(ctx context.Context, filerAddress string, sessionToken string) error |
|||
``` |
|||
|
|||
### 4. Configuration Cleanup |
|||
|
|||
**Before (iam_config_distributed.json):** |
|||
```json |
|||
{ |
|||
"sts": { |
|||
"sessionStoreConfig": { |
|||
"filerAddress": "localhost:8888", // ❌ Environment-specific |
|||
"basePath": "/etc/iam/sessions" |
|||
} |
|||
}, |
|||
"policy": { |
|||
"storeConfig": { |
|||
"filerAddress": "localhost:8888", // ❌ Environment-specific |
|||
"basePath": "/etc/iam/policies" |
|||
} |
|||
} |
|||
} |
|||
``` |
|||
|
|||
**After (iam_config_distributed.json):** |
|||
```json |
|||
{ |
|||
"sts": { |
|||
"sessionStoreConfig": { |
|||
"basePath": "/etc/iam/sessions" // ✅ Environment-agnostic |
|||
} |
|||
}, |
|||
"policy": { |
|||
"storeConfig": { |
|||
"basePath": "/etc/iam/policies" // ✅ Environment-agnostic |
|||
} |
|||
} |
|||
} |
|||
``` |
|||
|
|||
## Usage Examples |
|||
|
|||
### Caller Perspective (S3 API Server) |
|||
|
|||
**Before:** |
|||
```go |
|||
// STS service locked to specific filer during init |
|||
stsService.Initialize(&STSConfig{ |
|||
SessionStoreConfig: map[string]interface{}{ |
|||
"filerAddress": "filer-1:8888", // ❌ Fixed choice |
|||
"basePath": "/etc/iam/sessions", |
|||
}, |
|||
}) |
|||
|
|||
// All calls go to filer-1, no failover possible |
|||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, request) |
|||
``` |
|||
|
|||
**After:** |
|||
```go |
|||
// STS service configured without specific filer |
|||
stsService.Initialize(&STSConfig{ |
|||
SessionStoreConfig: map[string]interface{}{ |
|||
"basePath": "/etc/iam/sessions", // ✅ Just the path |
|||
}, |
|||
}) |
|||
|
|||
// Caller determines filer address per request |
|||
currentFiler := s.getCurrentFilerAddress() // ✅ Dynamic selection |
|||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, currentFiler, request) |
|||
``` |
|||
|
|||
### Dynamic Filer Selection |
|||
|
|||
```go |
|||
type S3ApiServer struct { |
|||
stsService *sts.STSService |
|||
filerClient *filer.Client |
|||
} |
|||
|
|||
func (s *S3ApiServer) getCurrentFilerAddress() string { |
|||
// ✅ Can implement any strategy: |
|||
// - Load balancing across multiple filers |
|||
// - Health checking and failover |
|||
// - Geographic routing |
|||
// - Round-robin selection |
|||
return s.filerClient.GetAvailableFiler() |
|||
} |
|||
|
|||
func (s *S3ApiServer) handleAssumeRole(ctx context.Context, request *AssumeRoleRequest) { |
|||
// ✅ Filer address determined at request time |
|||
filerAddr := s.getCurrentFilerAddress() |
|||
|
|||
response, err := s.stsService.AssumeRoleWithWebIdentity(ctx, filerAddr, request) |
|||
if err != nil && isNetworkError(err) { |
|||
// ✅ Retry with different filer |
|||
filerAddr = s.getBackupFilerAddress() |
|||
response, err = s.stsService.AssumeRoleWithWebIdentity(ctx, filerAddr, request) |
|||
} |
|||
} |
|||
``` |
|||
|
|||
## Memory Store Compatibility |
|||
|
|||
The `MemorySessionStore` accepts the `filerAddress` parameter but ignores it, maintaining interface consistency: |
|||
|
|||
```go |
|||
func (m *MemorySessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|||
// filerAddress ignored for memory store - maintains interface compatibility |
|||
if sessionId == "" { |
|||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
// ... in-memory storage logic |
|||
} |
|||
``` |
|||
|
|||
## Benefits Achieved |
|||
|
|||
### 1. **Dynamic Filer Selection** |
|||
```go |
|||
// Load balancing |
|||
filerAddr := loadBalancer.GetNextFiler() |
|||
|
|||
// Failover support |
|||
filerAddr := failoverManager.GetHealthyFiler() |
|||
|
|||
// Geographic routing |
|||
filerAddr := geoRouter.GetClosestFiler(clientIP) |
|||
``` |
|||
|
|||
### 2. **Environment Portability** |
|||
```bash |
|||
# Same config works everywhere |
|||
dev: STSService.method(ctx, "dev-filer:8888", ...) |
|||
staging: STSService.method(ctx, "staging-filer:8888", ...) |
|||
prod: STSService.method(ctx, "prod-filer-lb:8888", ...) |
|||
``` |
|||
|
|||
### 3. **Operational Flexibility** |
|||
- **Hot filer replacement**: Switch filers without restarting STS |
|||
- **A/B testing**: Route different requests to different filers |
|||
- **Disaster recovery**: Automatic failover to backup filers |
|||
- **Performance optimization**: Route to least loaded filer |
|||
|
|||
### 4. **SeaweedFS Consistency** |
|||
Follows the same pattern used throughout SeaweedFS codebase where filer addresses are passed to methods, not stored in structs. |
|||
|
|||
## Migration Guide |
|||
|
|||
### For Code Calling STS Methods |
|||
|
|||
**Before:** |
|||
```go |
|||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, request) |
|||
session, err := stsService.ValidateSessionToken(ctx, token) |
|||
err := stsService.RevokeSession(ctx, token) |
|||
``` |
|||
|
|||
**After:** |
|||
```go |
|||
filerAddr := getCurrentFilerAddress() // Implement your strategy |
|||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, filerAddr, request) |
|||
session, err := stsService.ValidateSessionToken(ctx, filerAddr, token) |
|||
err := stsService.RevokeSession(ctx, filerAddr, token) |
|||
``` |
|||
|
|||
### For Configuration Files |
|||
|
|||
Remove `filerAddress` from all store configurations: |
|||
|
|||
```bash |
|||
# Update all iam_config*.json files |
|||
sed -i 's|"filerAddress": ".*",||g' iam_config*.json |
|||
``` |
|||
|
|||
## Testing |
|||
|
|||
All tests have been updated to pass a test filer address: |
|||
|
|||
```go |
|||
func TestAssumeRoleWithWebIdentity(t *testing.T) { |
|||
service := setupTestSTSService(t) |
|||
testFilerAddress := "localhost:8888" // Test filer address |
|||
|
|||
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request) |
|||
// ... test logic |
|||
} |
|||
``` |
|||
|
|||
## Production Deployment |
|||
|
|||
### High Availability Setup |
|||
|
|||
```go |
|||
type FilerManager struct { |
|||
primaryFilers []string |
|||
backupFilers []string |
|||
healthChecker *HealthChecker |
|||
} |
|||
|
|||
func (fm *FilerManager) GetAvailableFiler() string { |
|||
// Check primary filers first |
|||
for _, filer := range fm.primaryFilers { |
|||
if fm.healthChecker.IsHealthy(filer) { |
|||
return filer |
|||
} |
|||
} |
|||
|
|||
// Fallback to backup filers |
|||
for _, filer := range fm.backupFilers { |
|||
if fm.healthChecker.IsHealthy(filer) { |
|||
return filer |
|||
} |
|||
} |
|||
|
|||
// Return first primary as last resort |
|||
return fm.primaryFilers[0] |
|||
} |
|||
``` |
|||
|
|||
### Load Balanced Configuration |
|||
|
|||
```json |
|||
{ |
|||
"sts": { |
|||
"sessionStoreType": "filer", |
|||
"sessionStoreConfig": { |
|||
"basePath": "/etc/iam/sessions" |
|||
} |
|||
} |
|||
} |
|||
``` |
|||
|
|||
```go |
|||
// Runtime filer selection |
|||
filerLoadBalancer := &RoundRobinBalancer{ |
|||
Filers: []string{ |
|||
"filer-1.prod:8888", |
|||
"filer-2.prod:8888", |
|||
"filer-3.prod:8888", |
|||
}, |
|||
} |
|||
|
|||
response, err := stsService.AssumeRoleWithWebIdentity( |
|||
ctx, |
|||
filerLoadBalancer.Next(), // ✅ Dynamic selection |
|||
request, |
|||
) |
|||
``` |
|||
|
|||
## Conclusion |
|||
|
|||
This refactoring successfully addresses the requirement for runtime filer address passing, enabling: |
|||
|
|||
- ✅ **Dynamic filer selection** per request |
|||
- ✅ **Automatic failover** capabilities |
|||
- ✅ **Environment-agnostic** configurations |
|||
- ✅ **Load balancing** support |
|||
- ✅ **SeaweedFS pattern** compliance |
|||
- ✅ **Operational flexibility** for production deployments |
|||
|
|||
The implementation maintains backward compatibility for memory stores while enabling powerful distributed deployment scenarios for filer-backed stores. |
@ -1,383 +0,0 @@ |
|||
package sts |
|||
|
|||
import ( |
|||
"context" |
|||
"encoding/json" |
|||
"fmt" |
|||
"strings" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/glog" |
|||
"github.com/seaweedfs/seaweedfs/weed/pb" |
|||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" |
|||
"google.golang.org/grpc" |
|||
) |
|||
|
|||
// MemorySessionStore implements SessionStore using in-memory storage
|
|||
type MemorySessionStore struct { |
|||
sessions map[string]*SessionInfo |
|||
mutex sync.RWMutex |
|||
} |
|||
|
|||
// NewMemorySessionStore creates a new memory-based session store
|
|||
func NewMemorySessionStore() *MemorySessionStore { |
|||
return &MemorySessionStore{ |
|||
sessions: make(map[string]*SessionInfo), |
|||
} |
|||
} |
|||
|
|||
// StoreSession stores session information in memory (filerAddress ignored for memory store)
|
|||
func (m *MemorySessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|||
if sessionId == "" { |
|||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
|
|||
if session == nil { |
|||
return fmt.Errorf("session cannot be nil") |
|||
} |
|||
|
|||
m.mutex.Lock() |
|||
defer m.mutex.Unlock() |
|||
|
|||
m.sessions[sessionId] = session |
|||
return nil |
|||
} |
|||
|
|||
// GetSession retrieves session information from memory (filerAddress ignored for memory store)
|
|||
func (m *MemorySessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) { |
|||
if sessionId == "" { |
|||
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
|
|||
m.mutex.RLock() |
|||
defer m.mutex.RUnlock() |
|||
|
|||
session, exists := m.sessions[sessionId] |
|||
if !exists { |
|||
return nil, fmt.Errorf("session not found") |
|||
} |
|||
|
|||
// Check if session has expired
|
|||
if time.Now().After(session.ExpiresAt) { |
|||
return nil, fmt.Errorf("session has expired") |
|||
} |
|||
|
|||
return session, nil |
|||
} |
|||
|
|||
// RevokeSession revokes a session from memory (filerAddress ignored for memory store)
|
|||
func (m *MemorySessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error { |
|||
if sessionId == "" { |
|||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
|
|||
m.mutex.Lock() |
|||
defer m.mutex.Unlock() |
|||
|
|||
delete(m.sessions, sessionId) |
|||
return nil |
|||
} |
|||
|
|||
// CleanupExpiredSessions removes expired sessions from memory (filerAddress ignored for memory store)
|
|||
func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error { |
|||
m.mutex.Lock() |
|||
defer m.mutex.Unlock() |
|||
|
|||
now := time.Now() |
|||
for sessionId, session := range m.sessions { |
|||
if now.After(session.ExpiresAt) { |
|||
delete(m.sessions, sessionId) |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// ExpireSessionForTesting manually expires a session for testing purposes (filerAddress ignored for memory store)
|
|||
func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionId string) error { |
|||
if sessionId == "" { |
|||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
|
|||
m.mutex.Lock() |
|||
defer m.mutex.Unlock() |
|||
|
|||
session, exists := m.sessions[sessionId] |
|||
if !exists { |
|||
return fmt.Errorf("session not found") |
|||
} |
|||
|
|||
// Set expiration to 1 minute in the past to ensure it's expired
|
|||
session.ExpiresAt = time.Now().Add(-1 * time.Minute) |
|||
m.sessions[sessionId] = session |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// FilerSessionStore implements SessionStore using SeaweedFS filer
|
|||
type FilerSessionStore struct { |
|||
grpcDialOption grpc.DialOption |
|||
basePath string |
|||
} |
|||
|
|||
// NewFilerSessionStore creates a new filer-based session store
|
|||
func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) { |
|||
store := &FilerSessionStore{ |
|||
basePath: DefaultSessionBasePath, // Use constant default
|
|||
} |
|||
|
|||
// Parse configuration - only basePath and other settings, NOT filerAddress
|
|||
if config != nil { |
|||
if basePath, ok := config[ConfigFieldBasePath].(string); ok && basePath != "" { |
|||
store.basePath = strings.TrimSuffix(basePath, "/") |
|||
} |
|||
} |
|||
|
|||
glog.V(2).Infof("Initialized FilerSessionStore with basePath %s", store.basePath) |
|||
|
|||
return store, nil |
|||
} |
|||
|
|||
// StoreSession stores session information in filer
|
|||
func (f *FilerSessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|||
if filerAddress == "" { |
|||
return fmt.Errorf(ErrFilerAddressRequired) |
|||
} |
|||
if sessionId == "" { |
|||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
if session == nil { |
|||
return fmt.Errorf("session cannot be nil") |
|||
} |
|||
|
|||
// Serialize session to JSON
|
|||
sessionData, err := json.Marshal(session) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to serialize session: %v", err) |
|||
} |
|||
|
|||
sessionPath := f.getSessionPath(sessionId) |
|||
|
|||
// Store in filer
|
|||
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|||
request := &filer_pb.CreateEntryRequest{ |
|||
Directory: f.basePath, |
|||
Entry: &filer_pb.Entry{ |
|||
Name: f.getSessionFileName(sessionId), |
|||
IsDirectory: false, |
|||
Attributes: &filer_pb.FuseAttributes{ |
|||
Mtime: time.Now().Unix(), |
|||
Crtime: time.Now().Unix(), |
|||
FileMode: uint32(0600), // Read/write for owner only
|
|||
Uid: uint32(0), |
|||
Gid: uint32(0), |
|||
}, |
|||
Content: sessionData, |
|||
}, |
|||
} |
|||
|
|||
glog.V(3).Infof("Storing session %s at %s", sessionId, sessionPath) |
|||
_, err := client.CreateEntry(ctx, request) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to store session %s: %v", sessionId, err) |
|||
} |
|||
|
|||
return nil |
|||
}) |
|||
} |
|||
|
|||
// GetSession retrieves session information from filer
|
|||
func (f *FilerSessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) { |
|||
if filerAddress == "" { |
|||
return nil, fmt.Errorf(ErrFilerAddressRequired) |
|||
} |
|||
if sessionId == "" { |
|||
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
|
|||
var sessionData []byte |
|||
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|||
request := &filer_pb.LookupDirectoryEntryRequest{ |
|||
Directory: f.basePath, |
|||
Name: f.getSessionFileName(sessionId), |
|||
} |
|||
|
|||
glog.V(3).Infof("Looking up session %s", sessionId) |
|||
response, err := client.LookupDirectoryEntry(ctx, request) |
|||
if err != nil { |
|||
return fmt.Errorf("session not found: %v", err) |
|||
} |
|||
|
|||
if response.Entry == nil { |
|||
return fmt.Errorf("session not found") |
|||
} |
|||
|
|||
sessionData = response.Entry.Content |
|||
return nil |
|||
}) |
|||
|
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// Deserialize session from JSON
|
|||
var session SessionInfo |
|||
if err := json.Unmarshal(sessionData, &session); err != nil { |
|||
return nil, fmt.Errorf("failed to deserialize session: %v", err) |
|||
} |
|||
|
|||
// Check if session has expired
|
|||
if time.Now().After(session.ExpiresAt) { |
|||
// Clean up expired session
|
|||
_ = f.RevokeSession(ctx, filerAddress, sessionId) |
|||
return nil, fmt.Errorf("session has expired") |
|||
} |
|||
|
|||
return &session, nil |
|||
} |
|||
|
|||
// RevokeSession revokes a session from filer
|
|||
func (f *FilerSessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error { |
|||
if filerAddress == "" { |
|||
return fmt.Errorf(ErrFilerAddressRequired) |
|||
} |
|||
if sessionId == "" { |
|||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|||
} |
|||
|
|||
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|||
request := &filer_pb.DeleteEntryRequest{ |
|||
Directory: f.basePath, |
|||
Name: f.getSessionFileName(sessionId), |
|||
IsDeleteData: true, |
|||
IsRecursive: false, |
|||
IgnoreRecursiveError: false, |
|||
} |
|||
|
|||
glog.V(3).Infof("Revoking session %s", sessionId) |
|||
resp, err := client.DeleteEntry(ctx, request) |
|||
if err != nil { |
|||
// Ignore "not found" errors - session may already be deleted
|
|||
if strings.Contains(err.Error(), "not found") { |
|||
return nil |
|||
} |
|||
return fmt.Errorf("failed to revoke session %s: %v", sessionId, err) |
|||
} |
|||
|
|||
// Check response error
|
|||
if resp.Error != "" { |
|||
// Ignore "not found" errors - session may already be deleted
|
|||
if strings.Contains(resp.Error, "not found") { |
|||
return nil |
|||
} |
|||
return fmt.Errorf("failed to revoke session %s: %s", sessionId, resp.Error) |
|||
} |
|||
|
|||
return nil |
|||
}) |
|||
} |
|||
|
|||
// CleanupExpiredSessions removes expired sessions from filer
|
|||
func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error { |
|||
if filerAddress == "" { |
|||
return fmt.Errorf(ErrFilerAddressRequired) |
|||
} |
|||
|
|||
now := time.Now() |
|||
expiredCount := 0 |
|||
|
|||
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|||
// List all entries in the session directory
|
|||
request := &filer_pb.ListEntriesRequest{ |
|||
Directory: f.basePath, |
|||
Prefix: "session_", |
|||
StartFromFileName: "", |
|||
InclusiveStartFrom: false, |
|||
Limit: 1000, // Process in batches of 1000
|
|||
} |
|||
|
|||
stream, err := client.ListEntries(ctx, request) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to list sessions: %v", err) |
|||
} |
|||
|
|||
for { |
|||
resp, err := stream.Recv() |
|||
if err != nil { |
|||
break // End of stream or error
|
|||
} |
|||
|
|||
if resp.Entry == nil || resp.Entry.IsDirectory { |
|||
continue |
|||
} |
|||
|
|||
// Parse session data to check expiration
|
|||
var session SessionInfo |
|||
if err := json.Unmarshal(resp.Entry.Content, &session); err != nil { |
|||
glog.V(2).Infof("Failed to parse session file %s, deleting: %v", resp.Entry.Name, err) |
|||
// Delete corrupted session file
|
|||
f.deleteSessionFile(ctx, client, resp.Entry.Name) |
|||
continue |
|||
} |
|||
|
|||
// Check if session is expired
|
|||
if now.After(session.ExpiresAt) { |
|||
glog.V(3).Infof("Cleaning up expired session: %s", resp.Entry.Name) |
|||
if err := f.deleteSessionFile(ctx, client, resp.Entry.Name); err != nil { |
|||
glog.V(1).Infof("Failed to delete expired session %s: %v", resp.Entry.Name, err) |
|||
} else { |
|||
expiredCount++ |
|||
} |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
}) |
|||
|
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if expiredCount > 0 { |
|||
glog.V(2).Infof("Cleaned up %d expired sessions", expiredCount) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Helper methods
|
|||
|
|||
// withFilerClient executes a function with a filer client
|
|||
func (f *FilerSessionStore) withFilerClient(filerAddress string, fn func(client filer_pb.SeaweedFilerClient) error) error { |
|||
if filerAddress == "" { |
|||
return fmt.Errorf(ErrFilerAddressRequired) |
|||
} |
|||
|
|||
// Use the pb.WithGrpcFilerClient helper similar to existing SeaweedFS code
|
|||
return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(filerAddress), f.grpcDialOption, fn) |
|||
} |
|||
|
|||
// getSessionPath returns the full path for a session
|
|||
func (f *FilerSessionStore) getSessionPath(sessionId string) string { |
|||
return f.basePath + "/" + f.getSessionFileName(sessionId) |
|||
} |
|||
|
|||
// getSessionFileName returns the filename for a session
|
|||
func (f *FilerSessionStore) getSessionFileName(sessionId string) string { |
|||
return "session_" + sessionId + ".json" |
|||
} |
|||
|
|||
// deleteSessionFile deletes a session file
|
|||
func (f *FilerSessionStore) deleteSessionFile(ctx context.Context, client filer_pb.SeaweedFilerClient, fileName string) error { |
|||
request := &filer_pb.DeleteEntryRequest{ |
|||
Directory: f.basePath, |
|||
Name: fileName, |
|||
IsDeleteData: true, |
|||
IsRecursive: false, |
|||
IgnoreRecursiveError: false, |
|||
} |
|||
|
|||
_, err := client.DeleteEntry(ctx, request) |
|||
return err |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue