diff --git a/weed/iam/ldap/ldap_provider.go b/weed/iam/ldap/ldap_provider.go new file mode 100644 index 000000000..76dd25d64 --- /dev/null +++ b/weed/iam/ldap/ldap_provider.go @@ -0,0 +1,219 @@ +package ldap + +import ( + "context" + "fmt" + "strings" + + "github.com/seaweedfs/seaweedfs/weed/iam/providers" +) + +// LDAPProvider implements LDAP authentication +type LDAPProvider struct { + name string + config *LDAPConfig + initialized bool + connPool interface{} // Will be proper LDAP connection pool +} + +// LDAPConfig holds LDAP provider configuration +type LDAPConfig struct { + // Server is the LDAP server URL (e.g., ldap://localhost:389) + Server string `json:"server"` + + // BaseDN is the base distinguished name for searches + BaseDN string `json:"baseDn"` + + // BindDN is the distinguished name for binding (authentication) + BindDN string `json:"bindDn,omitempty"` + + // BindPass is the password for binding + BindPass string `json:"bindPass,omitempty"` + + // UserFilter is the LDAP filter for finding users (e.g., "(sAMAccountName=%s)") + UserFilter string `json:"userFilter"` + + // GroupFilter is the LDAP filter for finding groups (e.g., "(member=%s)") + GroupFilter string `json:"groupFilter,omitempty"` + + // Attributes maps SeaweedFS identity fields to LDAP attributes + Attributes map[string]string `json:"attributes,omitempty"` + + // RoleMapping defines how to map LDAP groups to roles + RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` + + // TLS configuration + UseTLS bool `json:"useTls,omitempty"` + TLSCert string `json:"tlsCert,omitempty"` + TLSKey string `json:"tlsKey,omitempty"` + TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"` + + // Connection pool settings + MaxConnections int `json:"maxConnections,omitempty"` + ConnTimeout int `json:"connTimeout,omitempty"` // seconds +} + +// NewLDAPProvider creates a new LDAP provider +func NewLDAPProvider(name string) *LDAPProvider { + return &LDAPProvider{ + name: name, + } +} + +// Name returns the provider name +func (p *LDAPProvider) Name() string { + return p.name +} + +// Initialize initializes the LDAP provider with configuration +func (p *LDAPProvider) Initialize(config interface{}) error { + ldapConfig, ok := config.(*LDAPConfig) + if !ok { + return fmt.Errorf("invalid config type for LDAP provider") + } + + if err := p.validateConfig(ldapConfig); err != nil { + return fmt.Errorf("invalid LDAP configuration: %w", err) + } + + p.config = ldapConfig + p.initialized = true + + // TODO: Initialize LDAP connection pool + return fmt.Errorf("not implemented yet") +} + +// validateConfig validates the LDAP configuration +func (p *LDAPProvider) validateConfig(config *LDAPConfig) error { + if config.Server == "" { + return fmt.Errorf("server is required") + } + + if config.BaseDN == "" { + return fmt.Errorf("base DN is required") + } + + // Basic URL validation + if !strings.HasPrefix(config.Server, "ldap://") && !strings.HasPrefix(config.Server, "ldaps://") { + return fmt.Errorf("invalid server URL format") + } + + // Set default user filter if not provided + if config.UserFilter == "" { + config.UserFilter = "(uid=%s)" // Default LDAP user filter + } + + // Set default attributes if not provided + if config.Attributes == nil { + config.Attributes = map[string]string{ + "email": "mail", + "displayName": "cn", + "groups": "memberOf", + } + } + + return nil +} + +// Authenticate authenticates a user with LDAP +func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*providers.ExternalIdentity, error) { + if !p.initialized { + return nil, fmt.Errorf("provider not initialized") + } + + // TODO: Parse credentials (username:password), bind to LDAP, search for user + return nil, fmt.Errorf("not implemented yet") +} + +// GetUserInfo retrieves user information from LDAP +func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { + if !p.initialized { + return nil, fmt.Errorf("provider not initialized") + } + + if userID == "" { + return nil, fmt.Errorf("user ID cannot be empty") + } + + // TODO: Search LDAP for user, get attributes + return nil, fmt.Errorf("not implemented yet") +} + +// ValidateToken validates credentials (for LDAP, this is username/password) +func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { + if !p.initialized { + return nil, fmt.Errorf("provider not initialized") + } + + // TODO: For LDAP, "token" would be username:password format + // Validate credentials and return claims + return nil, fmt.Errorf("not implemented yet") +} + +// mapLDAPAttributes maps LDAP attributes to ExternalIdentity +func (p *LDAPProvider) mapLDAPAttributes(userID string, attrs map[string][]string) *providers.ExternalIdentity { + identity := &providers.ExternalIdentity{ + UserID: userID, + Provider: p.name, + Attributes: make(map[string]string), + } + + // Map configured attributes + for identityField, ldapAttr := range p.config.Attributes { + if values, exists := attrs[ldapAttr]; exists && len(values) > 0 { + switch identityField { + case "email": + identity.Email = values[0] + case "displayName": + identity.DisplayName = values[0] + case "groups": + identity.Groups = values + default: + // Store as custom attribute + identity.Attributes[identityField] = values[0] + } + } + } + + return identity +} + +// mapUserToRole maps user groups to roles based on role mapping rules +func (p *LDAPProvider) mapUserToRole(identity *providers.ExternalIdentity) string { + if p.config.RoleMapping == nil { + return "" + } + + // Create token claims from identity for rule matching + claims := &providers.TokenClaims{ + Subject: identity.UserID, + Claims: map[string]interface{}{ + "groups": identity.Groups, + "email": identity.Email, + }, + } + + // Check mapping rules + for _, rule := range p.config.RoleMapping.Rules { + if rule.Matches(claims) { + return rule.Role + } + } + + // Return default role if no rules match + return p.config.RoleMapping.DefaultRole +} + +// Connection management methods (stubs for now) +func (p *LDAPProvider) getConnectionPool() interface{} { + return p.connPool +} + +func (p *LDAPProvider) getConnection() (interface{}, error) { + // TODO: Get connection from pool + return nil, fmt.Errorf("not implemented") +} + +func (p *LDAPProvider) releaseConnection(conn interface{}) { + // TODO: Return connection to pool +} diff --git a/weed/iam/ldap/ldap_provider_test.go b/weed/iam/ldap/ldap_provider_test.go new file mode 100644 index 000000000..95caefa43 --- /dev/null +++ b/weed/iam/ldap/ldap_provider_test.go @@ -0,0 +1,357 @@ +package ldap + +import ( + "context" + "fmt" + "testing" + + "github.com/seaweedfs/seaweedfs/weed/iam/providers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestLDAPProviderInitialization tests LDAP provider initialization +func TestLDAPProviderInitialization(t *testing.T) { + tests := []struct { + name string + config *LDAPConfig + wantErr bool + }{ + { + name: "valid config", + config: &LDAPConfig{ + Server: "ldap://localhost:389", + BaseDN: "DC=example,DC=com", + BindDN: "CN=admin,DC=example,DC=com", + BindPass: "password", + UserFilter: "(sAMAccountName=%s)", + GroupFilter: "(member=%s)", + }, + wantErr: false, + }, + { + name: "missing server", + config: &LDAPConfig{ + BaseDN: "DC=example,DC=com", + }, + wantErr: true, + }, + { + name: "missing base DN", + config: &LDAPConfig{ + Server: "ldap://localhost:389", + }, + wantErr: true, + }, + { + name: "invalid server URL", + config: &LDAPConfig{ + Server: "invalid-url", + BaseDN: "DC=example,DC=com", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := NewLDAPProvider("test-ldap") + + err := provider.Initialize(tt.config) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, "test-ldap", provider.Name()) + } + }) + } +} + +// TestLDAPProviderAuthentication tests LDAP authentication +func TestLDAPProviderAuthentication(t *testing.T) { + // Skip if no LDAP test server available + if testing.Short() { + t.Skip("Skipping LDAP integration test in short mode") + } + + provider := NewLDAPProvider("test-ldap") + config := &LDAPConfig{ + Server: "ldap://localhost:389", + BaseDN: "DC=example,DC=com", + BindDN: "CN=admin,DC=example,DC=com", + BindPass: "password", + UserFilter: "(sAMAccountName=%s)", + GroupFilter: "(member=%s)", + Attributes: map[string]string{ + "email": "mail", + "displayName": "displayName", + "groups": "memberOf", + }, + RoleMapping: &providers.RoleMapping{ + Rules: []providers.MappingRule{ + { + Claim: "groups", + Value: "*CN=Admins*", + Role: "arn:seaweed:iam::role/AdminRole", + }, + { + Claim: "groups", + Value: "*CN=Users*", + Role: "arn:seaweed:iam::role/UserRole", + }, + }, + DefaultRole: "arn:seaweed:iam::role/GuestRole", + }, + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("authenticate with username/password", func(t *testing.T) { + // This would require an actual LDAP server for integration testing + credentials := "user:password" // Basic auth format + + identity, err := provider.Authenticate(context.Background(), credentials) + if err != nil { + t.Skip("LDAP server not available for testing") + } + + assert.NoError(t, err) + assert.Equal(t, "user", identity.UserID) + assert.Equal(t, "test-ldap", identity.Provider) + assert.NotEmpty(t, identity.Email) + }) + + t.Run("authenticate with invalid credentials", func(t *testing.T) { + _, err := provider.Authenticate(context.Background(), "invalid:credentials") + assert.Error(t, err) + }) +} + +// TestLDAPProviderUserInfo tests LDAP user info retrieval +func TestLDAPProviderUserInfo(t *testing.T) { + if testing.Short() { + t.Skip("Skipping LDAP integration test in short mode") + } + + provider := NewLDAPProvider("test-ldap") + config := &LDAPConfig{ + Server: "ldap://localhost:389", + BaseDN: "DC=example,DC=com", + BindDN: "CN=admin,DC=example,DC=com", + BindPass: "password", + UserFilter: "(sAMAccountName=%s)", + Attributes: map[string]string{ + "email": "mail", + "displayName": "displayName", + "department": "department", + }, + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("get user info", func(t *testing.T) { + identity, err := provider.GetUserInfo(context.Background(), "testuser") + if err != nil { + t.Skip("LDAP server not available for testing") + } + + assert.NoError(t, err) + assert.Equal(t, "testuser", identity.UserID) + assert.Equal(t, "test-ldap", identity.Provider) + assert.NotEmpty(t, identity.Email) + assert.NotEmpty(t, identity.DisplayName) + }) + + t.Run("get user info with empty username", func(t *testing.T) { + _, err := provider.GetUserInfo(context.Background(), "") + assert.Error(t, err) + }) + + t.Run("get user info for non-existent user", func(t *testing.T) { + _, err := provider.GetUserInfo(context.Background(), "nonexistent") + assert.Error(t, err) + }) +} + +// TestLDAPAttributeMapping tests LDAP attribute mapping +func TestLDAPAttributeMapping(t *testing.T) { + provider := NewLDAPProvider("test-ldap") + config := &LDAPConfig{ + Server: "ldap://localhost:389", + BaseDN: "DC=example,DC=com", + Attributes: map[string]string{ + "email": "mail", + "displayName": "cn", + "department": "departmentNumber", + "groups": "memberOf", + }, + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("map LDAP attributes to identity", func(t *testing.T) { + ldapAttrs := map[string][]string{ + "mail": {"user@example.com"}, + "cn": {"John Doe"}, + "departmentNumber": {"IT"}, + "memberOf": { + "CN=Users,OU=Groups,DC=example,DC=com", + "CN=Developers,OU=Groups,DC=example,DC=com", + }, + } + + identity := provider.mapLDAPAttributes("testuser", ldapAttrs) + + assert.Equal(t, "testuser", identity.UserID) + assert.Equal(t, "user@example.com", identity.Email) + assert.Equal(t, "John Doe", identity.DisplayName) + assert.Equal(t, "test-ldap", identity.Provider) + + // Check groups + assert.Contains(t, identity.Groups, "CN=Users,OU=Groups,DC=example,DC=com") + assert.Contains(t, identity.Groups, "CN=Developers,OU=Groups,DC=example,DC=com") + + // Check attributes + assert.Equal(t, "IT", identity.Attributes["department"]) + }) +} + +// TestLDAPGroupFiltering tests LDAP group filtering and role mapping +func TestLDAPGroupFiltering(t *testing.T) { + provider := NewLDAPProvider("test-ldap") + config := &LDAPConfig{ + Server: "ldap://localhost:389", + BaseDN: "DC=example,DC=com", + RoleMapping: &providers.RoleMapping{ + Rules: []providers.MappingRule{ + { + Claim: "groups", + Value: "*Admins*", + Role: "arn:seaweed:iam::role/AdminRole", + }, + { + Claim: "groups", + Value: "*Users*", + Role: "arn:seaweed:iam::role/UserRole", + }, + }, + DefaultRole: "arn:seaweed:iam::role/GuestRole", + }, + } + + err := provider.Initialize(config) + require.NoError(t, err) + + tests := []struct { + name string + groups []string + expectedRole string + expectedClaims map[string]interface{} + }{ + { + name: "admin user", + groups: []string{"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, + expectedRole: "arn:seaweed:iam::role/AdminRole", + }, + { + name: "regular user", + groups: []string{"CN=Users,OU=Groups,DC=example,DC=com"}, + expectedRole: "arn:seaweed:iam::role/UserRole", + }, + { + name: "guest user", + groups: []string{"CN=Guests,OU=Groups,DC=example,DC=com"}, + expectedRole: "arn:seaweed:iam::role/GuestRole", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + identity := &providers.ExternalIdentity{ + UserID: "testuser", + Groups: tt.groups, + Provider: "test-ldap", + } + + role := provider.mapUserToRole(identity) + assert.Equal(t, tt.expectedRole, role) + }) + } +} + +// TestLDAPConnectionPool tests LDAP connection pooling +func TestLDAPConnectionPool(t *testing.T) { + if testing.Short() { + t.Skip("Skipping LDAP connection pool test in short mode") + } + + provider := NewLDAPProvider("test-ldap") + config := &LDAPConfig{ + Server: "ldap://localhost:389", + BaseDN: "DC=example,DC=com", + BindDN: "CN=admin,DC=example,DC=com", + BindPass: "password", + MaxConnections: 5, + ConnTimeout: 30, + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("connection pool management", func(t *testing.T) { + // Test that multiple concurrent requests work + // This would require actual LDAP server for full testing + pool := provider.getConnectionPool() + assert.NotNil(t, pool) + + // Test connection acquisition and release + conn, err := provider.getConnection() + if err != nil { + t.Skip("LDAP server not available") + } + + assert.NotNil(t, conn) + provider.releaseConnection(conn) + }) +} + +// MockLDAPServer for unit testing (without external dependencies) +type MockLDAPServer struct { + users map[string]map[string][]string +} + +func NewMockLDAPServer() *MockLDAPServer { + return &MockLDAPServer{ + users: map[string]map[string][]string{ + "testuser": { + "mail": {"testuser@example.com"}, + "cn": {"Test User"}, + "department": {"Engineering"}, + "memberOf": {"CN=Users,OU=Groups,DC=example,DC=com"}, + }, + "admin": { + "mail": {"admin@example.com"}, + "cn": {"Administrator"}, + "department": {"IT"}, + "memberOf": {"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, + }, + }, + } +} + +func (m *MockLDAPServer) Authenticate(username, password string) bool { + _, exists := m.users[username] + return exists && password == "password" // Mock authentication +} + +func (m *MockLDAPServer) GetUserAttributes(username string) (map[string][]string, error) { + if attrs, exists := m.users[username]; exists { + return attrs, nil + } + return nil, fmt.Errorf("user not found: %s", username) +} diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go new file mode 100644 index 000000000..f9d895ab4 --- /dev/null +++ b/weed/iam/oidc/oidc_provider.go @@ -0,0 +1,124 @@ +package oidc + +import ( + "context" + "fmt" + + "github.com/seaweedfs/seaweedfs/weed/iam/providers" +) + +// OIDCProvider implements OpenID Connect authentication +type OIDCProvider struct { + name string + config *OIDCConfig + initialized bool +} + +// OIDCConfig holds OIDC provider configuration +type OIDCConfig struct { + // Issuer is the OIDC issuer URL + Issuer string `json:"issuer"` + + // ClientID is the OAuth2 client ID + ClientID string `json:"clientId"` + + // ClientSecret is the OAuth2 client secret (optional for public clients) + ClientSecret string `json:"clientSecret,omitempty"` + + // JWKSUri is the JSON Web Key Set URI + JWKSUri string `json:"jwksUri,omitempty"` + + // UserInfoUri is the UserInfo endpoint URI + UserInfoUri string `json:"userInfoUri,omitempty"` + + // Scopes are the OAuth2 scopes to request + Scopes []string `json:"scopes,omitempty"` + + // RoleMapping defines how to map OIDC claims to roles + RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` + + // ClaimsMapping defines how to map OIDC claims to identity attributes + ClaimsMapping map[string]string `json:"claimsMapping,omitempty"` +} + +// NewOIDCProvider creates a new OIDC provider +func NewOIDCProvider(name string) *OIDCProvider { + return &OIDCProvider{ + name: name, + } +} + +// Name returns the provider name +func (p *OIDCProvider) Name() string { + return p.name +} + +// Initialize initializes the OIDC provider with configuration +func (p *OIDCProvider) Initialize(config interface{}) error { + oidcConfig, ok := config.(*OIDCConfig) + if !ok { + return fmt.Errorf("invalid config type for OIDC provider") + } + + if err := p.validateConfig(oidcConfig); err != nil { + return fmt.Errorf("invalid OIDC configuration: %w", err) + } + + p.config = oidcConfig + p.initialized = true + + // TODO: Initialize OIDC client, fetch JWKS, etc. + return fmt.Errorf("not implemented yet") +} + +// validateConfig validates the OIDC configuration +func (p *OIDCProvider) validateConfig(config *OIDCConfig) error { + if config.Issuer == "" { + return fmt.Errorf("issuer is required") + } + + if config.ClientID == "" { + return fmt.Errorf("client ID is required") + } + + // Basic URL validation for issuer + if config.Issuer != "" && config.Issuer != "https://accounts.google.com" && config.Issuer[0:4] != "http" { + return fmt.Errorf("invalid issuer URL format") + } + + return nil +} + +// Authenticate authenticates a user with an OIDC token +func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { + if !p.initialized { + return nil, fmt.Errorf("provider not initialized") + } + + // TODO: Validate JWT token, extract claims, map to identity + return nil, fmt.Errorf("not implemented yet") +} + +// GetUserInfo retrieves user information from the UserInfo endpoint +func (p *OIDCProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { + if !p.initialized { + return nil, fmt.Errorf("provider not initialized") + } + + if userID == "" { + return nil, fmt.Errorf("user ID cannot be empty") + } + + // TODO: Call UserInfo endpoint + return nil, fmt.Errorf("not implemented yet") +} + +// ValidateToken validates an OIDC JWT token +func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { + if !p.initialized { + return nil, fmt.Errorf("provider not initialized") + } + + // TODO: Validate JWT signature, claims, expiration + return nil, fmt.Errorf("not implemented yet") +} diff --git a/weed/iam/oidc/oidc_provider_test.go b/weed/iam/oidc/oidc_provider_test.go new file mode 100644 index 000000000..c8441b810 --- /dev/null +++ b/weed/iam/oidc/oidc_provider_test.go @@ -0,0 +1,318 @@ +package oidc + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/seaweedfs/seaweedfs/weed/iam/providers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestOIDCProviderInitialization tests OIDC provider initialization +func TestOIDCProviderInitialization(t *testing.T) { + tests := []struct { + name string + config *OIDCConfig + wantErr bool + }{ + { + name: "valid config", + config: &OIDCConfig{ + Issuer: "https://accounts.google.com", + ClientID: "test-client-id", + JWKSUri: "https://www.googleapis.com/oauth2/v3/certs", + }, + wantErr: false, + }, + { + name: "missing issuer", + config: &OIDCConfig{ + ClientID: "test-client-id", + }, + wantErr: true, + }, + { + name: "missing client id", + config: &OIDCConfig{ + Issuer: "https://accounts.google.com", + }, + wantErr: true, + }, + { + name: "invalid issuer url", + config: &OIDCConfig{ + Issuer: "not-a-url", + ClientID: "test-client-id", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := NewOIDCProvider("test-provider") + + err := provider.Initialize(tt.config) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, "test-provider", provider.Name()) + } + }) + } +} + +// TestOIDCProviderJWTValidation tests JWT token validation +func TestOIDCProviderJWTValidation(t *testing.T) { + // Set up test server with JWKS endpoint + privateKey, publicKey := generateTestKeys(t) + + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "RSA", + "kid": "test-key-id", + "use": "sig", + "alg": "RS256", + "n": encodePublicKey(t, publicKey), + "e": "AQAB", + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid_configuration" { + config := map[string]interface{}{ + "issuer": "http://" + r.Host, + "jwks_uri": "http://" + r.Host + "/jwks", + } + json.NewEncoder(w).Encode(config) + } else if r.URL.Path == "/jwks" { + json.NewEncoder(w).Encode(jwks) + } + })) + defer server.Close() + + provider := NewOIDCProvider("test-oidc") + config := &OIDCConfig{ + Issuer: server.URL, + ClientID: "test-client", + JWKSUri: server.URL + "/jwks", + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("valid token", func(t *testing.T) { + // Create valid JWT token + token := createTestJWT(t, privateKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "test-client", + "sub": "user123", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + "email": "user@example.com", + "name": "Test User", + }) + + claims, err := provider.ValidateToken(context.Background(), token) + assert.NoError(t, err) + assert.Equal(t, "user123", claims.Subject) + assert.Equal(t, server.URL, claims.Issuer) + + email, exists := claims.GetClaimString("email") + assert.True(t, exists) + assert.Equal(t, "user@example.com", email) + }) + + t.Run("expired token", func(t *testing.T) { + // Create expired JWT token + token := createTestJWT(t, privateKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "test-client", + "sub": "user123", + "exp": time.Now().Add(-time.Hour).Unix(), // Expired + "iat": time.Now().Add(-time.Hour * 2).Unix(), + }) + + _, err := provider.ValidateToken(context.Background(), token) + assert.Error(t, err) + }) + + t.Run("invalid signature", func(t *testing.T) { + // Create token with wrong key + wrongKey, _ := generateTestKeys(t) + token := createTestJWT(t, wrongKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "test-client", + "sub": "user123", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + + _, err := provider.ValidateToken(context.Background(), token) + assert.Error(t, err) + }) +} + +// TestOIDCProviderAuthentication tests authentication flow +func TestOIDCProviderAuthentication(t *testing.T) { + // Set up test OIDC provider + privateKey, publicKey := generateTestKeys(t) + + server := setupOIDCTestServer(t, publicKey) + defer server.Close() + + provider := NewOIDCProvider("test-oidc") + config := &OIDCConfig{ + Issuer: server.URL, + ClientID: "test-client", + JWKSUri: server.URL + "/jwks", + RoleMapping: &providers.RoleMapping{ + Rules: []providers.MappingRule{ + { + Claim: "email", + Value: "*@example.com", + Role: "arn:seaweed:iam::role/UserRole", + }, + { + Claim: "groups", + Value: "admins", + Role: "arn:seaweed:iam::role/AdminRole", + }, + }, + DefaultRole: "arn:seaweed:iam::role/GuestRole", + }, + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("successful authentication", func(t *testing.T) { + token := createTestJWT(t, privateKey, jwt.MapClaims{ + "iss": server.URL, + "aud": "test-client", + "sub": "user123", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + "email": "user@example.com", + "name": "Test User", + "groups": []string{"users", "developers"}, + }) + + identity, err := provider.Authenticate(context.Background(), token) + assert.NoError(t, err) + assert.Equal(t, "user123", identity.UserID) + assert.Equal(t, "user@example.com", identity.Email) + assert.Equal(t, "Test User", identity.DisplayName) + assert.Equal(t, "test-oidc", identity.Provider) + assert.Contains(t, identity.Groups, "users") + assert.Contains(t, identity.Groups, "developers") + }) + + t.Run("authentication with invalid token", func(t *testing.T) { + _, err := provider.Authenticate(context.Background(), "invalid-token") + assert.Error(t, err) + }) +} + +// TestOIDCProviderUserInfo tests user info retrieval +func TestOIDCProviderUserInfo(t *testing.T) { + // Set up test server with UserInfo endpoint + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/userinfo" { + userInfo := map[string]interface{}{ + "sub": r.URL.Query().Get("user_id"), + "email": "user@example.com", + "name": "Test User", + } + json.NewEncoder(w).Encode(userInfo) + } + })) + defer server.Close() + + provider := NewOIDCProvider("test-oidc") + config := &OIDCConfig{ + Issuer: server.URL, + ClientID: "test-client", + UserInfoUri: server.URL + "/userinfo", + } + + err := provider.Initialize(config) + require.NoError(t, err) + + t.Run("get user info", func(t *testing.T) { + identity, err := provider.GetUserInfo(context.Background(), "user123") + assert.NoError(t, err) + assert.Equal(t, "user123", identity.UserID) + assert.Equal(t, "user@example.com", identity.Email) + assert.Equal(t, "Test User", identity.DisplayName) + }) + + t.Run("get user info with empty id", func(t *testing.T) { + _, err := provider.GetUserInfo(context.Background(), "") + assert.Error(t, err) + }) +} + +// Helper functions for testing + +func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return privateKey, &privateKey.PublicKey +} + +func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + return tokenString +} + +func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string { + // This is a simplified version - real implementation would properly encode the public key + return "test-public-key-n-value" +} + +func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server { + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "RSA", + "kid": "test-key-id", + "use": "sig", + "alg": "RS256", + "n": encodePublicKey(t, publicKey), + "e": "AQAB", + }, + }, + } + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid_configuration": + config := map[string]interface{}{ + "issuer": "http://" + r.Host, + "jwks_uri": "http://" + r.Host + "/jwks", + } + json.NewEncoder(w).Encode(config) + case "/jwks": + json.NewEncoder(w).Encode(jwks) + default: + http.NotFound(w, r) + } + })) +} diff --git a/weed/iam/providers/provider.go b/weed/iam/providers/provider.go new file mode 100644 index 000000000..2a3d9790f --- /dev/null +++ b/weed/iam/providers/provider.go @@ -0,0 +1,224 @@ +package providers + +import ( + "context" + "fmt" + "net/mail" + "strings" + "time" +) + +// IdentityProvider defines the interface for external identity providers +type IdentityProvider interface { + // Name returns the unique name of the provider + Name() string + + // Initialize initializes the provider with configuration + Initialize(config interface{}) error + + // Authenticate authenticates a user with a token and returns external identity + Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) + + // GetUserInfo retrieves user information by user ID + GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) + + // ValidateToken validates a token and returns claims + ValidateToken(ctx context.Context, token string) (*TokenClaims, error) +} + +// ExternalIdentity represents an identity from an external provider +type ExternalIdentity struct { + // UserID is the unique identifier from the external provider + UserID string `json:"userId"` + + // Email is the user's email address + Email string `json:"email"` + + // DisplayName is the user's display name + DisplayName string `json:"displayName"` + + // Groups are the groups the user belongs to + Groups []string `json:"groups,omitempty"` + + // Attributes are additional user attributes + Attributes map[string]string `json:"attributes,omitempty"` + + // Provider is the name of the identity provider + Provider string `json:"provider"` +} + +// Validate validates the external identity structure +func (e *ExternalIdentity) Validate() error { + if e.UserID == "" { + return fmt.Errorf("user ID is required") + } + + if e.Provider == "" { + return fmt.Errorf("provider is required") + } + + if e.Email != "" { + if _, err := mail.ParseAddress(e.Email); err != nil { + return fmt.Errorf("invalid email format: %w", err) + } + } + + return nil +} + +// TokenClaims represents claims from a validated token +type TokenClaims struct { + // Subject (sub) - user identifier + Subject string `json:"sub"` + + // Issuer (iss) - token issuer + Issuer string `json:"iss"` + + // Audience (aud) - intended audience + Audience string `json:"aud"` + + // ExpiresAt (exp) - expiration time + ExpiresAt time.Time `json:"exp"` + + // IssuedAt (iat) - issued at time + IssuedAt time.Time `json:"iat"` + + // NotBefore (nbf) - not valid before time + NotBefore time.Time `json:"nbf,omitempty"` + + // Claims are additional claims from the token + Claims map[string]interface{} `json:"claims,omitempty"` +} + +// IsValid checks if the token claims are valid (not expired, etc.) +func (c *TokenClaims) IsValid() bool { + now := time.Now() + + // Check expiration + if !c.ExpiresAt.IsZero() && now.After(c.ExpiresAt) { + return false + } + + // Check not before + if !c.NotBefore.IsZero() && now.Before(c.NotBefore) { + return false + } + + // Check issued at (shouldn't be in the future) + if !c.IssuedAt.IsZero() && now.Before(c.IssuedAt) { + return false + } + + return true +} + +// GetClaimString returns a string claim value +func (c *TokenClaims) GetClaimString(key string) (string, bool) { + if value, exists := c.Claims[key]; exists { + if str, ok := value.(string); ok { + return str, true + } + } + return "", false +} + +// GetClaimStringSlice returns a string slice claim value +func (c *TokenClaims) GetClaimStringSlice(key string) ([]string, bool) { + if value, exists := c.Claims[key]; exists { + switch v := value.(type) { + case []string: + return v, true + case []interface{}: + var result []string + for _, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } + } + return result, len(result) > 0 + case string: + // Single string can be treated as slice + return []string{v}, true + } + } + return nil, false +} + +// ProviderConfig represents configuration for identity providers +type ProviderConfig struct { + // Type of provider (oidc, ldap, saml) + Type string `json:"type"` + + // Name of the provider instance + Name string `json:"name"` + + // Enabled indicates if the provider is active + Enabled bool `json:"enabled"` + + // Config is provider-specific configuration + Config map[string]interface{} `json:"config"` + + // RoleMapping defines how to map external identities to roles + RoleMapping *RoleMapping `json:"roleMapping,omitempty"` +} + +// RoleMapping defines rules for mapping external identities to roles +type RoleMapping struct { + // Rules are the mapping rules + Rules []MappingRule `json:"rules"` + + // DefaultRole is assigned if no rules match + DefaultRole string `json:"defaultRole,omitempty"` +} + +// MappingRule defines a single mapping rule +type MappingRule struct { + // Claim is the claim key to check + Claim string `json:"claim"` + + // Value is the expected claim value (supports wildcards) + Value string `json:"value"` + + // Role is the role ARN to assign + Role string `json:"role"` + + // Condition is additional condition logic (optional) + Condition string `json:"condition,omitempty"` +} + +// Matches checks if a rule matches the given claims +func (r *MappingRule) Matches(claims *TokenClaims) bool { + if r.Claim == "" || r.Value == "" { + return false + } + + claimValue, exists := claims.GetClaimString(r.Claim) + if !exists { + // Try as string slice + if claimSlice, sliceExists := claims.GetClaimStringSlice(r.Claim); sliceExists { + for _, val := range claimSlice { + if r.matchValue(val) { + return true + } + } + } + return false + } + + return r.matchValue(claimValue) +} + +// matchValue checks if a value matches the rule value (with wildcard support) +func (r *MappingRule) matchValue(value string) bool { + // Simple wildcard matching + if strings.Contains(r.Value, "*") { + // Convert wildcard to regex-like matching + pattern := strings.ReplaceAll(r.Value, "*", "") + if pattern == "" { + return true // "*" matches everything + } + return strings.Contains(value, pattern) + } + + return value == r.Value +} diff --git a/weed/iam/providers/provider_test.go b/weed/iam/providers/provider_test.go new file mode 100644 index 000000000..99cf360c1 --- /dev/null +++ b/weed/iam/providers/provider_test.go @@ -0,0 +1,246 @@ +package providers + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestIdentityProviderInterface tests the core identity provider interface +func TestIdentityProviderInterface(t *testing.T) { + tests := []struct { + name string + provider IdentityProvider + wantErr bool + }{ + // We'll add test cases as we implement providers + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test provider name + name := tt.provider.Name() + assert.NotEmpty(t, name, "Provider name should not be empty") + + // Test initialization + err := tt.provider.Initialize(nil) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + + // Test authentication with invalid token + ctx := context.Background() + _, err = tt.provider.Authenticate(ctx, "invalid-token") + assert.Error(t, err, "Should fail with invalid token") + }) + } +} + +// TestExternalIdentityValidation tests external identity structure validation +func TestExternalIdentityValidation(t *testing.T) { + tests := []struct { + name string + identity *ExternalIdentity + wantErr bool + }{ + { + name: "valid identity", + identity: &ExternalIdentity{ + UserID: "user123", + Email: "user@example.com", + DisplayName: "Test User", + Groups: []string{"group1", "group2"}, + Attributes: map[string]string{"dept": "engineering"}, + Provider: "test-provider", + }, + wantErr: false, + }, + { + name: "missing user id", + identity: &ExternalIdentity{ + Email: "user@example.com", + Provider: "test-provider", + }, + wantErr: true, + }, + { + name: "missing provider", + identity: &ExternalIdentity{ + UserID: "user123", + Email: "user@example.com", + }, + wantErr: true, + }, + { + name: "invalid email", + identity: &ExternalIdentity{ + UserID: "user123", + Email: "invalid-email", + Provider: "test-provider", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.identity.Validate() + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestTokenClaimsValidation tests token claims structure +func TestTokenClaimsValidation(t *testing.T) { + tests := []struct { + name string + claims *TokenClaims + valid bool + }{ + { + name: "valid claims", + claims: &TokenClaims{ + Subject: "user123", + Issuer: "https://provider.example.com", + Audience: "seaweedfs", + ExpiresAt: time.Now().Add(time.Hour), + IssuedAt: time.Now().Add(-time.Minute), + Claims: map[string]interface{}{"email": "user@example.com"}, + }, + valid: true, + }, + { + name: "expired token", + claims: &TokenClaims{ + Subject: "user123", + Issuer: "https://provider.example.com", + Audience: "seaweedfs", + ExpiresAt: time.Now().Add(-time.Hour), // Expired + IssuedAt: time.Now().Add(-time.Hour * 2), + Claims: map[string]interface{}{"email": "user@example.com"}, + }, + valid: false, + }, + { + name: "future issued token", + claims: &TokenClaims{ + Subject: "user123", + Issuer: "https://provider.example.com", + Audience: "seaweedfs", + ExpiresAt: time.Now().Add(time.Hour), + IssuedAt: time.Now().Add(time.Hour), // Future + Claims: map[string]interface{}{"email": "user@example.com"}, + }, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + valid := tt.claims.IsValid() + assert.Equal(t, tt.valid, valid) + }) + } +} + +// TestProviderRegistry tests provider registration and discovery +func TestProviderRegistry(t *testing.T) { + // Clear registry for test + registry := NewProviderRegistry() + + t.Run("register provider", func(t *testing.T) { + mockProvider := &MockProvider{name: "test-provider"} + + err := registry.RegisterProvider(mockProvider) + assert.NoError(t, err) + + // Test duplicate registration + err = registry.RegisterProvider(mockProvider) + assert.Error(t, err, "Should not allow duplicate registration") + }) + + t.Run("get provider", func(t *testing.T) { + provider, exists := registry.GetProvider("test-provider") + assert.True(t, exists) + assert.Equal(t, "test-provider", provider.Name()) + + // Test non-existent provider + _, exists = registry.GetProvider("non-existent") + assert.False(t, exists) + }) + + t.Run("list providers", func(t *testing.T) { + providers := registry.ListProviders() + assert.Len(t, providers, 1) + assert.Equal(t, "test-provider", providers[0]) + }) +} + +// MockProvider for testing +type MockProvider struct { + name string + initialized bool + shouldError bool +} + +func (m *MockProvider) Name() string { + return m.name +} + +func (m *MockProvider) Initialize(config interface{}) error { + if m.shouldError { + return assert.AnError + } + m.initialized = true + return nil +} + +func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) { + if !m.initialized { + return nil, assert.AnError + } + if token == "invalid-token" { + return nil, assert.AnError + } + return &ExternalIdentity{ + UserID: "test-user", + Email: "test@example.com", + DisplayName: "Test User", + Provider: m.name, + }, nil +} + +func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) { + if !m.initialized || userID == "" { + return nil, assert.AnError + } + return &ExternalIdentity{ + UserID: userID, + Email: userID + "@example.com", + DisplayName: "User " + userID, + Provider: m.name, + }, nil +} + +func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) { + if !m.initialized || token == "invalid-token" { + return nil, assert.AnError + } + return &TokenClaims{ + Subject: "test-user", + Issuer: "test-issuer", + Audience: "seaweedfs", + ExpiresAt: time.Now().Add(time.Hour), + IssuedAt: time.Now(), + Claims: map[string]interface{}{"email": "test@example.com"}, + }, nil +} diff --git a/weed/iam/providers/registry.go b/weed/iam/providers/registry.go new file mode 100644 index 000000000..dee50df44 --- /dev/null +++ b/weed/iam/providers/registry.go @@ -0,0 +1,109 @@ +package providers + +import ( + "fmt" + "sync" +) + +// ProviderRegistry manages registered identity providers +type ProviderRegistry struct { + mu sync.RWMutex + providers map[string]IdentityProvider +} + +// NewProviderRegistry creates a new provider registry +func NewProviderRegistry() *ProviderRegistry { + return &ProviderRegistry{ + providers: make(map[string]IdentityProvider), + } +} + +// RegisterProvider registers a new identity provider +func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error { + if provider == nil { + return fmt.Errorf("provider cannot be nil") + } + + name := provider.Name() + if name == "" { + return fmt.Errorf("provider name cannot be empty") + } + + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.providers[name]; exists { + return fmt.Errorf("provider %s is already registered", name) + } + + r.providers[name] = provider + return nil +} + +// GetProvider retrieves a provider by name +func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + provider, exists := r.providers[name] + return provider, exists +} + +// ListProviders returns all registered provider names +func (r *ProviderRegistry) ListProviders() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + var names []string + for name := range r.providers { + names = append(names, name) + } + return names +} + +// UnregisterProvider removes a provider from the registry +func (r *ProviderRegistry) UnregisterProvider(name string) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.providers[name]; !exists { + return fmt.Errorf("provider %s is not registered", name) + } + + delete(r.providers, name) + return nil +} + +// Clear removes all providers from the registry +func (r *ProviderRegistry) Clear() { + r.mu.Lock() + defer r.mu.Unlock() + + r.providers = make(map[string]IdentityProvider) +} + +// GetProviderCount returns the number of registered providers +func (r *ProviderRegistry) GetProviderCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.providers) +} + +// Default global registry +var defaultRegistry = NewProviderRegistry() + +// RegisterProvider registers a provider in the default registry +func RegisterProvider(provider IdentityProvider) error { + return defaultRegistry.RegisterProvider(provider) +} + +// GetProvider retrieves a provider from the default registry +func GetProvider(name string) (IdentityProvider, bool) { + return defaultRegistry.GetProvider(name) +} + +// ListProviders returns all provider names from the default registry +func ListProviders() []string { + return defaultRegistry.ListProviders() +}