13 changed files with 40 additions and 2034 deletions
-
29test/s3/iam/DISTRIBUTED.md
-
13test/s3/iam/STS_DISTRIBUTED.md
-
8weed/iam/integration/iam_integration_test.go
-
866weed/iam/ldap/ldap_provider.go
-
360weed/iam/ldap/ldap_provider_test.go
-
18weed/iam/ldap/mock_provider.go
-
356weed/iam/sts/RUNTIME_FILER_ADDRESS.md
-
383weed/iam/sts/session_store.go
-
15weed/iam/sts/sts_service.go
-
8weed/s3api/s3_end_to_end_test.go
-
6weed/s3api/s3_jwt_auth_test.go
-
6weed/s3api/s3_multipart_iam_test.go
-
6weed/s3api/s3_presigned_url_iam_test.go
@ -1,866 +0,0 @@ |
|||||
package ldap |
|
||||
|
|
||||
import ( |
|
||||
"context" |
|
||||
"crypto/tls" |
|
||||
"fmt" |
|
||||
"net" |
|
||||
"strings" |
|
||||
"sync" |
|
||||
"time" |
|
||||
|
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog" |
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|
||||
) |
|
||||
|
|
||||
// LDAPProvider implements LDAP authentication
|
|
||||
type LDAPProvider struct { |
|
||||
name string |
|
||||
config *LDAPConfig |
|
||||
initialized bool |
|
||||
connPool *LDAPConnectionPool |
|
||||
} |
|
||||
|
|
||||
// LDAPConnectionPool manages LDAP connections
|
|
||||
type LDAPConnectionPool struct { |
|
||||
config *LDAPConfig |
|
||||
connections chan *LDAPConn |
|
||||
mu sync.Mutex |
|
||||
maxConns int |
|
||||
} |
|
||||
|
|
||||
// LDAPConn represents an LDAP connection (simplified implementation)
|
|
||||
type LDAPConn struct { |
|
||||
serverAddr string |
|
||||
conn net.Conn |
|
||||
bound bool |
|
||||
tlsConfig *tls.Config |
|
||||
} |
|
||||
|
|
||||
// LDAPSearchResult represents LDAP search results
|
|
||||
type LDAPSearchResult struct { |
|
||||
Entries []*LDAPEntry |
|
||||
} |
|
||||
|
|
||||
// LDAPEntry represents an LDAP directory entry
|
|
||||
type LDAPEntry struct { |
|
||||
DN string |
|
||||
Attributes []*LDAPAttribute |
|
||||
} |
|
||||
|
|
||||
// LDAPAttribute represents an LDAP attribute
|
|
||||
type LDAPAttribute struct { |
|
||||
Name string |
|
||||
Values []string |
|
||||
} |
|
||||
|
|
||||
// LDAPSearchRequest represents an LDAP search request
|
|
||||
type LDAPSearchRequest struct { |
|
||||
BaseDN string |
|
||||
Scope int |
|
||||
DerefAliases int |
|
||||
SizeLimit int |
|
||||
TimeLimit int |
|
||||
TypesOnly bool |
|
||||
Filter string |
|
||||
Attributes []string |
|
||||
} |
|
||||
|
|
||||
// LDAP search scope constants
|
|
||||
const ( |
|
||||
ScopeBaseObject = iota |
|
||||
ScopeWholeSubtree |
|
||||
NeverDerefAliases = 0 |
|
||||
) |
|
||||
|
|
||||
// LDAPConfig holds LDAP provider configuration
|
|
||||
type LDAPConfig struct { |
|
||||
// Server is the LDAP server URL (e.g., ldap://localhost:389)
|
|
||||
Server string `json:"server"` |
|
||||
|
|
||||
// BaseDN is the base distinguished name for searches
|
|
||||
BaseDN string `json:"baseDn"` |
|
||||
|
|
||||
// BindDN is the distinguished name for binding (authentication)
|
|
||||
BindDN string `json:"bindDn,omitempty"` |
|
||||
|
|
||||
// BindPass is the password for binding
|
|
||||
BindPass string `json:"bindPass,omitempty"` |
|
||||
|
|
||||
// UserFilter is the LDAP filter for finding users (e.g., "(sAMAccountName=%s)")
|
|
||||
UserFilter string `json:"userFilter"` |
|
||||
|
|
||||
// GroupFilter is the LDAP filter for finding groups (e.g., "(member=%s)")
|
|
||||
GroupFilter string `json:"groupFilter,omitempty"` |
|
||||
|
|
||||
// Attributes maps SeaweedFS identity fields to LDAP attributes
|
|
||||
Attributes map[string]string `json:"attributes,omitempty"` |
|
||||
|
|
||||
// RoleMapping defines how to map LDAP groups to roles
|
|
||||
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` |
|
||||
|
|
||||
// TLS configuration
|
|
||||
UseTLS bool `json:"useTls,omitempty"` |
|
||||
TLSCert string `json:"tlsCert,omitempty"` |
|
||||
TLSKey string `json:"tlsKey,omitempty"` |
|
||||
TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"` |
|
||||
|
|
||||
// Connection pool settings
|
|
||||
MaxConnections int `json:"maxConnections,omitempty"` |
|
||||
ConnTimeout int `json:"connTimeout,omitempty"` // seconds
|
|
||||
} |
|
||||
|
|
||||
// NewLDAPProvider creates a new LDAP provider
|
|
||||
func NewLDAPProvider(name string) *LDAPProvider { |
|
||||
return &LDAPProvider{ |
|
||||
name: name, |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Name returns the provider name
|
|
||||
func (p *LDAPProvider) Name() string { |
|
||||
return p.name |
|
||||
} |
|
||||
|
|
||||
// Initialize initializes the LDAP provider with configuration
|
|
||||
func (p *LDAPProvider) Initialize(config interface{}) error { |
|
||||
if config == nil { |
|
||||
return fmt.Errorf("config cannot be nil") |
|
||||
} |
|
||||
|
|
||||
ldapConfig, ok := config.(*LDAPConfig) |
|
||||
if !ok { |
|
||||
return fmt.Errorf("invalid config type for LDAP provider") |
|
||||
} |
|
||||
|
|
||||
if err := p.validateConfig(ldapConfig); err != nil { |
|
||||
return fmt.Errorf("invalid LDAP configuration: %w", err) |
|
||||
} |
|
||||
|
|
||||
p.config = ldapConfig |
|
||||
|
|
||||
// Initialize LDAP connection pool
|
|
||||
pool, err := NewLDAPConnectionPool(ldapConfig) |
|
||||
if err != nil { |
|
||||
glog.V(2).Infof("Failed to initialize LDAP connection pool: %v (using mock for testing)", err) |
|
||||
// In case of connection failure, continue but mark as testing mode
|
|
||||
p.initialized = true |
|
||||
return nil |
|
||||
} |
|
||||
p.connPool = pool |
|
||||
|
|
||||
// Test connectivity with one connection
|
|
||||
conn, err := p.connPool.GetConnection() |
|
||||
if err != nil { |
|
||||
glog.V(2).Infof("Failed to establish test LDAP connection: %v (using mock for testing)", err) |
|
||||
p.initialized = true |
|
||||
return nil |
|
||||
} |
|
||||
p.connPool.ReleaseConnection(conn) |
|
||||
|
|
||||
p.initialized = true |
|
||||
glog.V(2).Infof("LDAP provider %s initialized with server %s", p.name, ldapConfig.Server) |
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// validateConfig validates the LDAP configuration
|
|
||||
func (p *LDAPProvider) validateConfig(config *LDAPConfig) error { |
|
||||
if config.Server == "" { |
|
||||
return fmt.Errorf("server is required") |
|
||||
} |
|
||||
|
|
||||
if config.BaseDN == "" { |
|
||||
return fmt.Errorf("base DN is required") |
|
||||
} |
|
||||
|
|
||||
// Basic URL validation
|
|
||||
if !strings.HasPrefix(config.Server, "ldap://") && !strings.HasPrefix(config.Server, "ldaps://") { |
|
||||
return fmt.Errorf("invalid server URL format") |
|
||||
} |
|
||||
|
|
||||
// Set default user filter if not provided
|
|
||||
if config.UserFilter == "" { |
|
||||
config.UserFilter = "(uid=%s)" // Default LDAP user filter
|
|
||||
} |
|
||||
|
|
||||
// Set default attributes if not provided
|
|
||||
if config.Attributes == nil { |
|
||||
config.Attributes = map[string]string{ |
|
||||
"email": "mail", |
|
||||
"displayName": "cn", |
|
||||
"groups": "memberOf", |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// Authenticate authenticates a user with LDAP
|
|
||||
func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*providers.ExternalIdentity, error) { |
|
||||
if !p.initialized { |
|
||||
return nil, fmt.Errorf("provider not initialized") |
|
||||
} |
|
||||
|
|
||||
if credentials == "" { |
|
||||
return nil, fmt.Errorf("credentials cannot be empty") |
|
||||
} |
|
||||
|
|
||||
// Parse credentials (username:password format)
|
|
||||
parts := strings.SplitN(credentials, ":", 2) |
|
||||
if len(parts) != 2 { |
|
||||
return nil, fmt.Errorf("invalid credentials format (expected username:password)") |
|
||||
} |
|
||||
|
|
||||
username, password := parts[0], parts[1] |
|
||||
|
|
||||
// Get connection from pool
|
|
||||
conn, err := p.getConnection() |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("failed to get LDAP connection: %v", err) |
|
||||
} |
|
||||
defer p.releaseConnection(conn) |
|
||||
|
|
||||
// Perform LDAP bind with service account if configured
|
|
||||
if p.config.BindDN != "" && p.config.BindPass != "" { |
|
||||
err = conn.Bind(p.config.BindDN, p.config.BindPass) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("failed to bind with service account: %v", err) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Search for user
|
|
||||
userFilter := fmt.Sprintf(p.config.UserFilter, EscapeFilter(username)) |
|
||||
searchRequest := &LDAPSearchRequest{ |
|
||||
BaseDN: p.config.BaseDN, |
|
||||
Scope: ScopeWholeSubtree, |
|
||||
DerefAliases: NeverDerefAliases, |
|
||||
SizeLimit: 0, |
|
||||
TimeLimit: 0, |
|
||||
TypesOnly: false, |
|
||||
Filter: userFilter, |
|
||||
Attributes: p.getSearchAttributes(), |
|
||||
} |
|
||||
|
|
||||
searchResult, err := conn.Search(searchRequest) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("LDAP search failed: %v", err) |
|
||||
} |
|
||||
|
|
||||
if len(searchResult.Entries) == 0 { |
|
||||
return nil, fmt.Errorf("user not found in LDAP: %s", username) |
|
||||
} |
|
||||
|
|
||||
if len(searchResult.Entries) > 1 { |
|
||||
return nil, fmt.Errorf("multiple users found for username: %s", username) |
|
||||
} |
|
||||
|
|
||||
userEntry := searchResult.Entries[0] |
|
||||
userDN := userEntry.DN |
|
||||
|
|
||||
// Authenticate user by binding with their credentials
|
|
||||
err = conn.Bind(userDN, password) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("LDAP authentication failed for user %s: %v", username, err) |
|
||||
} |
|
||||
|
|
||||
// Extract user attributes
|
|
||||
attributes := make(map[string][]string) |
|
||||
for _, attr := range userEntry.Attributes { |
|
||||
attributes[attr.Name] = attr.Values |
|
||||
} |
|
||||
|
|
||||
// Map to ExternalIdentity
|
|
||||
identity := p.mapLDAPAttributes(username, attributes) |
|
||||
identity.UserID = username |
|
||||
|
|
||||
// Get user groups if group filter is configured
|
|
||||
if p.config.GroupFilter != "" { |
|
||||
groups, err := p.getUserGroups(conn, userDN, username) |
|
||||
if err != nil { |
|
||||
glog.V(2).Infof("Failed to retrieve groups for user %s: %v", username, err) |
|
||||
} else { |
|
||||
identity.Groups = groups |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("LDAP authentication successful for user: %s", username) |
|
||||
return identity, nil |
|
||||
} |
|
||||
|
|
||||
// GetUserInfo retrieves user information from LDAP
|
|
||||
func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { |
|
||||
if !p.initialized { |
|
||||
return nil, fmt.Errorf("provider not initialized") |
|
||||
} |
|
||||
|
|
||||
if userID == "" { |
|
||||
return nil, fmt.Errorf("user ID cannot be empty") |
|
||||
} |
|
||||
|
|
||||
// Get connection from pool
|
|
||||
conn, err := p.getConnection() |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("failed to get LDAP connection: %v", err) |
|
||||
} |
|
||||
defer p.releaseConnection(conn) |
|
||||
|
|
||||
// Perform LDAP bind with service account if configured
|
|
||||
if p.config.BindDN != "" && p.config.BindPass != "" { |
|
||||
err = conn.Bind(p.config.BindDN, p.config.BindPass) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("failed to bind with service account: %v", err) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Search for user by userID using configured user filter
|
|
||||
userFilter := fmt.Sprintf(p.config.UserFilter, EscapeFilter(userID)) |
|
||||
searchRequest := &LDAPSearchRequest{ |
|
||||
BaseDN: p.config.BaseDN, |
|
||||
Scope: ScopeWholeSubtree, |
|
||||
DerefAliases: NeverDerefAliases, |
|
||||
SizeLimit: 1, // We only need one user
|
|
||||
TimeLimit: 30, // 30 second timeout
|
|
||||
TypesOnly: false, |
|
||||
Filter: userFilter, |
|
||||
Attributes: p.getSearchAttributes(), |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("Searching for user %s with filter: %s", userID, userFilter) |
|
||||
searchResult, err := conn.Search(searchRequest) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("LDAP user search failed: %v", err) |
|
||||
} |
|
||||
|
|
||||
if len(searchResult.Entries) == 0 { |
|
||||
return nil, fmt.Errorf("user not found in LDAP: %s", userID) |
|
||||
} |
|
||||
|
|
||||
if len(searchResult.Entries) > 1 { |
|
||||
glog.V(2).Infof("Multiple entries found for user %s, using first one", userID) |
|
||||
} |
|
||||
|
|
||||
userEntry := searchResult.Entries[0] |
|
||||
userDN := userEntry.DN |
|
||||
|
|
||||
glog.V(3).Infof("Found LDAP user: %s with DN: %s", userID, userDN) |
|
||||
|
|
||||
// Extract user attributes
|
|
||||
attributes := make(map[string][]string) |
|
||||
for _, attr := range userEntry.Attributes { |
|
||||
attributes[attr.Name] = attr.Values |
|
||||
} |
|
||||
|
|
||||
// Map to ExternalIdentity
|
|
||||
identity := p.mapLDAPAttributes(userID, attributes) |
|
||||
identity.UserID = userID |
|
||||
|
|
||||
// Get user groups if group filter is configured
|
|
||||
if p.config.GroupFilter != "" { |
|
||||
groups, err := p.getUserGroups(conn, userDN, userID) |
|
||||
if err != nil { |
|
||||
glog.V(2).Infof("Failed to retrieve groups for user %s: %v", userID, err) |
|
||||
} else { |
|
||||
identity.Groups = groups |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("Successfully retrieved user info for: %s", userID) |
|
||||
return identity, nil |
|
||||
} |
|
||||
|
|
||||
// ValidateToken validates credentials (for LDAP, this is username/password)
|
|
||||
func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { |
|
||||
if !p.initialized { |
|
||||
return nil, fmt.Errorf("provider not initialized") |
|
||||
} |
|
||||
|
|
||||
if token == "" { |
|
||||
return nil, fmt.Errorf("token cannot be empty") |
|
||||
} |
|
||||
|
|
||||
// Parse credentials (username:password format)
|
|
||||
parts := strings.SplitN(token, ":", 2) |
|
||||
if len(parts) != 2 { |
|
||||
return nil, fmt.Errorf("invalid token format (expected username:password)") |
|
||||
} |
|
||||
|
|
||||
username, password := parts[0], parts[1] |
|
||||
|
|
||||
// Get connection from pool
|
|
||||
conn, err := p.getConnection() |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("failed to get LDAP connection: %v", err) |
|
||||
} |
|
||||
defer p.releaseConnection(conn) |
|
||||
|
|
||||
// Perform LDAP bind with service account if configured
|
|
||||
if p.config.BindDN != "" && p.config.BindPass != "" { |
|
||||
err = conn.Bind(p.config.BindDN, p.config.BindPass) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("failed to bind with service account: %v", err) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Search for user using configured user filter
|
|
||||
userFilter := fmt.Sprintf(p.config.UserFilter, EscapeFilter(username)) |
|
||||
searchRequest := &LDAPSearchRequest{ |
|
||||
BaseDN: p.config.BaseDN, |
|
||||
Scope: ScopeWholeSubtree, |
|
||||
DerefAliases: NeverDerefAliases, |
|
||||
SizeLimit: 1, // We only need one user
|
|
||||
TimeLimit: 30, // 30 second timeout
|
|
||||
TypesOnly: false, |
|
||||
Filter: userFilter, |
|
||||
Attributes: p.getSearchAttributes(), |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("Validating credentials for user %s with filter: %s", username, userFilter) |
|
||||
searchResult, err := conn.Search(searchRequest) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("LDAP user search failed: %v", err) |
|
||||
} |
|
||||
|
|
||||
if len(searchResult.Entries) == 0 { |
|
||||
return nil, fmt.Errorf("user not found in LDAP: %s", username) |
|
||||
} |
|
||||
|
|
||||
if len(searchResult.Entries) > 1 { |
|
||||
glog.V(2).Infof("Multiple entries found for user %s, using first one", username) |
|
||||
} |
|
||||
|
|
||||
userEntry := searchResult.Entries[0] |
|
||||
userDN := userEntry.DN |
|
||||
|
|
||||
// Attempt to bind with user credentials to validate password
|
|
||||
err = conn.Bind(userDN, password) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("LDAP authentication failed for user %s: %v", username, err) |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("LDAP credential validation successful for user: %s", username) |
|
||||
|
|
||||
// Extract user claims (DN, attributes, group memberships)
|
|
||||
attributes := make(map[string][]string) |
|
||||
for _, attr := range userEntry.Attributes { |
|
||||
attributes[attr.Name] = attr.Values |
|
||||
} |
|
||||
|
|
||||
// Get user groups if group filter is configured
|
|
||||
var groups []string |
|
||||
if p.config.GroupFilter != "" { |
|
||||
groups, err = p.getUserGroups(conn, userDN, username) |
|
||||
if err != nil { |
|
||||
glog.V(2).Infof("Failed to retrieve groups for user %s: %v", username, err) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Return TokenClaims with LDAP-specific information
|
|
||||
claims := &providers.TokenClaims{ |
|
||||
Subject: username, |
|
||||
Issuer: p.name, |
|
||||
Claims: map[string]interface{}{ |
|
||||
"dn": userDN, |
|
||||
"provider": p.name, |
|
||||
"groups": groups, |
|
||||
"attributes": attributes, |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
return claims, nil |
|
||||
} |
|
||||
|
|
||||
// mapLDAPAttributes maps LDAP attributes to ExternalIdentity
|
|
||||
func (p *LDAPProvider) mapLDAPAttributes(userID string, attrs map[string][]string) *providers.ExternalIdentity { |
|
||||
identity := &providers.ExternalIdentity{ |
|
||||
UserID: userID, |
|
||||
Provider: p.name, |
|
||||
Attributes: make(map[string]string), |
|
||||
} |
|
||||
|
|
||||
// Map configured attributes
|
|
||||
for identityField, ldapAttr := range p.config.Attributes { |
|
||||
if values, exists := attrs[ldapAttr]; exists && len(values) > 0 { |
|
||||
switch identityField { |
|
||||
case "email": |
|
||||
identity.Email = values[0] |
|
||||
case "displayName": |
|
||||
identity.DisplayName = values[0] |
|
||||
case "groups": |
|
||||
identity.Groups = values |
|
||||
default: |
|
||||
// Store as custom attribute
|
|
||||
identity.Attributes[identityField] = values[0] |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
return identity |
|
||||
} |
|
||||
|
|
||||
// mapUserToRole maps user groups to roles based on role mapping rules
|
|
||||
func (p *LDAPProvider) mapUserToRole(identity *providers.ExternalIdentity) string { |
|
||||
if p.config.RoleMapping == nil { |
|
||||
return "" |
|
||||
} |
|
||||
|
|
||||
// Create token claims from identity for rule matching
|
|
||||
claims := &providers.TokenClaims{ |
|
||||
Subject: identity.UserID, |
|
||||
Claims: map[string]interface{}{ |
|
||||
"groups": identity.Groups, |
|
||||
"email": identity.Email, |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
// Check mapping rules
|
|
||||
for _, rule := range p.config.RoleMapping.Rules { |
|
||||
if rule.Matches(claims) { |
|
||||
return rule.Role |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Return default role if no rules match
|
|
||||
return p.config.RoleMapping.DefaultRole |
|
||||
} |
|
||||
|
|
||||
// Connection management methods (stubs for now)
|
|
||||
func (p *LDAPProvider) getConnectionPool() interface{} { |
|
||||
return p.connPool |
|
||||
} |
|
||||
|
|
||||
func (p *LDAPProvider) getConnection() (*LDAPConn, error) { |
|
||||
if p.connPool == nil { |
|
||||
return nil, fmt.Errorf("LDAP connection pool not initialized") |
|
||||
} |
|
||||
return p.connPool.GetConnection() |
|
||||
} |
|
||||
|
|
||||
func (p *LDAPProvider) releaseConnection(conn *LDAPConn) { |
|
||||
if p.connPool != nil && conn != nil { |
|
||||
p.connPool.ReleaseConnection(conn) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// getSearchAttributes returns the list of attributes to retrieve
|
|
||||
func (p *LDAPProvider) getSearchAttributes() []string { |
|
||||
attrs := make([]string, 0, len(p.config.Attributes)+1) |
|
||||
attrs = append(attrs, "dn") // Always include DN
|
|
||||
|
|
||||
for _, ldapAttr := range p.config.Attributes { |
|
||||
attrs = append(attrs, ldapAttr) |
|
||||
} |
|
||||
|
|
||||
return attrs |
|
||||
} |
|
||||
|
|
||||
// getUserGroups retrieves user groups using the configured group filter
|
|
||||
func (p *LDAPProvider) getUserGroups(conn *LDAPConn, userDN, username string) ([]string, error) { |
|
||||
// Try different group search approaches
|
|
||||
|
|
||||
// 1. Search by member DN
|
|
||||
groupFilter := fmt.Sprintf(p.config.GroupFilter, EscapeFilter(userDN)) |
|
||||
groups, err := p.searchGroups(conn, groupFilter) |
|
||||
if err == nil && len(groups) > 0 { |
|
||||
return groups, nil |
|
||||
} |
|
||||
|
|
||||
// 2. Search by username if DN search fails
|
|
||||
groupFilter = fmt.Sprintf(p.config.GroupFilter, EscapeFilter(username)) |
|
||||
groups, err = p.searchGroups(conn, groupFilter) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
|
|
||||
return groups, nil |
|
||||
} |
|
||||
|
|
||||
// searchGroups performs the actual group search
|
|
||||
func (p *LDAPProvider) searchGroups(conn *LDAPConn, filter string) ([]string, error) { |
|
||||
searchRequest := &LDAPSearchRequest{ |
|
||||
BaseDN: p.config.BaseDN, |
|
||||
Scope: ScopeWholeSubtree, |
|
||||
DerefAliases: NeverDerefAliases, |
|
||||
SizeLimit: 0, |
|
||||
TimeLimit: 0, |
|
||||
TypesOnly: false, |
|
||||
Filter: filter, |
|
||||
Attributes: []string{"cn", "dn"}, |
|
||||
} |
|
||||
|
|
||||
searchResult, err := conn.Search(searchRequest) |
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("group search failed: %v", err) |
|
||||
} |
|
||||
|
|
||||
groups := make([]string, 0, len(searchResult.Entries)) |
|
||||
for _, entry := range searchResult.Entries { |
|
||||
// Try to get CN first, fall back to DN
|
|
||||
if cn := entry.GetAttributeValue("cn"); cn != "" { |
|
||||
groups = append(groups, cn) |
|
||||
} else { |
|
||||
groups = append(groups, entry.DN) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
return groups, nil |
|
||||
} |
|
||||
|
|
||||
// NewLDAPConnectionPool creates a new LDAP connection pool
|
|
||||
func NewLDAPConnectionPool(config *LDAPConfig) (*LDAPConnectionPool, error) { |
|
||||
maxConns := config.MaxConnections |
|
||||
if maxConns <= 0 { |
|
||||
maxConns = 10 |
|
||||
} |
|
||||
|
|
||||
pool := &LDAPConnectionPool{ |
|
||||
config: config, |
|
||||
connections: make(chan *LDAPConn, maxConns), |
|
||||
maxConns: maxConns, |
|
||||
} |
|
||||
|
|
||||
// Pre-populate the pool with a few connections for testing
|
|
||||
for i := 0; i < 2 && i < maxConns; i++ { |
|
||||
conn, err := pool.createConnection() |
|
||||
if err != nil { |
|
||||
// If we can't create any connections, return error
|
|
||||
if i == 0 { |
|
||||
return nil, err |
|
||||
} |
|
||||
// If we created at least one, continue
|
|
||||
break |
|
||||
} |
|
||||
pool.connections <- conn |
|
||||
} |
|
||||
|
|
||||
return pool, nil |
|
||||
} |
|
||||
|
|
||||
// createConnection creates a new LDAP connection
|
|
||||
func (pool *LDAPConnectionPool) createConnection() (*LDAPConn, error) { |
|
||||
var netConn net.Conn |
|
||||
var err error |
|
||||
|
|
||||
timeout := time.Duration(pool.config.ConnTimeout) * time.Second |
|
||||
|
|
||||
// Parse server address
|
|
||||
serverAddr := pool.config.Server |
|
||||
if strings.HasPrefix(serverAddr, "ldap://") { |
|
||||
serverAddr = strings.TrimPrefix(serverAddr, "ldap://") |
|
||||
} else if strings.HasPrefix(serverAddr, "ldaps://") { |
|
||||
serverAddr = strings.TrimPrefix(serverAddr, "ldaps://") |
|
||||
} |
|
||||
|
|
||||
// Add default port if not specified
|
|
||||
if !strings.Contains(serverAddr, ":") { |
|
||||
if strings.HasPrefix(pool.config.Server, "ldaps://") { |
|
||||
serverAddr += ":636" |
|
||||
} else { |
|
||||
serverAddr += ":389" |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
if strings.HasPrefix(pool.config.Server, "ldaps://") { |
|
||||
// LDAPS connection
|
|
||||
tlsConfig := &tls.Config{ |
|
||||
InsecureSkipVerify: pool.config.TLSSkipVerify, |
|
||||
} |
|
||||
dialer := &net.Dialer{Timeout: timeout} |
|
||||
netConn, err = tls.DialWithDialer(dialer, "tcp", serverAddr, tlsConfig) |
|
||||
} else { |
|
||||
// Plain LDAP connection
|
|
||||
netConn, err = net.DialTimeout("tcp", serverAddr, timeout) |
|
||||
} |
|
||||
|
|
||||
if err != nil { |
|
||||
return nil, fmt.Errorf("failed to connect to LDAP server %s: %v", pool.config.Server, err) |
|
||||
} |
|
||||
|
|
||||
conn := &LDAPConn{ |
|
||||
serverAddr: serverAddr, |
|
||||
conn: netConn, |
|
||||
bound: false, |
|
||||
tlsConfig: &tls.Config{ |
|
||||
InsecureSkipVerify: pool.config.TLSSkipVerify, |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
// Start TLS if configured and not already using LDAPS
|
|
||||
if pool.config.UseTLS && !strings.HasPrefix(pool.config.Server, "ldaps://") { |
|
||||
err = conn.StartTLS(conn.tlsConfig) |
|
||||
if err != nil { |
|
||||
conn.Close() |
|
||||
return nil, fmt.Errorf("failed to start TLS: %v", err) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
return conn, nil |
|
||||
} |
|
||||
|
|
||||
// GetConnection retrieves a connection from the pool
|
|
||||
func (pool *LDAPConnectionPool) GetConnection() (*LDAPConn, error) { |
|
||||
select { |
|
||||
case conn := <-pool.connections: |
|
||||
// Test if connection is still valid
|
|
||||
if pool.isConnectionValid(conn) { |
|
||||
return conn, nil |
|
||||
} |
|
||||
// Connection is stale, close it and create a new one
|
|
||||
conn.Close() |
|
||||
default: |
|
||||
// No connection available in pool
|
|
||||
} |
|
||||
|
|
||||
// Create a new connection
|
|
||||
return pool.createConnection() |
|
||||
} |
|
||||
|
|
||||
// ReleaseConnection returns a connection to the pool
|
|
||||
func (pool *LDAPConnectionPool) ReleaseConnection(conn *LDAPConn) { |
|
||||
if conn == nil { |
|
||||
return |
|
||||
} |
|
||||
|
|
||||
select { |
|
||||
case pool.connections <- conn: |
|
||||
// Successfully returned to pool
|
|
||||
default: |
|
||||
// Pool is full, close the connection
|
|
||||
conn.Close() |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// isConnectionValid tests if a connection is still valid
|
|
||||
func (pool *LDAPConnectionPool) isConnectionValid(conn *LDAPConn) bool { |
|
||||
// Simple test: check if underlying connection is still open
|
|
||||
if conn == nil || conn.conn == nil { |
|
||||
return false |
|
||||
} |
|
||||
|
|
||||
// Try to perform a simple operation to test connectivity
|
|
||||
searchRequest := &LDAPSearchRequest{ |
|
||||
BaseDN: "", |
|
||||
Scope: ScopeBaseObject, |
|
||||
DerefAliases: NeverDerefAliases, |
|
||||
SizeLimit: 0, |
|
||||
TimeLimit: 0, |
|
||||
TypesOnly: false, |
|
||||
Filter: "(objectClass=*)", |
|
||||
Attributes: []string{"1.1"}, // Minimal attributes
|
|
||||
} |
|
||||
|
|
||||
_, err := conn.Search(searchRequest) |
|
||||
return err == nil |
|
||||
} |
|
||||
|
|
||||
// Close closes all connections in the pool
|
|
||||
func (pool *LDAPConnectionPool) Close() { |
|
||||
pool.mu.Lock() |
|
||||
defer pool.mu.Unlock() |
|
||||
|
|
||||
close(pool.connections) |
|
||||
for conn := range pool.connections { |
|
||||
conn.Close() |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Helper functions and LDAP connection methods
|
|
||||
|
|
||||
// EscapeFilter escapes special characters in LDAP filter values
|
|
||||
func EscapeFilter(filter string) string { |
|
||||
// Basic LDAP filter escaping
|
|
||||
filter = strings.ReplaceAll(filter, "\\", "\\5c") |
|
||||
filter = strings.ReplaceAll(filter, "*", "\\2a") |
|
||||
filter = strings.ReplaceAll(filter, "(", "\\28") |
|
||||
filter = strings.ReplaceAll(filter, ")", "\\29") |
|
||||
filter = strings.ReplaceAll(filter, "/", "\\2f") |
|
||||
filter = strings.ReplaceAll(filter, "=", "\\3d") |
|
||||
return filter |
|
||||
} |
|
||||
|
|
||||
// LDAPConn methods
|
|
||||
|
|
||||
// Bind performs an LDAP bind operation
|
|
||||
func (conn *LDAPConn) Bind(bindDN, bindPassword string) error { |
|
||||
if conn == nil || conn.conn == nil { |
|
||||
return fmt.Errorf("connection is nil") |
|
||||
} |
|
||||
|
|
||||
// In a real implementation, this would send an LDAP bind request
|
|
||||
// For now, we simulate the bind operation
|
|
||||
glog.V(3).Infof("LDAP Bind attempt for DN: %s", bindDN) |
|
||||
|
|
||||
// Simple validation
|
|
||||
if bindDN == "" { |
|
||||
return fmt.Errorf("bind DN cannot be empty") |
|
||||
} |
|
||||
|
|
||||
// Simulate bind success for valid credentials
|
|
||||
if bindPassword != "" { |
|
||||
conn.bound = true |
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
return fmt.Errorf("invalid credentials") |
|
||||
} |
|
||||
|
|
||||
// Search performs an LDAP search operation
|
|
||||
func (conn *LDAPConn) Search(searchRequest *LDAPSearchRequest) (*LDAPSearchResult, error) { |
|
||||
if conn == nil || conn.conn == nil { |
|
||||
return nil, fmt.Errorf("connection is nil") |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("LDAP Search - BaseDN: %s, Filter: %s", searchRequest.BaseDN, searchRequest.Filter) |
|
||||
|
|
||||
// In a real implementation, this would send an LDAP search request
|
|
||||
// For now, we simulate a search operation
|
|
||||
result := &LDAPSearchResult{ |
|
||||
Entries: []*LDAPEntry{}, |
|
||||
} |
|
||||
|
|
||||
// Simulate finding a test user for certain searches
|
|
||||
if strings.Contains(searchRequest.Filter, "testuser") || strings.Contains(searchRequest.Filter, "admin") { |
|
||||
entry := &LDAPEntry{ |
|
||||
DN: fmt.Sprintf("uid=%s,%s", "testuser", searchRequest.BaseDN), |
|
||||
Attributes: []*LDAPAttribute{ |
|
||||
{Name: "uid", Values: []string{"testuser"}}, |
|
||||
{Name: "mail", Values: []string{"testuser@example.com"}}, |
|
||||
{Name: "cn", Values: []string{"Test User"}}, |
|
||||
{Name: "memberOf", Values: []string{"cn=users,ou=groups," + searchRequest.BaseDN}}, |
|
||||
}, |
|
||||
} |
|
||||
result.Entries = append(result.Entries, entry) |
|
||||
} |
|
||||
|
|
||||
return result, nil |
|
||||
} |
|
||||
|
|
||||
// Close closes the LDAP connection
|
|
||||
func (conn *LDAPConn) Close() error { |
|
||||
if conn != nil && conn.conn != nil { |
|
||||
return conn.conn.Close() |
|
||||
} |
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// StartTLS starts TLS on the connection
|
|
||||
func (conn *LDAPConn) StartTLS(config *tls.Config) error { |
|
||||
if conn == nil || conn.conn == nil { |
|
||||
return fmt.Errorf("connection is nil") |
|
||||
} |
|
||||
|
|
||||
// In a real implementation, this would upgrade the connection to TLS
|
|
||||
glog.V(3).Info("LDAP StartTLS operation") |
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// LDAPEntry methods
|
|
||||
|
|
||||
// GetAttributeValue returns the first value of the specified attribute
|
|
||||
func (entry *LDAPEntry) GetAttributeValue(attrName string) string { |
|
||||
for _, attr := range entry.Attributes { |
|
||||
if attr.Name == attrName && len(attr.Values) > 0 { |
|
||||
return attr.Values[0] |
|
||||
} |
|
||||
} |
|
||||
return "" |
|
||||
} |
|
||||
@ -1,360 +0,0 @@ |
|||||
package ldap |
|
||||
|
|
||||
import ( |
|
||||
"context" |
|
||||
"fmt" |
|
||||
"testing" |
|
||||
|
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|
||||
"github.com/stretchr/testify/assert" |
|
||||
"github.com/stretchr/testify/require" |
|
||||
) |
|
||||
|
|
||||
// TestLDAPProviderInitialization tests LDAP provider initialization
|
|
||||
func TestLDAPProviderInitialization(t *testing.T) { |
|
||||
tests := []struct { |
|
||||
name string |
|
||||
config *LDAPConfig |
|
||||
wantErr bool |
|
||||
}{ |
|
||||
{ |
|
||||
name: "valid config", |
|
||||
config: &LDAPConfig{ |
|
||||
Server: "ldap://localhost:389", |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
BindDN: "CN=admin,DC=example,DC=com", |
|
||||
BindPass: "password", |
|
||||
UserFilter: "(sAMAccountName=%s)", |
|
||||
GroupFilter: "(member=%s)", |
|
||||
}, |
|
||||
wantErr: false, |
|
||||
}, |
|
||||
{ |
|
||||
name: "missing server", |
|
||||
config: &LDAPConfig{ |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
}, |
|
||||
wantErr: true, |
|
||||
}, |
|
||||
{ |
|
||||
name: "missing base DN", |
|
||||
config: &LDAPConfig{ |
|
||||
Server: "ldap://localhost:389", |
|
||||
}, |
|
||||
wantErr: true, |
|
||||
}, |
|
||||
{ |
|
||||
name: "invalid server URL", |
|
||||
config: &LDAPConfig{ |
|
||||
Server: "invalid-url", |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
}, |
|
||||
wantErr: true, |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
for _, tt := range tests { |
|
||||
t.Run(tt.name, func(t *testing.T) { |
|
||||
provider := NewLDAPProvider("test-ldap") |
|
||||
|
|
||||
err := provider.Initialize(tt.config) |
|
||||
|
|
||||
if tt.wantErr { |
|
||||
assert.Error(t, err) |
|
||||
} else { |
|
||||
assert.NoError(t, err) |
|
||||
assert.Equal(t, "test-ldap", provider.Name()) |
|
||||
} |
|
||||
}) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// TestLDAPProviderAuthentication tests LDAP authentication
|
|
||||
func TestLDAPProviderAuthentication(t *testing.T) { |
|
||||
// Skip if no LDAP test server available
|
|
||||
if testing.Short() { |
|
||||
t.Skip("Skipping LDAP integration test in short mode") |
|
||||
} |
|
||||
|
|
||||
provider := NewLDAPProvider("test-ldap") |
|
||||
config := &LDAPConfig{ |
|
||||
Server: "ldap://localhost:389", |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
BindDN: "CN=admin,DC=example,DC=com", |
|
||||
BindPass: "password", |
|
||||
UserFilter: "(sAMAccountName=%s)", |
|
||||
GroupFilter: "(member=%s)", |
|
||||
Attributes: map[string]string{ |
|
||||
"email": "mail", |
|
||||
"displayName": "displayName", |
|
||||
"groups": "memberOf", |
|
||||
}, |
|
||||
RoleMapping: &providers.RoleMapping{ |
|
||||
Rules: []providers.MappingRule{ |
|
||||
{ |
|
||||
Claim: "groups", |
|
||||
Value: "*CN=Admins*", |
|
||||
Role: "arn:seaweed:iam::role/AdminRole", |
|
||||
}, |
|
||||
{ |
|
||||
Claim: "groups", |
|
||||
Value: "*CN=Users*", |
|
||||
Role: "arn:seaweed:iam::role/UserRole", |
|
||||
}, |
|
||||
}, |
|
||||
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
err := provider.Initialize(config) |
|
||||
require.NoError(t, err) |
|
||||
|
|
||||
t.Run("authenticate with username/password", func(t *testing.T) { |
|
||||
// This would require an actual LDAP server for integration testing
|
|
||||
credentials := "user:password" // Basic auth format
|
|
||||
|
|
||||
identity, err := provider.Authenticate(context.Background(), credentials) |
|
||||
if err != nil { |
|
||||
t.Skip("LDAP server not available for testing") |
|
||||
} |
|
||||
|
|
||||
assert.NoError(t, err) |
|
||||
assert.Equal(t, "user", identity.UserID) |
|
||||
assert.Equal(t, "test-ldap", identity.Provider) |
|
||||
assert.NotEmpty(t, identity.Email) |
|
||||
}) |
|
||||
|
|
||||
t.Run("authenticate with invalid credentials", func(t *testing.T) { |
|
||||
_, err := provider.Authenticate(context.Background(), "invalid:credentials") |
|
||||
assert.Error(t, err) |
|
||||
}) |
|
||||
} |
|
||||
|
|
||||
// TestLDAPProviderUserInfo tests LDAP user info retrieval
|
|
||||
func TestLDAPProviderUserInfo(t *testing.T) { |
|
||||
if testing.Short() { |
|
||||
t.Skip("Skipping LDAP integration test in short mode") |
|
||||
} |
|
||||
|
|
||||
provider := NewLDAPProvider("test-ldap") |
|
||||
config := &LDAPConfig{ |
|
||||
Server: "ldap://localhost:389", |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
BindDN: "CN=admin,DC=example,DC=com", |
|
||||
BindPass: "password", |
|
||||
UserFilter: "(sAMAccountName=%s)", |
|
||||
Attributes: map[string]string{ |
|
||||
"email": "mail", |
|
||||
"displayName": "displayName", |
|
||||
"department": "department", |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
err := provider.Initialize(config) |
|
||||
require.NoError(t, err) |
|
||||
|
|
||||
t.Run("get user info", func(t *testing.T) { |
|
||||
identity, err := provider.GetUserInfo(context.Background(), "testuser") |
|
||||
if err != nil { |
|
||||
t.Skip("LDAP server not available for testing") |
|
||||
} |
|
||||
|
|
||||
assert.NoError(t, err) |
|
||||
assert.Equal(t, "testuser", identity.UserID) |
|
||||
assert.Equal(t, "test-ldap", identity.Provider) |
|
||||
assert.NotEmpty(t, identity.Email) |
|
||||
assert.NotEmpty(t, identity.DisplayName) |
|
||||
}) |
|
||||
|
|
||||
t.Run("get user info with empty username", func(t *testing.T) { |
|
||||
_, err := provider.GetUserInfo(context.Background(), "") |
|
||||
assert.Error(t, err) |
|
||||
}) |
|
||||
|
|
||||
t.Run("get user info for non-existent user", func(t *testing.T) { |
|
||||
_, err := provider.GetUserInfo(context.Background(), "nonexistent") |
|
||||
assert.Error(t, err) |
|
||||
}) |
|
||||
} |
|
||||
|
|
||||
// TestLDAPAttributeMapping tests LDAP attribute mapping
|
|
||||
func TestLDAPAttributeMapping(t *testing.T) { |
|
||||
provider := NewLDAPProvider("test-ldap") |
|
||||
config := &LDAPConfig{ |
|
||||
Server: "ldap://localhost:389", |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
Attributes: map[string]string{ |
|
||||
"email": "mail", |
|
||||
"displayName": "cn", |
|
||||
"department": "departmentNumber", |
|
||||
"groups": "memberOf", |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
err := provider.Initialize(config) |
|
||||
require.NoError(t, err) |
|
||||
|
|
||||
t.Run("map LDAP attributes to identity", func(t *testing.T) { |
|
||||
ldapAttrs := map[string][]string{ |
|
||||
"mail": {"user@example.com"}, |
|
||||
"cn": {"John Doe"}, |
|
||||
"departmentNumber": {"IT"}, |
|
||||
"memberOf": { |
|
||||
"CN=Users,OU=Groups,DC=example,DC=com", |
|
||||
"CN=Developers,OU=Groups,DC=example,DC=com", |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
identity := provider.mapLDAPAttributes("testuser", ldapAttrs) |
|
||||
|
|
||||
assert.Equal(t, "testuser", identity.UserID) |
|
||||
assert.Equal(t, "user@example.com", identity.Email) |
|
||||
assert.Equal(t, "John Doe", identity.DisplayName) |
|
||||
assert.Equal(t, "test-ldap", identity.Provider) |
|
||||
|
|
||||
// Check groups
|
|
||||
assert.Contains(t, identity.Groups, "CN=Users,OU=Groups,DC=example,DC=com") |
|
||||
assert.Contains(t, identity.Groups, "CN=Developers,OU=Groups,DC=example,DC=com") |
|
||||
|
|
||||
// Check attributes
|
|
||||
assert.Equal(t, "IT", identity.Attributes["department"]) |
|
||||
}) |
|
||||
} |
|
||||
|
|
||||
// TestLDAPGroupFiltering tests LDAP group filtering and role mapping
|
|
||||
func TestLDAPGroupFiltering(t *testing.T) { |
|
||||
provider := NewLDAPProvider("test-ldap") |
|
||||
config := &LDAPConfig{ |
|
||||
Server: "ldap://localhost:389", |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
RoleMapping: &providers.RoleMapping{ |
|
||||
Rules: []providers.MappingRule{ |
|
||||
{ |
|
||||
Claim: "groups", |
|
||||
Value: "*Admins*", |
|
||||
Role: "arn:seaweed:iam::role/AdminRole", |
|
||||
}, |
|
||||
{ |
|
||||
Claim: "groups", |
|
||||
Value: "*Users*", |
|
||||
Role: "arn:seaweed:iam::role/UserRole", |
|
||||
}, |
|
||||
}, |
|
||||
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
err := provider.Initialize(config) |
|
||||
require.NoError(t, err) |
|
||||
|
|
||||
tests := []struct { |
|
||||
name string |
|
||||
groups []string |
|
||||
expectedRole string |
|
||||
expectedClaims map[string]interface{} |
|
||||
}{ |
|
||||
{ |
|
||||
name: "admin user", |
|
||||
groups: []string{"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
|
||||
expectedRole: "arn:seaweed:iam::role/AdminRole", |
|
||||
}, |
|
||||
{ |
|
||||
name: "regular user", |
|
||||
groups: []string{"CN=Users,OU=Groups,DC=example,DC=com"}, |
|
||||
expectedRole: "arn:seaweed:iam::role/UserRole", |
|
||||
}, |
|
||||
{ |
|
||||
name: "guest user", |
|
||||
groups: []string{"CN=Guests,OU=Groups,DC=example,DC=com"}, |
|
||||
expectedRole: "arn:seaweed:iam::role/GuestRole", |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
for _, tt := range tests { |
|
||||
t.Run(tt.name, func(t *testing.T) { |
|
||||
identity := &providers.ExternalIdentity{ |
|
||||
UserID: "testuser", |
|
||||
Groups: tt.groups, |
|
||||
Provider: "test-ldap", |
|
||||
} |
|
||||
|
|
||||
role := provider.mapUserToRole(identity) |
|
||||
assert.Equal(t, tt.expectedRole, role) |
|
||||
}) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// TestLDAPConnectionPool tests LDAP connection pooling
|
|
||||
func TestLDAPConnectionPool(t *testing.T) { |
|
||||
if testing.Short() { |
|
||||
t.Skip("Skipping LDAP connection pool test in short mode") |
|
||||
} |
|
||||
|
|
||||
provider := NewLDAPProvider("test-ldap") |
|
||||
config := &LDAPConfig{ |
|
||||
Server: "ldap://localhost:389", |
|
||||
BaseDN: "DC=example,DC=com", |
|
||||
BindDN: "CN=admin,DC=example,DC=com", |
|
||||
BindPass: "password", |
|
||||
MaxConnections: 5, |
|
||||
ConnTimeout: 30, |
|
||||
} |
|
||||
|
|
||||
err := provider.Initialize(config) |
|
||||
require.NoError(t, err) |
|
||||
|
|
||||
t.Run("connection pool management", func(t *testing.T) { |
|
||||
// Test that multiple concurrent requests work
|
|
||||
// This would require actual LDAP server for full testing
|
|
||||
pool := provider.getConnectionPool() |
|
||||
|
|
||||
// In CI environments where no LDAP server is available, pool might be nil
|
|
||||
// Skip the test if we can't establish a connection
|
|
||||
conn, err := provider.getConnection() |
|
||||
if err != nil { |
|
||||
t.Skip("LDAP server not available - skipping connection pool test") |
|
||||
return |
|
||||
} |
|
||||
|
|
||||
// Only test if we successfully got a connection
|
|
||||
assert.NotNil(t, pool) |
|
||||
assert.NotNil(t, conn) |
|
||||
provider.releaseConnection(conn) |
|
||||
}) |
|
||||
} |
|
||||
|
|
||||
// MockLDAPServer for unit testing (without external dependencies)
|
|
||||
type MockLDAPServer struct { |
|
||||
users map[string]map[string][]string |
|
||||
} |
|
||||
|
|
||||
func NewMockLDAPServer() *MockLDAPServer { |
|
||||
return &MockLDAPServer{ |
|
||||
users: map[string]map[string][]string{ |
|
||||
"testuser": { |
|
||||
"mail": {"testuser@example.com"}, |
|
||||
"cn": {"Test User"}, |
|
||||
"department": {"Engineering"}, |
|
||||
"memberOf": {"CN=Users,OU=Groups,DC=example,DC=com"}, |
|
||||
}, |
|
||||
"admin": { |
|
||||
"mail": {"admin@example.com"}, |
|
||||
"cn": {"Administrator"}, |
|
||||
"department": {"IT"}, |
|
||||
"memberOf": {"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
|
||||
}, |
|
||||
}, |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func (m *MockLDAPServer) Authenticate(username, password string) bool { |
|
||||
_, exists := m.users[username] |
|
||||
return exists && password == "password" // Mock authentication
|
|
||||
} |
|
||||
|
|
||||
func (m *MockLDAPServer) GetUserAttributes(username string) (map[string][]string, error) { |
|
||||
if attrs, exists := m.users[username]; exists { |
|
||||
return attrs, nil |
|
||||
} |
|
||||
return nil, fmt.Errorf("user not found: %s", username) |
|
||||
} |
|
||||
@ -1,356 +0,0 @@ |
|||||
# Runtime Filer Address Implementation |
|
||||
|
|
||||
This document describes the implementation of runtime filer address passing for the STSService, addressing the requirement that filer addresses should be passed at call-time rather than initialization time. |
|
||||
|
|
||||
## Problem Statement |
|
||||
|
|
||||
The user identified a critical issue with the original STS implementation: |
|
||||
|
|
||||
> "the filer address should be passed when called, not during init time, since the filer may change." |
|
||||
|
|
||||
This is important because: |
|
||||
1. **Filer Failover**: Filer addresses can change during runtime due to failover scenarios |
|
||||
2. **Load Balancing**: Different requests may need to hit different filer instances |
|
||||
3. **Environment Agnostic**: Configuration files should work across dev/staging/prod without hardcoded addresses |
|
||||
4. **SeaweedFS Patterns**: Follows existing SeaweedFS patterns used throughout the codebase |
|
||||
|
|
||||
## Implementation Changes |
|
||||
|
|
||||
### 1. SessionStore Interface Refactoring |
|
||||
|
|
||||
**Before:** |
|
||||
```go |
|
||||
type SessionStore interface { |
|
||||
StoreSession(ctx context.Context, sessionId string, session *SessionInfo) error |
|
||||
GetSession(ctx context.Context, sessionId string) (*SessionInfo, error) |
|
||||
RevokeSession(ctx context.Context, sessionId string) error |
|
||||
CleanupExpiredSessions(ctx context.Context) error |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
**After:** |
|
||||
```go |
|
||||
type SessionStore interface { |
|
||||
// filerAddress ignored for memory stores, required for filer stores |
|
||||
StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error |
|
||||
GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) |
|
||||
RevokeSession(ctx context.Context, filerAddress string, sessionId string) error |
|
||||
CleanupExpiredSessions(ctx context.Context, filerAddress string) error |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
### 2. FilerSessionStore Changes |
|
||||
|
|
||||
**Before:** |
|
||||
```go |
|
||||
type FilerSessionStore struct { |
|
||||
filerGrpcAddress string // ❌ Fixed at init time |
|
||||
grpcDialOption grpc.DialOption |
|
||||
basePath string |
|
||||
} |
|
||||
|
|
||||
func NewFilerSessionStore(filerAddress string, config map[string]interface{}) (*FilerSessionStore, error) { |
|
||||
store := &FilerSessionStore{ |
|
||||
filerGrpcAddress: filerAddress, // ❌ Locked in during init |
|
||||
basePath: DefaultSessionBasePath, |
|
||||
} |
|
||||
// ... |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
**After:** |
|
||||
```go |
|
||||
type FilerSessionStore struct { |
|
||||
grpcDialOption grpc.DialOption // ✅ No fixed filer address |
|
||||
basePath string |
|
||||
} |
|
||||
|
|
||||
func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) { |
|
||||
store := &FilerSessionStore{ |
|
||||
basePath: DefaultSessionBasePath, // ✅ Only path configuration |
|
||||
} |
|
||||
// ✅ filerAddress passed at call time |
|
||||
} |
|
||||
|
|
||||
func (f *FilerSessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|
||||
// ✅ filerAddress provided per call |
|
||||
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|
||||
// ... store logic |
|
||||
}) |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
### 3. STS Service Method Signatures |
|
||||
|
|
||||
**Before:** |
|
||||
```go |
|
||||
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) |
|
||||
func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) |
|
||||
func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) error |
|
||||
``` |
|
||||
|
|
||||
**After:** |
|
||||
```go |
|
||||
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, filerAddress string, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) |
|
||||
func (s *STSService) ValidateSessionToken(ctx context.Context, filerAddress string, sessionToken string) (*SessionInfo, error) |
|
||||
func (s *STSService) RevokeSession(ctx context.Context, filerAddress string, sessionToken string) error |
|
||||
``` |
|
||||
|
|
||||
### 4. Configuration Cleanup |
|
||||
|
|
||||
**Before (iam_config_distributed.json):** |
|
||||
```json |
|
||||
{ |
|
||||
"sts": { |
|
||||
"sessionStoreConfig": { |
|
||||
"filerAddress": "localhost:8888", // ❌ Environment-specific |
|
||||
"basePath": "/etc/iam/sessions" |
|
||||
} |
|
||||
}, |
|
||||
"policy": { |
|
||||
"storeConfig": { |
|
||||
"filerAddress": "localhost:8888", // ❌ Environment-specific |
|
||||
"basePath": "/etc/iam/policies" |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
**After (iam_config_distributed.json):** |
|
||||
```json |
|
||||
{ |
|
||||
"sts": { |
|
||||
"sessionStoreConfig": { |
|
||||
"basePath": "/etc/iam/sessions" // ✅ Environment-agnostic |
|
||||
} |
|
||||
}, |
|
||||
"policy": { |
|
||||
"storeConfig": { |
|
||||
"basePath": "/etc/iam/policies" // ✅ Environment-agnostic |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
## Usage Examples |
|
||||
|
|
||||
### Caller Perspective (S3 API Server) |
|
||||
|
|
||||
**Before:** |
|
||||
```go |
|
||||
// STS service locked to specific filer during init |
|
||||
stsService.Initialize(&STSConfig{ |
|
||||
SessionStoreConfig: map[string]interface{}{ |
|
||||
"filerAddress": "filer-1:8888", // ❌ Fixed choice |
|
||||
"basePath": "/etc/iam/sessions", |
|
||||
}, |
|
||||
}) |
|
||||
|
|
||||
// All calls go to filer-1, no failover possible |
|
||||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, request) |
|
||||
``` |
|
||||
|
|
||||
**After:** |
|
||||
```go |
|
||||
// STS service configured without specific filer |
|
||||
stsService.Initialize(&STSConfig{ |
|
||||
SessionStoreConfig: map[string]interface{}{ |
|
||||
"basePath": "/etc/iam/sessions", // ✅ Just the path |
|
||||
}, |
|
||||
}) |
|
||||
|
|
||||
// Caller determines filer address per request |
|
||||
currentFiler := s.getCurrentFilerAddress() // ✅ Dynamic selection |
|
||||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, currentFiler, request) |
|
||||
``` |
|
||||
|
|
||||
### Dynamic Filer Selection |
|
||||
|
|
||||
```go |
|
||||
type S3ApiServer struct { |
|
||||
stsService *sts.STSService |
|
||||
filerClient *filer.Client |
|
||||
} |
|
||||
|
|
||||
func (s *S3ApiServer) getCurrentFilerAddress() string { |
|
||||
// ✅ Can implement any strategy: |
|
||||
// - Load balancing across multiple filers |
|
||||
// - Health checking and failover |
|
||||
// - Geographic routing |
|
||||
// - Round-robin selection |
|
||||
return s.filerClient.GetAvailableFiler() |
|
||||
} |
|
||||
|
|
||||
func (s *S3ApiServer) handleAssumeRole(ctx context.Context, request *AssumeRoleRequest) { |
|
||||
// ✅ Filer address determined at request time |
|
||||
filerAddr := s.getCurrentFilerAddress() |
|
||||
|
|
||||
response, err := s.stsService.AssumeRoleWithWebIdentity(ctx, filerAddr, request) |
|
||||
if err != nil && isNetworkError(err) { |
|
||||
// ✅ Retry with different filer |
|
||||
filerAddr = s.getBackupFilerAddress() |
|
||||
response, err = s.stsService.AssumeRoleWithWebIdentity(ctx, filerAddr, request) |
|
||||
} |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
## Memory Store Compatibility |
|
||||
|
|
||||
The `MemorySessionStore` accepts the `filerAddress` parameter but ignores it, maintaining interface consistency: |
|
||||
|
|
||||
```go |
|
||||
func (m *MemorySessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|
||||
// filerAddress ignored for memory store - maintains interface compatibility |
|
||||
if sessionId == "" { |
|
||||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
// ... in-memory storage logic |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
## Benefits Achieved |
|
||||
|
|
||||
### 1. **Dynamic Filer Selection** |
|
||||
```go |
|
||||
// Load balancing |
|
||||
filerAddr := loadBalancer.GetNextFiler() |
|
||||
|
|
||||
// Failover support |
|
||||
filerAddr := failoverManager.GetHealthyFiler() |
|
||||
|
|
||||
// Geographic routing |
|
||||
filerAddr := geoRouter.GetClosestFiler(clientIP) |
|
||||
``` |
|
||||
|
|
||||
### 2. **Environment Portability** |
|
||||
```bash |
|
||||
# Same config works everywhere |
|
||||
dev: STSService.method(ctx, "dev-filer:8888", ...) |
|
||||
staging: STSService.method(ctx, "staging-filer:8888", ...) |
|
||||
prod: STSService.method(ctx, "prod-filer-lb:8888", ...) |
|
||||
``` |
|
||||
|
|
||||
### 3. **Operational Flexibility** |
|
||||
- **Hot filer replacement**: Switch filers without restarting STS |
|
||||
- **A/B testing**: Route different requests to different filers |
|
||||
- **Disaster recovery**: Automatic failover to backup filers |
|
||||
- **Performance optimization**: Route to least loaded filer |
|
||||
|
|
||||
### 4. **SeaweedFS Consistency** |
|
||||
Follows the same pattern used throughout SeaweedFS codebase where filer addresses are passed to methods, not stored in structs. |
|
||||
|
|
||||
## Migration Guide |
|
||||
|
|
||||
### For Code Calling STS Methods |
|
||||
|
|
||||
**Before:** |
|
||||
```go |
|
||||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, request) |
|
||||
session, err := stsService.ValidateSessionToken(ctx, token) |
|
||||
err := stsService.RevokeSession(ctx, token) |
|
||||
``` |
|
||||
|
|
||||
**After:** |
|
||||
```go |
|
||||
filerAddr := getCurrentFilerAddress() // Implement your strategy |
|
||||
response, err := stsService.AssumeRoleWithWebIdentity(ctx, filerAddr, request) |
|
||||
session, err := stsService.ValidateSessionToken(ctx, filerAddr, token) |
|
||||
err := stsService.RevokeSession(ctx, filerAddr, token) |
|
||||
``` |
|
||||
|
|
||||
### For Configuration Files |
|
||||
|
|
||||
Remove `filerAddress` from all store configurations: |
|
||||
|
|
||||
```bash |
|
||||
# Update all iam_config*.json files |
|
||||
sed -i 's|"filerAddress": ".*",||g' iam_config*.json |
|
||||
``` |
|
||||
|
|
||||
## Testing |
|
||||
|
|
||||
All tests have been updated to pass a test filer address: |
|
||||
|
|
||||
```go |
|
||||
func TestAssumeRoleWithWebIdentity(t *testing.T) { |
|
||||
service := setupTestSTSService(t) |
|
||||
testFilerAddress := "localhost:8888" // Test filer address |
|
||||
|
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, testFilerAddress, request) |
|
||||
// ... test logic |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
## Production Deployment |
|
||||
|
|
||||
### High Availability Setup |
|
||||
|
|
||||
```go |
|
||||
type FilerManager struct { |
|
||||
primaryFilers []string |
|
||||
backupFilers []string |
|
||||
healthChecker *HealthChecker |
|
||||
} |
|
||||
|
|
||||
func (fm *FilerManager) GetAvailableFiler() string { |
|
||||
// Check primary filers first |
|
||||
for _, filer := range fm.primaryFilers { |
|
||||
if fm.healthChecker.IsHealthy(filer) { |
|
||||
return filer |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Fallback to backup filers |
|
||||
for _, filer := range fm.backupFilers { |
|
||||
if fm.healthChecker.IsHealthy(filer) { |
|
||||
return filer |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Return first primary as last resort |
|
||||
return fm.primaryFilers[0] |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
### Load Balanced Configuration |
|
||||
|
|
||||
```json |
|
||||
{ |
|
||||
"sts": { |
|
||||
"sessionStoreType": "filer", |
|
||||
"sessionStoreConfig": { |
|
||||
"basePath": "/etc/iam/sessions" |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
```go |
|
||||
// Runtime filer selection |
|
||||
filerLoadBalancer := &RoundRobinBalancer{ |
|
||||
Filers: []string{ |
|
||||
"filer-1.prod:8888", |
|
||||
"filer-2.prod:8888", |
|
||||
"filer-3.prod:8888", |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
response, err := stsService.AssumeRoleWithWebIdentity( |
|
||||
ctx, |
|
||||
filerLoadBalancer.Next(), // ✅ Dynamic selection |
|
||||
request, |
|
||||
) |
|
||||
``` |
|
||||
|
|
||||
## Conclusion |
|
||||
|
|
||||
This refactoring successfully addresses the requirement for runtime filer address passing, enabling: |
|
||||
|
|
||||
- ✅ **Dynamic filer selection** per request |
|
||||
- ✅ **Automatic failover** capabilities |
|
||||
- ✅ **Environment-agnostic** configurations |
|
||||
- ✅ **Load balancing** support |
|
||||
- ✅ **SeaweedFS pattern** compliance |
|
||||
- ✅ **Operational flexibility** for production deployments |
|
||||
|
|
||||
The implementation maintains backward compatibility for memory stores while enabling powerful distributed deployment scenarios for filer-backed stores. |
|
||||
@ -1,383 +0,0 @@ |
|||||
package sts |
|
||||
|
|
||||
import ( |
|
||||
"context" |
|
||||
"encoding/json" |
|
||||
"fmt" |
|
||||
"strings" |
|
||||
"sync" |
|
||||
"time" |
|
||||
|
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog" |
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb" |
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" |
|
||||
"google.golang.org/grpc" |
|
||||
) |
|
||||
|
|
||||
// MemorySessionStore implements SessionStore using in-memory storage
|
|
||||
type MemorySessionStore struct { |
|
||||
sessions map[string]*SessionInfo |
|
||||
mutex sync.RWMutex |
|
||||
} |
|
||||
|
|
||||
// NewMemorySessionStore creates a new memory-based session store
|
|
||||
func NewMemorySessionStore() *MemorySessionStore { |
|
||||
return &MemorySessionStore{ |
|
||||
sessions: make(map[string]*SessionInfo), |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// StoreSession stores session information in memory (filerAddress ignored for memory store)
|
|
||||
func (m *MemorySessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|
||||
if sessionId == "" { |
|
||||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
|
|
||||
if session == nil { |
|
||||
return fmt.Errorf("session cannot be nil") |
|
||||
} |
|
||||
|
|
||||
m.mutex.Lock() |
|
||||
defer m.mutex.Unlock() |
|
||||
|
|
||||
m.sessions[sessionId] = session |
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// GetSession retrieves session information from memory (filerAddress ignored for memory store)
|
|
||||
func (m *MemorySessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) { |
|
||||
if sessionId == "" { |
|
||||
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
|
|
||||
m.mutex.RLock() |
|
||||
defer m.mutex.RUnlock() |
|
||||
|
|
||||
session, exists := m.sessions[sessionId] |
|
||||
if !exists { |
|
||||
return nil, fmt.Errorf("session not found") |
|
||||
} |
|
||||
|
|
||||
// Check if session has expired
|
|
||||
if time.Now().After(session.ExpiresAt) { |
|
||||
return nil, fmt.Errorf("session has expired") |
|
||||
} |
|
||||
|
|
||||
return session, nil |
|
||||
} |
|
||||
|
|
||||
// RevokeSession revokes a session from memory (filerAddress ignored for memory store)
|
|
||||
func (m *MemorySessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error { |
|
||||
if sessionId == "" { |
|
||||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
|
|
||||
m.mutex.Lock() |
|
||||
defer m.mutex.Unlock() |
|
||||
|
|
||||
delete(m.sessions, sessionId) |
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// CleanupExpiredSessions removes expired sessions from memory (filerAddress ignored for memory store)
|
|
||||
func (m *MemorySessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error { |
|
||||
m.mutex.Lock() |
|
||||
defer m.mutex.Unlock() |
|
||||
|
|
||||
now := time.Now() |
|
||||
for sessionId, session := range m.sessions { |
|
||||
if now.After(session.ExpiresAt) { |
|
||||
delete(m.sessions, sessionId) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// ExpireSessionForTesting manually expires a session for testing purposes (filerAddress ignored for memory store)
|
|
||||
func (m *MemorySessionStore) ExpireSessionForTesting(ctx context.Context, filerAddress string, sessionId string) error { |
|
||||
if sessionId == "" { |
|
||||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
|
|
||||
m.mutex.Lock() |
|
||||
defer m.mutex.Unlock() |
|
||||
|
|
||||
session, exists := m.sessions[sessionId] |
|
||||
if !exists { |
|
||||
return fmt.Errorf("session not found") |
|
||||
} |
|
||||
|
|
||||
// Set expiration to 1 minute in the past to ensure it's expired
|
|
||||
session.ExpiresAt = time.Now().Add(-1 * time.Minute) |
|
||||
m.sessions[sessionId] = session |
|
||||
|
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// FilerSessionStore implements SessionStore using SeaweedFS filer
|
|
||||
type FilerSessionStore struct { |
|
||||
grpcDialOption grpc.DialOption |
|
||||
basePath string |
|
||||
} |
|
||||
|
|
||||
// NewFilerSessionStore creates a new filer-based session store
|
|
||||
func NewFilerSessionStore(config map[string]interface{}) (*FilerSessionStore, error) { |
|
||||
store := &FilerSessionStore{ |
|
||||
basePath: DefaultSessionBasePath, // Use constant default
|
|
||||
} |
|
||||
|
|
||||
// Parse configuration - only basePath and other settings, NOT filerAddress
|
|
||||
if config != nil { |
|
||||
if basePath, ok := config[ConfigFieldBasePath].(string); ok && basePath != "" { |
|
||||
store.basePath = strings.TrimSuffix(basePath, "/") |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
glog.V(2).Infof("Initialized FilerSessionStore with basePath %s", store.basePath) |
|
||||
|
|
||||
return store, nil |
|
||||
} |
|
||||
|
|
||||
// StoreSession stores session information in filer
|
|
||||
func (f *FilerSessionStore) StoreSession(ctx context.Context, filerAddress string, sessionId string, session *SessionInfo) error { |
|
||||
if filerAddress == "" { |
|
||||
return fmt.Errorf(ErrFilerAddressRequired) |
|
||||
} |
|
||||
if sessionId == "" { |
|
||||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
if session == nil { |
|
||||
return fmt.Errorf("session cannot be nil") |
|
||||
} |
|
||||
|
|
||||
// Serialize session to JSON
|
|
||||
sessionData, err := json.Marshal(session) |
|
||||
if err != nil { |
|
||||
return fmt.Errorf("failed to serialize session: %v", err) |
|
||||
} |
|
||||
|
|
||||
sessionPath := f.getSessionPath(sessionId) |
|
||||
|
|
||||
// Store in filer
|
|
||||
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|
||||
request := &filer_pb.CreateEntryRequest{ |
|
||||
Directory: f.basePath, |
|
||||
Entry: &filer_pb.Entry{ |
|
||||
Name: f.getSessionFileName(sessionId), |
|
||||
IsDirectory: false, |
|
||||
Attributes: &filer_pb.FuseAttributes{ |
|
||||
Mtime: time.Now().Unix(), |
|
||||
Crtime: time.Now().Unix(), |
|
||||
FileMode: uint32(0600), // Read/write for owner only
|
|
||||
Uid: uint32(0), |
|
||||
Gid: uint32(0), |
|
||||
}, |
|
||||
Content: sessionData, |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("Storing session %s at %s", sessionId, sessionPath) |
|
||||
_, err := client.CreateEntry(ctx, request) |
|
||||
if err != nil { |
|
||||
return fmt.Errorf("failed to store session %s: %v", sessionId, err) |
|
||||
} |
|
||||
|
|
||||
return nil |
|
||||
}) |
|
||||
} |
|
||||
|
|
||||
// GetSession retrieves session information from filer
|
|
||||
func (f *FilerSessionStore) GetSession(ctx context.Context, filerAddress string, sessionId string) (*SessionInfo, error) { |
|
||||
if filerAddress == "" { |
|
||||
return nil, fmt.Errorf(ErrFilerAddressRequired) |
|
||||
} |
|
||||
if sessionId == "" { |
|
||||
return nil, fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
|
|
||||
var sessionData []byte |
|
||||
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|
||||
request := &filer_pb.LookupDirectoryEntryRequest{ |
|
||||
Directory: f.basePath, |
|
||||
Name: f.getSessionFileName(sessionId), |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("Looking up session %s", sessionId) |
|
||||
response, err := client.LookupDirectoryEntry(ctx, request) |
|
||||
if err != nil { |
|
||||
return fmt.Errorf("session not found: %v", err) |
|
||||
} |
|
||||
|
|
||||
if response.Entry == nil { |
|
||||
return fmt.Errorf("session not found") |
|
||||
} |
|
||||
|
|
||||
sessionData = response.Entry.Content |
|
||||
return nil |
|
||||
}) |
|
||||
|
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
|
|
||||
// Deserialize session from JSON
|
|
||||
var session SessionInfo |
|
||||
if err := json.Unmarshal(sessionData, &session); err != nil { |
|
||||
return nil, fmt.Errorf("failed to deserialize session: %v", err) |
|
||||
} |
|
||||
|
|
||||
// Check if session has expired
|
|
||||
if time.Now().After(session.ExpiresAt) { |
|
||||
// Clean up expired session
|
|
||||
_ = f.RevokeSession(ctx, filerAddress, sessionId) |
|
||||
return nil, fmt.Errorf("session has expired") |
|
||||
} |
|
||||
|
|
||||
return &session, nil |
|
||||
} |
|
||||
|
|
||||
// RevokeSession revokes a session from filer
|
|
||||
func (f *FilerSessionStore) RevokeSession(ctx context.Context, filerAddress string, sessionId string) error { |
|
||||
if filerAddress == "" { |
|
||||
return fmt.Errorf(ErrFilerAddressRequired) |
|
||||
} |
|
||||
if sessionId == "" { |
|
||||
return fmt.Errorf(ErrSessionIDCannotBeEmpty) |
|
||||
} |
|
||||
|
|
||||
return f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|
||||
request := &filer_pb.DeleteEntryRequest{ |
|
||||
Directory: f.basePath, |
|
||||
Name: f.getSessionFileName(sessionId), |
|
||||
IsDeleteData: true, |
|
||||
IsRecursive: false, |
|
||||
IgnoreRecursiveError: false, |
|
||||
} |
|
||||
|
|
||||
glog.V(3).Infof("Revoking session %s", sessionId) |
|
||||
resp, err := client.DeleteEntry(ctx, request) |
|
||||
if err != nil { |
|
||||
// Ignore "not found" errors - session may already be deleted
|
|
||||
if strings.Contains(err.Error(), "not found") { |
|
||||
return nil |
|
||||
} |
|
||||
return fmt.Errorf("failed to revoke session %s: %v", sessionId, err) |
|
||||
} |
|
||||
|
|
||||
// Check response error
|
|
||||
if resp.Error != "" { |
|
||||
// Ignore "not found" errors - session may already be deleted
|
|
||||
if strings.Contains(resp.Error, "not found") { |
|
||||
return nil |
|
||||
} |
|
||||
return fmt.Errorf("failed to revoke session %s: %s", sessionId, resp.Error) |
|
||||
} |
|
||||
|
|
||||
return nil |
|
||||
}) |
|
||||
} |
|
||||
|
|
||||
// CleanupExpiredSessions removes expired sessions from filer
|
|
||||
func (f *FilerSessionStore) CleanupExpiredSessions(ctx context.Context, filerAddress string) error { |
|
||||
if filerAddress == "" { |
|
||||
return fmt.Errorf(ErrFilerAddressRequired) |
|
||||
} |
|
||||
|
|
||||
now := time.Now() |
|
||||
expiredCount := 0 |
|
||||
|
|
||||
err := f.withFilerClient(filerAddress, func(client filer_pb.SeaweedFilerClient) error { |
|
||||
// List all entries in the session directory
|
|
||||
request := &filer_pb.ListEntriesRequest{ |
|
||||
Directory: f.basePath, |
|
||||
Prefix: "session_", |
|
||||
StartFromFileName: "", |
|
||||
InclusiveStartFrom: false, |
|
||||
Limit: 1000, // Process in batches of 1000
|
|
||||
} |
|
||||
|
|
||||
stream, err := client.ListEntries(ctx, request) |
|
||||
if err != nil { |
|
||||
return fmt.Errorf("failed to list sessions: %v", err) |
|
||||
} |
|
||||
|
|
||||
for { |
|
||||
resp, err := stream.Recv() |
|
||||
if err != nil { |
|
||||
break // End of stream or error
|
|
||||
} |
|
||||
|
|
||||
if resp.Entry == nil || resp.Entry.IsDirectory { |
|
||||
continue |
|
||||
} |
|
||||
|
|
||||
// Parse session data to check expiration
|
|
||||
var session SessionInfo |
|
||||
if err := json.Unmarshal(resp.Entry.Content, &session); err != nil { |
|
||||
glog.V(2).Infof("Failed to parse session file %s, deleting: %v", resp.Entry.Name, err) |
|
||||
// Delete corrupted session file
|
|
||||
f.deleteSessionFile(ctx, client, resp.Entry.Name) |
|
||||
continue |
|
||||
} |
|
||||
|
|
||||
// Check if session is expired
|
|
||||
if now.After(session.ExpiresAt) { |
|
||||
glog.V(3).Infof("Cleaning up expired session: %s", resp.Entry.Name) |
|
||||
if err := f.deleteSessionFile(ctx, client, resp.Entry.Name); err != nil { |
|
||||
glog.V(1).Infof("Failed to delete expired session %s: %v", resp.Entry.Name, err) |
|
||||
} else { |
|
||||
expiredCount++ |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
return nil |
|
||||
}) |
|
||||
|
|
||||
if err != nil { |
|
||||
return err |
|
||||
} |
|
||||
|
|
||||
if expiredCount > 0 { |
|
||||
glog.V(2).Infof("Cleaned up %d expired sessions", expiredCount) |
|
||||
} |
|
||||
|
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
// Helper methods
|
|
||||
|
|
||||
// withFilerClient executes a function with a filer client
|
|
||||
func (f *FilerSessionStore) withFilerClient(filerAddress string, fn func(client filer_pb.SeaweedFilerClient) error) error { |
|
||||
if filerAddress == "" { |
|
||||
return fmt.Errorf(ErrFilerAddressRequired) |
|
||||
} |
|
||||
|
|
||||
// Use the pb.WithGrpcFilerClient helper similar to existing SeaweedFS code
|
|
||||
return pb.WithGrpcFilerClient(false, 0, pb.ServerAddress(filerAddress), f.grpcDialOption, fn) |
|
||||
} |
|
||||
|
|
||||
// getSessionPath returns the full path for a session
|
|
||||
func (f *FilerSessionStore) getSessionPath(sessionId string) string { |
|
||||
return f.basePath + "/" + f.getSessionFileName(sessionId) |
|
||||
} |
|
||||
|
|
||||
// getSessionFileName returns the filename for a session
|
|
||||
func (f *FilerSessionStore) getSessionFileName(sessionId string) string { |
|
||||
return "session_" + sessionId + ".json" |
|
||||
} |
|
||||
|
|
||||
// deleteSessionFile deletes a session file
|
|
||||
func (f *FilerSessionStore) deleteSessionFile(ctx context.Context, client filer_pb.SeaweedFilerClient, fileName string) error { |
|
||||
request := &filer_pb.DeleteEntryRequest{ |
|
||||
Directory: f.basePath, |
|
||||
Name: fileName, |
|
||||
IsDeleteData: true, |
|
||||
IsRecursive: false, |
|
||||
IgnoreRecursiveError: false, |
|
||||
} |
|
||||
|
|
||||
_, err := client.DeleteEntry(ctx, request) |
|
||||
return err |
|
||||
} |
|
||||
Write
Preview
Loading…
Cancel
Save
Reference in new issue