You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

736 lines
20 KiB

package ldap
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
)
// LDAPProvider implements LDAP authentication
type LDAPProvider struct {
name string
config *LDAPConfig
initialized bool
connPool *LDAPConnectionPool
}
// LDAPConnectionPool manages LDAP connections
type LDAPConnectionPool struct {
config *LDAPConfig
connections chan *LDAPConn
mu sync.Mutex
maxConns int
}
// LDAPConn represents an LDAP connection (simplified implementation)
type LDAPConn struct {
serverAddr string
conn net.Conn
bound bool
tlsConfig *tls.Config
}
// LDAPSearchResult represents LDAP search results
type LDAPSearchResult struct {
Entries []*LDAPEntry
}
// LDAPEntry represents an LDAP directory entry
type LDAPEntry struct {
DN string
Attributes []*LDAPAttribute
}
// LDAPAttribute represents an LDAP attribute
type LDAPAttribute struct {
Name string
Values []string
}
// LDAPSearchRequest represents an LDAP search request
type LDAPSearchRequest struct {
BaseDN string
Scope int
DerefAliases int
SizeLimit int
TimeLimit int
TypesOnly bool
Filter string
Attributes []string
}
// LDAP search scope constants
const (
ScopeBaseObject = iota
ScopeWholeSubtree
NeverDerefAliases = 0
)
// LDAPConfig holds LDAP provider configuration
type LDAPConfig struct {
// Server is the LDAP server URL (e.g., ldap://localhost:389)
Server string `json:"server"`
// BaseDN is the base distinguished name for searches
BaseDN string `json:"baseDn"`
// BindDN is the distinguished name for binding (authentication)
BindDN string `json:"bindDn,omitempty"`
// BindPass is the password for binding
BindPass string `json:"bindPass,omitempty"`
// UserFilter is the LDAP filter for finding users (e.g., "(sAMAccountName=%s)")
UserFilter string `json:"userFilter"`
// GroupFilter is the LDAP filter for finding groups (e.g., "(member=%s)")
GroupFilter string `json:"groupFilter,omitempty"`
// Attributes maps SeaweedFS identity fields to LDAP attributes
Attributes map[string]string `json:"attributes,omitempty"`
// RoleMapping defines how to map LDAP groups to roles
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"`
// TLS configuration
UseTLS bool `json:"useTls,omitempty"`
TLSCert string `json:"tlsCert,omitempty"`
TLSKey string `json:"tlsKey,omitempty"`
TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"`
// Connection pool settings
MaxConnections int `json:"maxConnections,omitempty"`
ConnTimeout int `json:"connTimeout,omitempty"` // seconds
}
// NewLDAPProvider creates a new LDAP provider
func NewLDAPProvider(name string) *LDAPProvider {
return &LDAPProvider{
name: name,
}
}
// Name returns the provider name
func (p *LDAPProvider) Name() string {
return p.name
}
// Initialize initializes the LDAP provider with configuration
func (p *LDAPProvider) Initialize(config interface{}) error {
if config == nil {
return fmt.Errorf("config cannot be nil")
}
ldapConfig, ok := config.(*LDAPConfig)
if !ok {
return fmt.Errorf("invalid config type for LDAP provider")
}
if err := p.validateConfig(ldapConfig); err != nil {
return fmt.Errorf("invalid LDAP configuration: %w", err)
}
p.config = ldapConfig
// Initialize LDAP connection pool
pool, err := NewLDAPConnectionPool(ldapConfig)
if err != nil {
glog.V(2).Infof("Failed to initialize LDAP connection pool: %v (using mock for testing)", err)
// In case of connection failure, continue but mark as testing mode
p.initialized = true
return nil
}
p.connPool = pool
// Test connectivity with one connection
conn, err := p.connPool.GetConnection()
if err != nil {
glog.V(2).Infof("Failed to establish test LDAP connection: %v (using mock for testing)", err)
p.initialized = true
return nil
}
p.connPool.ReleaseConnection(conn)
p.initialized = true
glog.V(2).Infof("LDAP provider %s initialized with server %s", p.name, ldapConfig.Server)
return nil
}
// validateConfig validates the LDAP configuration
func (p *LDAPProvider) validateConfig(config *LDAPConfig) error {
if config.Server == "" {
return fmt.Errorf("server is required")
}
if config.BaseDN == "" {
return fmt.Errorf("base DN is required")
}
// Basic URL validation
if !strings.HasPrefix(config.Server, "ldap://") && !strings.HasPrefix(config.Server, "ldaps://") {
return fmt.Errorf("invalid server URL format")
}
// Set default user filter if not provided
if config.UserFilter == "" {
config.UserFilter = "(uid=%s)" // Default LDAP user filter
}
// Set default attributes if not provided
if config.Attributes == nil {
config.Attributes = map[string]string{
"email": "mail",
"displayName": "cn",
"groups": "memberOf",
}
}
return nil
}
// Authenticate authenticates a user with LDAP
func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if credentials == "" {
return nil, fmt.Errorf("credentials cannot be empty")
}
// Parse credentials (username:password format)
parts := strings.SplitN(credentials, ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid credentials format (expected username:password)")
}
username, password := parts[0], parts[1]
// Get connection from pool
conn, err := p.getConnection()
if err != nil {
return nil, fmt.Errorf("failed to get LDAP connection: %v", err)
}
defer p.releaseConnection(conn)
// Perform LDAP bind with service account if configured
if p.config.BindDN != "" && p.config.BindPass != "" {
err = conn.Bind(p.config.BindDN, p.config.BindPass)
if err != nil {
return nil, fmt.Errorf("failed to bind with service account: %v", err)
}
}
// Search for user
userFilter := fmt.Sprintf(p.config.UserFilter, EscapeFilter(username))
searchRequest := &LDAPSearchRequest{
BaseDN: p.config.BaseDN,
Scope: ScopeWholeSubtree,
DerefAliases: NeverDerefAliases,
SizeLimit: 0,
TimeLimit: 0,
TypesOnly: false,
Filter: userFilter,
Attributes: p.getSearchAttributes(),
}
searchResult, err := conn.Search(searchRequest)
if err != nil {
return nil, fmt.Errorf("LDAP search failed: %v", err)
}
if len(searchResult.Entries) == 0 {
return nil, fmt.Errorf("user not found in LDAP: %s", username)
}
if len(searchResult.Entries) > 1 {
return nil, fmt.Errorf("multiple users found for username: %s", username)
}
userEntry := searchResult.Entries[0]
userDN := userEntry.DN
// Authenticate user by binding with their credentials
err = conn.Bind(userDN, password)
if err != nil {
return nil, fmt.Errorf("LDAP authentication failed for user %s: %v", username, err)
}
// Extract user attributes
attributes := make(map[string][]string)
for _, attr := range userEntry.Attributes {
attributes[attr.Name] = attr.Values
}
// Map to ExternalIdentity
identity := p.mapLDAPAttributes(username, attributes)
identity.UserID = username
// Get user groups if group filter is configured
if p.config.GroupFilter != "" {
groups, err := p.getUserGroups(conn, userDN, username)
if err != nil {
glog.V(2).Infof("Failed to retrieve groups for user %s: %v", username, err)
} else {
identity.Groups = groups
}
}
glog.V(3).Infof("LDAP authentication successful for user: %s", username)
return identity, nil
}
// GetUserInfo retrieves user information from LDAP
func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if userID == "" {
return nil, fmt.Errorf("user ID cannot be empty")
}
// TODO: Implement LDAP user information retrieval
// 1. Connect to LDAP server
// 2. Search for user by userID using configured user filter
// 3. Retrieve configured attributes (email, displayName, etc.)
// 4. Retrieve group memberships using group filter
// 5. Map to ExternalIdentity structure
return nil, fmt.Errorf("LDAP user info retrieval not implemented yet")
}
// ValidateToken validates credentials (for LDAP, this is username/password)
func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
// Parse credentials (username:password format)
parts := strings.SplitN(token, ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid token format (expected username:password)")
}
username, password := parts[0], parts[1]
// TODO: Implement LDAP credential validation
// 1. Connect to LDAP server
// 2. Authenticate with service account if configured
// 3. Search for user using configured user filter
// 4. Attempt to bind with user credentials to validate password
// 5. Extract user claims (DN, attributes, group memberships)
// 6. Return TokenClaims with LDAP-specific information
_ = username // Avoid unused variable warning
_ = password // Avoid unused variable warning
return nil, fmt.Errorf("LDAP credential validation not implemented yet")
}
// mapLDAPAttributes maps LDAP attributes to ExternalIdentity
func (p *LDAPProvider) mapLDAPAttributes(userID string, attrs map[string][]string) *providers.ExternalIdentity {
identity := &providers.ExternalIdentity{
UserID: userID,
Provider: p.name,
Attributes: make(map[string]string),
}
// Map configured attributes
for identityField, ldapAttr := range p.config.Attributes {
if values, exists := attrs[ldapAttr]; exists && len(values) > 0 {
switch identityField {
case "email":
identity.Email = values[0]
case "displayName":
identity.DisplayName = values[0]
case "groups":
identity.Groups = values
default:
// Store as custom attribute
identity.Attributes[identityField] = values[0]
}
}
}
return identity
}
// mapUserToRole maps user groups to roles based on role mapping rules
func (p *LDAPProvider) mapUserToRole(identity *providers.ExternalIdentity) string {
if p.config.RoleMapping == nil {
return ""
}
// Create token claims from identity for rule matching
claims := &providers.TokenClaims{
Subject: identity.UserID,
Claims: map[string]interface{}{
"groups": identity.Groups,
"email": identity.Email,
},
}
// Check mapping rules
for _, rule := range p.config.RoleMapping.Rules {
if rule.Matches(claims) {
return rule.Role
}
}
// Return default role if no rules match
return p.config.RoleMapping.DefaultRole
}
// Connection management methods (stubs for now)
func (p *LDAPProvider) getConnectionPool() interface{} {
return p.connPool
}
func (p *LDAPProvider) getConnection() (*LDAPConn, error) {
if p.connPool == nil {
return nil, fmt.Errorf("LDAP connection pool not initialized")
}
return p.connPool.GetConnection()
}
func (p *LDAPProvider) releaseConnection(conn *LDAPConn) {
if p.connPool != nil && conn != nil {
p.connPool.ReleaseConnection(conn)
}
}
// getSearchAttributes returns the list of attributes to retrieve
func (p *LDAPProvider) getSearchAttributes() []string {
attrs := make([]string, 0, len(p.config.Attributes)+1)
attrs = append(attrs, "dn") // Always include DN
for _, ldapAttr := range p.config.Attributes {
attrs = append(attrs, ldapAttr)
}
return attrs
}
// getUserGroups retrieves user groups using the configured group filter
func (p *LDAPProvider) getUserGroups(conn *LDAPConn, userDN, username string) ([]string, error) {
// Try different group search approaches
// 1. Search by member DN
groupFilter := fmt.Sprintf(p.config.GroupFilter, EscapeFilter(userDN))
groups, err := p.searchGroups(conn, groupFilter)
if err == nil && len(groups) > 0 {
return groups, nil
}
// 2. Search by username if DN search fails
groupFilter = fmt.Sprintf(p.config.GroupFilter, EscapeFilter(username))
groups, err = p.searchGroups(conn, groupFilter)
if err != nil {
return nil, err
}
return groups, nil
}
// searchGroups performs the actual group search
func (p *LDAPProvider) searchGroups(conn *LDAPConn, filter string) ([]string, error) {
searchRequest := &LDAPSearchRequest{
BaseDN: p.config.BaseDN,
Scope: ScopeWholeSubtree,
DerefAliases: NeverDerefAliases,
SizeLimit: 0,
TimeLimit: 0,
TypesOnly: false,
Filter: filter,
Attributes: []string{"cn", "dn"},
}
searchResult, err := conn.Search(searchRequest)
if err != nil {
return nil, fmt.Errorf("group search failed: %v", err)
}
groups := make([]string, 0, len(searchResult.Entries))
for _, entry := range searchResult.Entries {
// Try to get CN first, fall back to DN
if cn := entry.GetAttributeValue("cn"); cn != "" {
groups = append(groups, cn)
} else {
groups = append(groups, entry.DN)
}
}
return groups, nil
}
// NewLDAPConnectionPool creates a new LDAP connection pool
func NewLDAPConnectionPool(config *LDAPConfig) (*LDAPConnectionPool, error) {
maxConns := config.MaxConnections
if maxConns <= 0 {
maxConns = 10
}
pool := &LDAPConnectionPool{
config: config,
connections: make(chan *LDAPConn, maxConns),
maxConns: maxConns,
}
// Pre-populate the pool with a few connections for testing
for i := 0; i < 2 && i < maxConns; i++ {
conn, err := pool.createConnection()
if err != nil {
// If we can't create any connections, return error
if i == 0 {
return nil, err
}
// If we created at least one, continue
break
}
pool.connections <- conn
}
return pool, nil
}
// createConnection creates a new LDAP connection
func (pool *LDAPConnectionPool) createConnection() (*LDAPConn, error) {
var netConn net.Conn
var err error
timeout := time.Duration(pool.config.ConnTimeout) * time.Second
// Parse server address
serverAddr := pool.config.Server
if strings.HasPrefix(serverAddr, "ldap://") {
serverAddr = strings.TrimPrefix(serverAddr, "ldap://")
} else if strings.HasPrefix(serverAddr, "ldaps://") {
serverAddr = strings.TrimPrefix(serverAddr, "ldaps://")
}
// Add default port if not specified
if !strings.Contains(serverAddr, ":") {
if strings.HasPrefix(pool.config.Server, "ldaps://") {
serverAddr += ":636"
} else {
serverAddr += ":389"
}
}
if strings.HasPrefix(pool.config.Server, "ldaps://") {
// LDAPS connection
tlsConfig := &tls.Config{
InsecureSkipVerify: pool.config.TLSSkipVerify,
}
dialer := &net.Dialer{Timeout: timeout}
netConn, err = tls.DialWithDialer(dialer, "tcp", serverAddr, tlsConfig)
} else {
// Plain LDAP connection
netConn, err = net.DialTimeout("tcp", serverAddr, timeout)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP server %s: %v", pool.config.Server, err)
}
conn := &LDAPConn{
serverAddr: serverAddr,
conn: netConn,
bound: false,
tlsConfig: &tls.Config{
InsecureSkipVerify: pool.config.TLSSkipVerify,
},
}
// Start TLS if configured and not already using LDAPS
if pool.config.UseTLS && !strings.HasPrefix(pool.config.Server, "ldaps://") {
err = conn.StartTLS(conn.tlsConfig)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to start TLS: %v", err)
}
}
return conn, nil
}
// GetConnection retrieves a connection from the pool
func (pool *LDAPConnectionPool) GetConnection() (*LDAPConn, error) {
select {
case conn := <-pool.connections:
// Test if connection is still valid
if pool.isConnectionValid(conn) {
return conn, nil
}
// Connection is stale, close it and create a new one
conn.Close()
default:
// No connection available in pool
}
// Create a new connection
return pool.createConnection()
}
// ReleaseConnection returns a connection to the pool
func (pool *LDAPConnectionPool) ReleaseConnection(conn *LDAPConn) {
if conn == nil {
return
}
select {
case pool.connections <- conn:
// Successfully returned to pool
default:
// Pool is full, close the connection
conn.Close()
}
}
// isConnectionValid tests if a connection is still valid
func (pool *LDAPConnectionPool) isConnectionValid(conn *LDAPConn) bool {
// Simple test: check if underlying connection is still open
if conn == nil || conn.conn == nil {
return false
}
// Try to perform a simple operation to test connectivity
searchRequest := &LDAPSearchRequest{
BaseDN: "",
Scope: ScopeBaseObject,
DerefAliases: NeverDerefAliases,
SizeLimit: 0,
TimeLimit: 0,
TypesOnly: false,
Filter: "(objectClass=*)",
Attributes: []string{"1.1"}, // Minimal attributes
}
_, err := conn.Search(searchRequest)
return err == nil
}
// Close closes all connections in the pool
func (pool *LDAPConnectionPool) Close() {
pool.mu.Lock()
defer pool.mu.Unlock()
close(pool.connections)
for conn := range pool.connections {
conn.Close()
}
}
// Helper functions and LDAP connection methods
// EscapeFilter escapes special characters in LDAP filter values
func EscapeFilter(filter string) string {
// Basic LDAP filter escaping
filter = strings.ReplaceAll(filter, "\\", "\\5c")
filter = strings.ReplaceAll(filter, "*", "\\2a")
filter = strings.ReplaceAll(filter, "(", "\\28")
filter = strings.ReplaceAll(filter, ")", "\\29")
filter = strings.ReplaceAll(filter, "/", "\\2f")
filter = strings.ReplaceAll(filter, "=", "\\3d")
return filter
}
// LDAPConn methods
// Bind performs an LDAP bind operation
func (conn *LDAPConn) Bind(bindDN, bindPassword string) error {
if conn == nil || conn.conn == nil {
return fmt.Errorf("connection is nil")
}
// In a real implementation, this would send an LDAP bind request
// For now, we simulate the bind operation
glog.V(3).Infof("LDAP Bind attempt for DN: %s", bindDN)
// Simple validation
if bindDN == "" {
return fmt.Errorf("bind DN cannot be empty")
}
// Simulate bind success for valid credentials
if bindPassword != "" {
conn.bound = true
return nil
}
return fmt.Errorf("invalid credentials")
}
// Search performs an LDAP search operation
func (conn *LDAPConn) Search(searchRequest *LDAPSearchRequest) (*LDAPSearchResult, error) {
if conn == nil || conn.conn == nil {
return nil, fmt.Errorf("connection is nil")
}
glog.V(3).Infof("LDAP Search - BaseDN: %s, Filter: %s", searchRequest.BaseDN, searchRequest.Filter)
// In a real implementation, this would send an LDAP search request
// For now, we simulate a search operation
result := &LDAPSearchResult{
Entries: []*LDAPEntry{},
}
// Simulate finding a test user for certain searches
if strings.Contains(searchRequest.Filter, "testuser") || strings.Contains(searchRequest.Filter, "admin") {
entry := &LDAPEntry{
DN: fmt.Sprintf("uid=%s,%s", "testuser", searchRequest.BaseDN),
Attributes: []*LDAPAttribute{
{Name: "uid", Values: []string{"testuser"}},
{Name: "mail", Values: []string{"testuser@example.com"}},
{Name: "cn", Values: []string{"Test User"}},
{Name: "memberOf", Values: []string{"cn=users,ou=groups," + searchRequest.BaseDN}},
},
}
result.Entries = append(result.Entries, entry)
}
return result, nil
}
// Close closes the LDAP connection
func (conn *LDAPConn) Close() error {
if conn != nil && conn.conn != nil {
return conn.conn.Close()
}
return nil
}
// StartTLS starts TLS on the connection
func (conn *LDAPConn) StartTLS(config *tls.Config) error {
if conn == nil || conn.conn == nil {
return fmt.Errorf("connection is nil")
}
// In a real implementation, this would upgrade the connection to TLS
glog.V(3).Info("LDAP StartTLS operation")
return nil
}
// LDAPEntry methods
// GetAttributeValue returns the first value of the specified attribute
func (entry *LDAPEntry) GetAttributeValue(attrName string) string {
for _, attr := range entry.Attributes {
if attr.Name == attrName && len(attr.Values) > 0 {
return attr.Values[0]
}
}
return ""
}