Browse Source
TDD RED Phase: Add identity provider framework tests
TDD RED Phase: Add identity provider framework tests
- Add core IdentityProvider interface with tests - Add OIDC provider tests with JWT token validation - Add LDAP provider tests with authentication flows - Add ProviderRegistry for managing multiple providers - Tests currently failing as expected in TDD RED phasepull/7160/head
7 changed files with 1597 additions and 0 deletions
-
219weed/iam/ldap/ldap_provider.go
-
357weed/iam/ldap/ldap_provider_test.go
-
124weed/iam/oidc/oidc_provider.go
-
318weed/iam/oidc/oidc_provider_test.go
-
224weed/iam/providers/provider.go
-
246weed/iam/providers/provider_test.go
-
109weed/iam/providers/registry.go
@ -0,0 +1,219 @@ |
|||||
|
package ldap |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"fmt" |
||||
|
"strings" |
||||
|
|
||||
|
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
||||
|
) |
||||
|
|
||||
|
// LDAPProvider implements LDAP authentication
|
||||
|
type LDAPProvider struct { |
||||
|
name string |
||||
|
config *LDAPConfig |
||||
|
initialized bool |
||||
|
connPool interface{} // Will be proper LDAP connection pool
|
||||
|
} |
||||
|
|
||||
|
// LDAPConfig holds LDAP provider configuration
|
||||
|
type LDAPConfig struct { |
||||
|
// Server is the LDAP server URL (e.g., ldap://localhost:389)
|
||||
|
Server string `json:"server"` |
||||
|
|
||||
|
// BaseDN is the base distinguished name for searches
|
||||
|
BaseDN string `json:"baseDn"` |
||||
|
|
||||
|
// BindDN is the distinguished name for binding (authentication)
|
||||
|
BindDN string `json:"bindDn,omitempty"` |
||||
|
|
||||
|
// BindPass is the password for binding
|
||||
|
BindPass string `json:"bindPass,omitempty"` |
||||
|
|
||||
|
// UserFilter is the LDAP filter for finding users (e.g., "(sAMAccountName=%s)")
|
||||
|
UserFilter string `json:"userFilter"` |
||||
|
|
||||
|
// GroupFilter is the LDAP filter for finding groups (e.g., "(member=%s)")
|
||||
|
GroupFilter string `json:"groupFilter,omitempty"` |
||||
|
|
||||
|
// Attributes maps SeaweedFS identity fields to LDAP attributes
|
||||
|
Attributes map[string]string `json:"attributes,omitempty"` |
||||
|
|
||||
|
// RoleMapping defines how to map LDAP groups to roles
|
||||
|
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` |
||||
|
|
||||
|
// TLS configuration
|
||||
|
UseTLS bool `json:"useTls,omitempty"` |
||||
|
TLSCert string `json:"tlsCert,omitempty"` |
||||
|
TLSKey string `json:"tlsKey,omitempty"` |
||||
|
TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"` |
||||
|
|
||||
|
// Connection pool settings
|
||||
|
MaxConnections int `json:"maxConnections,omitempty"` |
||||
|
ConnTimeout int `json:"connTimeout,omitempty"` // seconds
|
||||
|
} |
||||
|
|
||||
|
// NewLDAPProvider creates a new LDAP provider
|
||||
|
func NewLDAPProvider(name string) *LDAPProvider { |
||||
|
return &LDAPProvider{ |
||||
|
name: name, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Name returns the provider name
|
||||
|
func (p *LDAPProvider) Name() string { |
||||
|
return p.name |
||||
|
} |
||||
|
|
||||
|
// Initialize initializes the LDAP provider with configuration
|
||||
|
func (p *LDAPProvider) Initialize(config interface{}) error { |
||||
|
ldapConfig, ok := config.(*LDAPConfig) |
||||
|
if !ok { |
||||
|
return fmt.Errorf("invalid config type for LDAP provider") |
||||
|
} |
||||
|
|
||||
|
if err := p.validateConfig(ldapConfig); err != nil { |
||||
|
return fmt.Errorf("invalid LDAP configuration: %w", err) |
||||
|
} |
||||
|
|
||||
|
p.config = ldapConfig |
||||
|
p.initialized = true |
||||
|
|
||||
|
// TODO: Initialize LDAP connection pool
|
||||
|
return fmt.Errorf("not implemented yet") |
||||
|
} |
||||
|
|
||||
|
// validateConfig validates the LDAP configuration
|
||||
|
func (p *LDAPProvider) validateConfig(config *LDAPConfig) error { |
||||
|
if config.Server == "" { |
||||
|
return fmt.Errorf("server is required") |
||||
|
} |
||||
|
|
||||
|
if config.BaseDN == "" { |
||||
|
return fmt.Errorf("base DN is required") |
||||
|
} |
||||
|
|
||||
|
// Basic URL validation
|
||||
|
if !strings.HasPrefix(config.Server, "ldap://") && !strings.HasPrefix(config.Server, "ldaps://") { |
||||
|
return fmt.Errorf("invalid server URL format") |
||||
|
} |
||||
|
|
||||
|
// Set default user filter if not provided
|
||||
|
if config.UserFilter == "" { |
||||
|
config.UserFilter = "(uid=%s)" // Default LDAP user filter
|
||||
|
} |
||||
|
|
||||
|
// Set default attributes if not provided
|
||||
|
if config.Attributes == nil { |
||||
|
config.Attributes = map[string]string{ |
||||
|
"email": "mail", |
||||
|
"displayName": "cn", |
||||
|
"groups": "memberOf", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Authenticate authenticates a user with LDAP
|
||||
|
func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*providers.ExternalIdentity, error) { |
||||
|
if !p.initialized { |
||||
|
return nil, fmt.Errorf("provider not initialized") |
||||
|
} |
||||
|
|
||||
|
// TODO: Parse credentials (username:password), bind to LDAP, search for user
|
||||
|
return nil, fmt.Errorf("not implemented yet") |
||||
|
} |
||||
|
|
||||
|
// GetUserInfo retrieves user information from LDAP
|
||||
|
func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { |
||||
|
if !p.initialized { |
||||
|
return nil, fmt.Errorf("provider not initialized") |
||||
|
} |
||||
|
|
||||
|
if userID == "" { |
||||
|
return nil, fmt.Errorf("user ID cannot be empty") |
||||
|
} |
||||
|
|
||||
|
// TODO: Search LDAP for user, get attributes
|
||||
|
return nil, fmt.Errorf("not implemented yet") |
||||
|
} |
||||
|
|
||||
|
// ValidateToken validates credentials (for LDAP, this is username/password)
|
||||
|
func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { |
||||
|
if !p.initialized { |
||||
|
return nil, fmt.Errorf("provider not initialized") |
||||
|
} |
||||
|
|
||||
|
// TODO: For LDAP, "token" would be username:password format
|
||||
|
// Validate credentials and return claims
|
||||
|
return nil, fmt.Errorf("not implemented yet") |
||||
|
} |
||||
|
|
||||
|
// mapLDAPAttributes maps LDAP attributes to ExternalIdentity
|
||||
|
func (p *LDAPProvider) mapLDAPAttributes(userID string, attrs map[string][]string) *providers.ExternalIdentity { |
||||
|
identity := &providers.ExternalIdentity{ |
||||
|
UserID: userID, |
||||
|
Provider: p.name, |
||||
|
Attributes: make(map[string]string), |
||||
|
} |
||||
|
|
||||
|
// Map configured attributes
|
||||
|
for identityField, ldapAttr := range p.config.Attributes { |
||||
|
if values, exists := attrs[ldapAttr]; exists && len(values) > 0 { |
||||
|
switch identityField { |
||||
|
case "email": |
||||
|
identity.Email = values[0] |
||||
|
case "displayName": |
||||
|
identity.DisplayName = values[0] |
||||
|
case "groups": |
||||
|
identity.Groups = values |
||||
|
default: |
||||
|
// Store as custom attribute
|
||||
|
identity.Attributes[identityField] = values[0] |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return identity |
||||
|
} |
||||
|
|
||||
|
// mapUserToRole maps user groups to roles based on role mapping rules
|
||||
|
func (p *LDAPProvider) mapUserToRole(identity *providers.ExternalIdentity) string { |
||||
|
if p.config.RoleMapping == nil { |
||||
|
return "" |
||||
|
} |
||||
|
|
||||
|
// Create token claims from identity for rule matching
|
||||
|
claims := &providers.TokenClaims{ |
||||
|
Subject: identity.UserID, |
||||
|
Claims: map[string]interface{}{ |
||||
|
"groups": identity.Groups, |
||||
|
"email": identity.Email, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
// Check mapping rules
|
||||
|
for _, rule := range p.config.RoleMapping.Rules { |
||||
|
if rule.Matches(claims) { |
||||
|
return rule.Role |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Return default role if no rules match
|
||||
|
return p.config.RoleMapping.DefaultRole |
||||
|
} |
||||
|
|
||||
|
// Connection management methods (stubs for now)
|
||||
|
func (p *LDAPProvider) getConnectionPool() interface{} { |
||||
|
return p.connPool |
||||
|
} |
||||
|
|
||||
|
func (p *LDAPProvider) getConnection() (interface{}, error) { |
||||
|
// TODO: Get connection from pool
|
||||
|
return nil, fmt.Errorf("not implemented") |
||||
|
} |
||||
|
|
||||
|
func (p *LDAPProvider) releaseConnection(conn interface{}) { |
||||
|
// TODO: Return connection to pool
|
||||
|
} |
@ -0,0 +1,357 @@ |
|||||
|
package ldap |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"fmt" |
||||
|
"testing" |
||||
|
|
||||
|
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
||||
|
"github.com/stretchr/testify/assert" |
||||
|
"github.com/stretchr/testify/require" |
||||
|
) |
||||
|
|
||||
|
// TestLDAPProviderInitialization tests LDAP provider initialization
|
||||
|
func TestLDAPProviderInitialization(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
config *LDAPConfig |
||||
|
wantErr bool |
||||
|
}{ |
||||
|
{ |
||||
|
name: "valid config", |
||||
|
config: &LDAPConfig{ |
||||
|
Server: "ldap://localhost:389", |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
BindDN: "CN=admin,DC=example,DC=com", |
||||
|
BindPass: "password", |
||||
|
UserFilter: "(sAMAccountName=%s)", |
||||
|
GroupFilter: "(member=%s)", |
||||
|
}, |
||||
|
wantErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "missing server", |
||||
|
config: &LDAPConfig{ |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "missing base DN", |
||||
|
config: &LDAPConfig{ |
||||
|
Server: "ldap://localhost:389", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "invalid server URL", |
||||
|
config: &LDAPConfig{ |
||||
|
Server: "invalid-url", |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
provider := NewLDAPProvider("test-ldap") |
||||
|
|
||||
|
err := provider.Initialize(tt.config) |
||||
|
|
||||
|
if tt.wantErr { |
||||
|
assert.Error(t, err) |
||||
|
} else { |
||||
|
assert.NoError(t, err) |
||||
|
assert.Equal(t, "test-ldap", provider.Name()) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TestLDAPProviderAuthentication tests LDAP authentication
|
||||
|
func TestLDAPProviderAuthentication(t *testing.T) { |
||||
|
// Skip if no LDAP test server available
|
||||
|
if testing.Short() { |
||||
|
t.Skip("Skipping LDAP integration test in short mode") |
||||
|
} |
||||
|
|
||||
|
provider := NewLDAPProvider("test-ldap") |
||||
|
config := &LDAPConfig{ |
||||
|
Server: "ldap://localhost:389", |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
BindDN: "CN=admin,DC=example,DC=com", |
||||
|
BindPass: "password", |
||||
|
UserFilter: "(sAMAccountName=%s)", |
||||
|
GroupFilter: "(member=%s)", |
||||
|
Attributes: map[string]string{ |
||||
|
"email": "mail", |
||||
|
"displayName": "displayName", |
||||
|
"groups": "memberOf", |
||||
|
}, |
||||
|
RoleMapping: &providers.RoleMapping{ |
||||
|
Rules: []providers.MappingRule{ |
||||
|
{ |
||||
|
Claim: "groups", |
||||
|
Value: "*CN=Admins*", |
||||
|
Role: "arn:seaweed:iam::role/AdminRole", |
||||
|
}, |
||||
|
{ |
||||
|
Claim: "groups", |
||||
|
Value: "*CN=Users*", |
||||
|
Role: "arn:seaweed:iam::role/UserRole", |
||||
|
}, |
||||
|
}, |
||||
|
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
t.Run("authenticate with username/password", func(t *testing.T) { |
||||
|
// This would require an actual LDAP server for integration testing
|
||||
|
credentials := "user:password" // Basic auth format
|
||||
|
|
||||
|
identity, err := provider.Authenticate(context.Background(), credentials) |
||||
|
if err != nil { |
||||
|
t.Skip("LDAP server not available for testing") |
||||
|
} |
||||
|
|
||||
|
assert.NoError(t, err) |
||||
|
assert.Equal(t, "user", identity.UserID) |
||||
|
assert.Equal(t, "test-ldap", identity.Provider) |
||||
|
assert.NotEmpty(t, identity.Email) |
||||
|
}) |
||||
|
|
||||
|
t.Run("authenticate with invalid credentials", func(t *testing.T) { |
||||
|
_, err := provider.Authenticate(context.Background(), "invalid:credentials") |
||||
|
assert.Error(t, err) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// TestLDAPProviderUserInfo tests LDAP user info retrieval
|
||||
|
func TestLDAPProviderUserInfo(t *testing.T) { |
||||
|
if testing.Short() { |
||||
|
t.Skip("Skipping LDAP integration test in short mode") |
||||
|
} |
||||
|
|
||||
|
provider := NewLDAPProvider("test-ldap") |
||||
|
config := &LDAPConfig{ |
||||
|
Server: "ldap://localhost:389", |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
BindDN: "CN=admin,DC=example,DC=com", |
||||
|
BindPass: "password", |
||||
|
UserFilter: "(sAMAccountName=%s)", |
||||
|
Attributes: map[string]string{ |
||||
|
"email": "mail", |
||||
|
"displayName": "displayName", |
||||
|
"department": "department", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
t.Run("get user info", func(t *testing.T) { |
||||
|
identity, err := provider.GetUserInfo(context.Background(), "testuser") |
||||
|
if err != nil { |
||||
|
t.Skip("LDAP server not available for testing") |
||||
|
} |
||||
|
|
||||
|
assert.NoError(t, err) |
||||
|
assert.Equal(t, "testuser", identity.UserID) |
||||
|
assert.Equal(t, "test-ldap", identity.Provider) |
||||
|
assert.NotEmpty(t, identity.Email) |
||||
|
assert.NotEmpty(t, identity.DisplayName) |
||||
|
}) |
||||
|
|
||||
|
t.Run("get user info with empty username", func(t *testing.T) { |
||||
|
_, err := provider.GetUserInfo(context.Background(), "") |
||||
|
assert.Error(t, err) |
||||
|
}) |
||||
|
|
||||
|
t.Run("get user info for non-existent user", func(t *testing.T) { |
||||
|
_, err := provider.GetUserInfo(context.Background(), "nonexistent") |
||||
|
assert.Error(t, err) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// TestLDAPAttributeMapping tests LDAP attribute mapping
|
||||
|
func TestLDAPAttributeMapping(t *testing.T) { |
||||
|
provider := NewLDAPProvider("test-ldap") |
||||
|
config := &LDAPConfig{ |
||||
|
Server: "ldap://localhost:389", |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
Attributes: map[string]string{ |
||||
|
"email": "mail", |
||||
|
"displayName": "cn", |
||||
|
"department": "departmentNumber", |
||||
|
"groups": "memberOf", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
t.Run("map LDAP attributes to identity", func(t *testing.T) { |
||||
|
ldapAttrs := map[string][]string{ |
||||
|
"mail": {"user@example.com"}, |
||||
|
"cn": {"John Doe"}, |
||||
|
"departmentNumber": {"IT"}, |
||||
|
"memberOf": { |
||||
|
"CN=Users,OU=Groups,DC=example,DC=com", |
||||
|
"CN=Developers,OU=Groups,DC=example,DC=com", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
identity := provider.mapLDAPAttributes("testuser", ldapAttrs) |
||||
|
|
||||
|
assert.Equal(t, "testuser", identity.UserID) |
||||
|
assert.Equal(t, "user@example.com", identity.Email) |
||||
|
assert.Equal(t, "John Doe", identity.DisplayName) |
||||
|
assert.Equal(t, "test-ldap", identity.Provider) |
||||
|
|
||||
|
// Check groups
|
||||
|
assert.Contains(t, identity.Groups, "CN=Users,OU=Groups,DC=example,DC=com") |
||||
|
assert.Contains(t, identity.Groups, "CN=Developers,OU=Groups,DC=example,DC=com") |
||||
|
|
||||
|
// Check attributes
|
||||
|
assert.Equal(t, "IT", identity.Attributes["department"]) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// TestLDAPGroupFiltering tests LDAP group filtering and role mapping
|
||||
|
func TestLDAPGroupFiltering(t *testing.T) { |
||||
|
provider := NewLDAPProvider("test-ldap") |
||||
|
config := &LDAPConfig{ |
||||
|
Server: "ldap://localhost:389", |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
RoleMapping: &providers.RoleMapping{ |
||||
|
Rules: []providers.MappingRule{ |
||||
|
{ |
||||
|
Claim: "groups", |
||||
|
Value: "*Admins*", |
||||
|
Role: "arn:seaweed:iam::role/AdminRole", |
||||
|
}, |
||||
|
{ |
||||
|
Claim: "groups", |
||||
|
Value: "*Users*", |
||||
|
Role: "arn:seaweed:iam::role/UserRole", |
||||
|
}, |
||||
|
}, |
||||
|
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
tests := []struct { |
||||
|
name string |
||||
|
groups []string |
||||
|
expectedRole string |
||||
|
expectedClaims map[string]interface{} |
||||
|
}{ |
||||
|
{ |
||||
|
name: "admin user", |
||||
|
groups: []string{"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
||||
|
expectedRole: "arn:seaweed:iam::role/AdminRole", |
||||
|
}, |
||||
|
{ |
||||
|
name: "regular user", |
||||
|
groups: []string{"CN=Users,OU=Groups,DC=example,DC=com"}, |
||||
|
expectedRole: "arn:seaweed:iam::role/UserRole", |
||||
|
}, |
||||
|
{ |
||||
|
name: "guest user", |
||||
|
groups: []string{"CN=Guests,OU=Groups,DC=example,DC=com"}, |
||||
|
expectedRole: "arn:seaweed:iam::role/GuestRole", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
identity := &providers.ExternalIdentity{ |
||||
|
UserID: "testuser", |
||||
|
Groups: tt.groups, |
||||
|
Provider: "test-ldap", |
||||
|
} |
||||
|
|
||||
|
role := provider.mapUserToRole(identity) |
||||
|
assert.Equal(t, tt.expectedRole, role) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TestLDAPConnectionPool tests LDAP connection pooling
|
||||
|
func TestLDAPConnectionPool(t *testing.T) { |
||||
|
if testing.Short() { |
||||
|
t.Skip("Skipping LDAP connection pool test in short mode") |
||||
|
} |
||||
|
|
||||
|
provider := NewLDAPProvider("test-ldap") |
||||
|
config := &LDAPConfig{ |
||||
|
Server: "ldap://localhost:389", |
||||
|
BaseDN: "DC=example,DC=com", |
||||
|
BindDN: "CN=admin,DC=example,DC=com", |
||||
|
BindPass: "password", |
||||
|
MaxConnections: 5, |
||||
|
ConnTimeout: 30, |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
t.Run("connection pool management", func(t *testing.T) { |
||||
|
// Test that multiple concurrent requests work
|
||||
|
// This would require actual LDAP server for full testing
|
||||
|
pool := provider.getConnectionPool() |
||||
|
assert.NotNil(t, pool) |
||||
|
|
||||
|
// Test connection acquisition and release
|
||||
|
conn, err := provider.getConnection() |
||||
|
if err != nil { |
||||
|
t.Skip("LDAP server not available") |
||||
|
} |
||||
|
|
||||
|
assert.NotNil(t, conn) |
||||
|
provider.releaseConnection(conn) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// MockLDAPServer for unit testing (without external dependencies)
|
||||
|
type MockLDAPServer struct { |
||||
|
users map[string]map[string][]string |
||||
|
} |
||||
|
|
||||
|
func NewMockLDAPServer() *MockLDAPServer { |
||||
|
return &MockLDAPServer{ |
||||
|
users: map[string]map[string][]string{ |
||||
|
"testuser": { |
||||
|
"mail": {"testuser@example.com"}, |
||||
|
"cn": {"Test User"}, |
||||
|
"department": {"Engineering"}, |
||||
|
"memberOf": {"CN=Users,OU=Groups,DC=example,DC=com"}, |
||||
|
}, |
||||
|
"admin": { |
||||
|
"mail": {"admin@example.com"}, |
||||
|
"cn": {"Administrator"}, |
||||
|
"department": {"IT"}, |
||||
|
"memberOf": {"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (m *MockLDAPServer) Authenticate(username, password string) bool { |
||||
|
_, exists := m.users[username] |
||||
|
return exists && password == "password" // Mock authentication
|
||||
|
} |
||||
|
|
||||
|
func (m *MockLDAPServer) GetUserAttributes(username string) (map[string][]string, error) { |
||||
|
if attrs, exists := m.users[username]; exists { |
||||
|
return attrs, nil |
||||
|
} |
||||
|
return nil, fmt.Errorf("user not found: %s", username) |
||||
|
} |
@ -0,0 +1,124 @@ |
|||||
|
package oidc |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"fmt" |
||||
|
|
||||
|
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
||||
|
) |
||||
|
|
||||
|
// OIDCProvider implements OpenID Connect authentication
|
||||
|
type OIDCProvider struct { |
||||
|
name string |
||||
|
config *OIDCConfig |
||||
|
initialized bool |
||||
|
} |
||||
|
|
||||
|
// OIDCConfig holds OIDC provider configuration
|
||||
|
type OIDCConfig struct { |
||||
|
// Issuer is the OIDC issuer URL
|
||||
|
Issuer string `json:"issuer"` |
||||
|
|
||||
|
// ClientID is the OAuth2 client ID
|
||||
|
ClientID string `json:"clientId"` |
||||
|
|
||||
|
// ClientSecret is the OAuth2 client secret (optional for public clients)
|
||||
|
ClientSecret string `json:"clientSecret,omitempty"` |
||||
|
|
||||
|
// JWKSUri is the JSON Web Key Set URI
|
||||
|
JWKSUri string `json:"jwksUri,omitempty"` |
||||
|
|
||||
|
// UserInfoUri is the UserInfo endpoint URI
|
||||
|
UserInfoUri string `json:"userInfoUri,omitempty"` |
||||
|
|
||||
|
// Scopes are the OAuth2 scopes to request
|
||||
|
Scopes []string `json:"scopes,omitempty"` |
||||
|
|
||||
|
// RoleMapping defines how to map OIDC claims to roles
|
||||
|
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` |
||||
|
|
||||
|
// ClaimsMapping defines how to map OIDC claims to identity attributes
|
||||
|
ClaimsMapping map[string]string `json:"claimsMapping,omitempty"` |
||||
|
} |
||||
|
|
||||
|
// NewOIDCProvider creates a new OIDC provider
|
||||
|
func NewOIDCProvider(name string) *OIDCProvider { |
||||
|
return &OIDCProvider{ |
||||
|
name: name, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Name returns the provider name
|
||||
|
func (p *OIDCProvider) Name() string { |
||||
|
return p.name |
||||
|
} |
||||
|
|
||||
|
// Initialize initializes the OIDC provider with configuration
|
||||
|
func (p *OIDCProvider) Initialize(config interface{}) error { |
||||
|
oidcConfig, ok := config.(*OIDCConfig) |
||||
|
if !ok { |
||||
|
return fmt.Errorf("invalid config type for OIDC provider") |
||||
|
} |
||||
|
|
||||
|
if err := p.validateConfig(oidcConfig); err != nil { |
||||
|
return fmt.Errorf("invalid OIDC configuration: %w", err) |
||||
|
} |
||||
|
|
||||
|
p.config = oidcConfig |
||||
|
p.initialized = true |
||||
|
|
||||
|
// TODO: Initialize OIDC client, fetch JWKS, etc.
|
||||
|
return fmt.Errorf("not implemented yet") |
||||
|
} |
||||
|
|
||||
|
// validateConfig validates the OIDC configuration
|
||||
|
func (p *OIDCProvider) validateConfig(config *OIDCConfig) error { |
||||
|
if config.Issuer == "" { |
||||
|
return fmt.Errorf("issuer is required") |
||||
|
} |
||||
|
|
||||
|
if config.ClientID == "" { |
||||
|
return fmt.Errorf("client ID is required") |
||||
|
} |
||||
|
|
||||
|
// Basic URL validation for issuer
|
||||
|
if config.Issuer != "" && config.Issuer != "https://accounts.google.com" && config.Issuer[0:4] != "http" { |
||||
|
return fmt.Errorf("invalid issuer URL format") |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Authenticate authenticates a user with an OIDC token
|
||||
|
func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { |
||||
|
if !p.initialized { |
||||
|
return nil, fmt.Errorf("provider not initialized") |
||||
|
} |
||||
|
|
||||
|
// TODO: Validate JWT token, extract claims, map to identity
|
||||
|
return nil, fmt.Errorf("not implemented yet") |
||||
|
} |
||||
|
|
||||
|
// GetUserInfo retrieves user information from the UserInfo endpoint
|
||||
|
func (p *OIDCProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { |
||||
|
if !p.initialized { |
||||
|
return nil, fmt.Errorf("provider not initialized") |
||||
|
} |
||||
|
|
||||
|
if userID == "" { |
||||
|
return nil, fmt.Errorf("user ID cannot be empty") |
||||
|
} |
||||
|
|
||||
|
// TODO: Call UserInfo endpoint
|
||||
|
return nil, fmt.Errorf("not implemented yet") |
||||
|
} |
||||
|
|
||||
|
// ValidateToken validates an OIDC JWT token
|
||||
|
func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { |
||||
|
if !p.initialized { |
||||
|
return nil, fmt.Errorf("provider not initialized") |
||||
|
} |
||||
|
|
||||
|
// TODO: Validate JWT signature, claims, expiration
|
||||
|
return nil, fmt.Errorf("not implemented yet") |
||||
|
} |
@ -0,0 +1,318 @@ |
|||||
|
package oidc |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"crypto/rand" |
||||
|
"crypto/rsa" |
||||
|
"encoding/json" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/golang-jwt/jwt/v5" |
||||
|
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
||||
|
"github.com/stretchr/testify/assert" |
||||
|
"github.com/stretchr/testify/require" |
||||
|
) |
||||
|
|
||||
|
// TestOIDCProviderInitialization tests OIDC provider initialization
|
||||
|
func TestOIDCProviderInitialization(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
config *OIDCConfig |
||||
|
wantErr bool |
||||
|
}{ |
||||
|
{ |
||||
|
name: "valid config", |
||||
|
config: &OIDCConfig{ |
||||
|
Issuer: "https://accounts.google.com", |
||||
|
ClientID: "test-client-id", |
||||
|
JWKSUri: "https://www.googleapis.com/oauth2/v3/certs", |
||||
|
}, |
||||
|
wantErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "missing issuer", |
||||
|
config: &OIDCConfig{ |
||||
|
ClientID: "test-client-id", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "missing client id", |
||||
|
config: &OIDCConfig{ |
||||
|
Issuer: "https://accounts.google.com", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "invalid issuer url", |
||||
|
config: &OIDCConfig{ |
||||
|
Issuer: "not-a-url", |
||||
|
ClientID: "test-client-id", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
provider := NewOIDCProvider("test-provider") |
||||
|
|
||||
|
err := provider.Initialize(tt.config) |
||||
|
|
||||
|
if tt.wantErr { |
||||
|
assert.Error(t, err) |
||||
|
} else { |
||||
|
assert.NoError(t, err) |
||||
|
assert.Equal(t, "test-provider", provider.Name()) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TestOIDCProviderJWTValidation tests JWT token validation
|
||||
|
func TestOIDCProviderJWTValidation(t *testing.T) { |
||||
|
// Set up test server with JWKS endpoint
|
||||
|
privateKey, publicKey := generateTestKeys(t) |
||||
|
|
||||
|
jwks := map[string]interface{}{ |
||||
|
"keys": []map[string]interface{}{ |
||||
|
{ |
||||
|
"kty": "RSA", |
||||
|
"kid": "test-key-id", |
||||
|
"use": "sig", |
||||
|
"alg": "RS256", |
||||
|
"n": encodePublicKey(t, publicKey), |
||||
|
"e": "AQAB", |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.Path == "/.well-known/openid_configuration" { |
||||
|
config := map[string]interface{}{ |
||||
|
"issuer": "http://" + r.Host, |
||||
|
"jwks_uri": "http://" + r.Host + "/jwks", |
||||
|
} |
||||
|
json.NewEncoder(w).Encode(config) |
||||
|
} else if r.URL.Path == "/jwks" { |
||||
|
json.NewEncoder(w).Encode(jwks) |
||||
|
} |
||||
|
})) |
||||
|
defer server.Close() |
||||
|
|
||||
|
provider := NewOIDCProvider("test-oidc") |
||||
|
config := &OIDCConfig{ |
||||
|
Issuer: server.URL, |
||||
|
ClientID: "test-client", |
||||
|
JWKSUri: server.URL + "/jwks", |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
t.Run("valid token", func(t *testing.T) { |
||||
|
// Create valid JWT token
|
||||
|
token := createTestJWT(t, privateKey, jwt.MapClaims{ |
||||
|
"iss": server.URL, |
||||
|
"aud": "test-client", |
||||
|
"sub": "user123", |
||||
|
"exp": time.Now().Add(time.Hour).Unix(), |
||||
|
"iat": time.Now().Unix(), |
||||
|
"email": "user@example.com", |
||||
|
"name": "Test User", |
||||
|
}) |
||||
|
|
||||
|
claims, err := provider.ValidateToken(context.Background(), token) |
||||
|
assert.NoError(t, err) |
||||
|
assert.Equal(t, "user123", claims.Subject) |
||||
|
assert.Equal(t, server.URL, claims.Issuer) |
||||
|
|
||||
|
email, exists := claims.GetClaimString("email") |
||||
|
assert.True(t, exists) |
||||
|
assert.Equal(t, "user@example.com", email) |
||||
|
}) |
||||
|
|
||||
|
t.Run("expired token", func(t *testing.T) { |
||||
|
// Create expired JWT token
|
||||
|
token := createTestJWT(t, privateKey, jwt.MapClaims{ |
||||
|
"iss": server.URL, |
||||
|
"aud": "test-client", |
||||
|
"sub": "user123", |
||||
|
"exp": time.Now().Add(-time.Hour).Unix(), // Expired
|
||||
|
"iat": time.Now().Add(-time.Hour * 2).Unix(), |
||||
|
}) |
||||
|
|
||||
|
_, err := provider.ValidateToken(context.Background(), token) |
||||
|
assert.Error(t, err) |
||||
|
}) |
||||
|
|
||||
|
t.Run("invalid signature", func(t *testing.T) { |
||||
|
// Create token with wrong key
|
||||
|
wrongKey, _ := generateTestKeys(t) |
||||
|
token := createTestJWT(t, wrongKey, jwt.MapClaims{ |
||||
|
"iss": server.URL, |
||||
|
"aud": "test-client", |
||||
|
"sub": "user123", |
||||
|
"exp": time.Now().Add(time.Hour).Unix(), |
||||
|
"iat": time.Now().Unix(), |
||||
|
}) |
||||
|
|
||||
|
_, err := provider.ValidateToken(context.Background(), token) |
||||
|
assert.Error(t, err) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// TestOIDCProviderAuthentication tests authentication flow
|
||||
|
func TestOIDCProviderAuthentication(t *testing.T) { |
||||
|
// Set up test OIDC provider
|
||||
|
privateKey, publicKey := generateTestKeys(t) |
||||
|
|
||||
|
server := setupOIDCTestServer(t, publicKey) |
||||
|
defer server.Close() |
||||
|
|
||||
|
provider := NewOIDCProvider("test-oidc") |
||||
|
config := &OIDCConfig{ |
||||
|
Issuer: server.URL, |
||||
|
ClientID: "test-client", |
||||
|
JWKSUri: server.URL + "/jwks", |
||||
|
RoleMapping: &providers.RoleMapping{ |
||||
|
Rules: []providers.MappingRule{ |
||||
|
{ |
||||
|
Claim: "email", |
||||
|
Value: "*@example.com", |
||||
|
Role: "arn:seaweed:iam::role/UserRole", |
||||
|
}, |
||||
|
{ |
||||
|
Claim: "groups", |
||||
|
Value: "admins", |
||||
|
Role: "arn:seaweed:iam::role/AdminRole", |
||||
|
}, |
||||
|
}, |
||||
|
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
t.Run("successful authentication", func(t *testing.T) { |
||||
|
token := createTestJWT(t, privateKey, jwt.MapClaims{ |
||||
|
"iss": server.URL, |
||||
|
"aud": "test-client", |
||||
|
"sub": "user123", |
||||
|
"exp": time.Now().Add(time.Hour).Unix(), |
||||
|
"iat": time.Now().Unix(), |
||||
|
"email": "user@example.com", |
||||
|
"name": "Test User", |
||||
|
"groups": []string{"users", "developers"}, |
||||
|
}) |
||||
|
|
||||
|
identity, err := provider.Authenticate(context.Background(), token) |
||||
|
assert.NoError(t, err) |
||||
|
assert.Equal(t, "user123", identity.UserID) |
||||
|
assert.Equal(t, "user@example.com", identity.Email) |
||||
|
assert.Equal(t, "Test User", identity.DisplayName) |
||||
|
assert.Equal(t, "test-oidc", identity.Provider) |
||||
|
assert.Contains(t, identity.Groups, "users") |
||||
|
assert.Contains(t, identity.Groups, "developers") |
||||
|
}) |
||||
|
|
||||
|
t.Run("authentication with invalid token", func(t *testing.T) { |
||||
|
_, err := provider.Authenticate(context.Background(), "invalid-token") |
||||
|
assert.Error(t, err) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// TestOIDCProviderUserInfo tests user info retrieval
|
||||
|
func TestOIDCProviderUserInfo(t *testing.T) { |
||||
|
// Set up test server with UserInfo endpoint
|
||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.Path == "/userinfo" { |
||||
|
userInfo := map[string]interface{}{ |
||||
|
"sub": r.URL.Query().Get("user_id"), |
||||
|
"email": "user@example.com", |
||||
|
"name": "Test User", |
||||
|
} |
||||
|
json.NewEncoder(w).Encode(userInfo) |
||||
|
} |
||||
|
})) |
||||
|
defer server.Close() |
||||
|
|
||||
|
provider := NewOIDCProvider("test-oidc") |
||||
|
config := &OIDCConfig{ |
||||
|
Issuer: server.URL, |
||||
|
ClientID: "test-client", |
||||
|
UserInfoUri: server.URL + "/userinfo", |
||||
|
} |
||||
|
|
||||
|
err := provider.Initialize(config) |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
t.Run("get user info", func(t *testing.T) { |
||||
|
identity, err := provider.GetUserInfo(context.Background(), "user123") |
||||
|
assert.NoError(t, err) |
||||
|
assert.Equal(t, "user123", identity.UserID) |
||||
|
assert.Equal(t, "user@example.com", identity.Email) |
||||
|
assert.Equal(t, "Test User", identity.DisplayName) |
||||
|
}) |
||||
|
|
||||
|
t.Run("get user info with empty id", func(t *testing.T) { |
||||
|
_, err := provider.GetUserInfo(context.Background(), "") |
||||
|
assert.Error(t, err) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// Helper functions for testing
|
||||
|
|
||||
|
func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { |
||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
|
require.NoError(t, err) |
||||
|
return privateKey, &privateKey.PublicKey |
||||
|
} |
||||
|
|
||||
|
func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string { |
||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) |
||||
|
token.Header["kid"] = "test-key-id" |
||||
|
|
||||
|
tokenString, err := token.SignedString(privateKey) |
||||
|
require.NoError(t, err) |
||||
|
return tokenString |
||||
|
} |
||||
|
|
||||
|
func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string { |
||||
|
// This is a simplified version - real implementation would properly encode the public key
|
||||
|
return "test-public-key-n-value" |
||||
|
} |
||||
|
|
||||
|
func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server { |
||||
|
jwks := map[string]interface{}{ |
||||
|
"keys": []map[string]interface{}{ |
||||
|
{ |
||||
|
"kty": "RSA", |
||||
|
"kid": "test-key-id", |
||||
|
"use": "sig", |
||||
|
"alg": "RS256", |
||||
|
"n": encodePublicKey(t, publicKey), |
||||
|
"e": "AQAB", |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
switch r.URL.Path { |
||||
|
case "/.well-known/openid_configuration": |
||||
|
config := map[string]interface{}{ |
||||
|
"issuer": "http://" + r.Host, |
||||
|
"jwks_uri": "http://" + r.Host + "/jwks", |
||||
|
} |
||||
|
json.NewEncoder(w).Encode(config) |
||||
|
case "/jwks": |
||||
|
json.NewEncoder(w).Encode(jwks) |
||||
|
default: |
||||
|
http.NotFound(w, r) |
||||
|
} |
||||
|
})) |
||||
|
} |
@ -0,0 +1,224 @@ |
|||||
|
package providers |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"fmt" |
||||
|
"net/mail" |
||||
|
"strings" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// IdentityProvider defines the interface for external identity providers
|
||||
|
type IdentityProvider interface { |
||||
|
// Name returns the unique name of the provider
|
||||
|
Name() string |
||||
|
|
||||
|
// Initialize initializes the provider with configuration
|
||||
|
Initialize(config interface{}) error |
||||
|
|
||||
|
// Authenticate authenticates a user with a token and returns external identity
|
||||
|
Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) |
||||
|
|
||||
|
// GetUserInfo retrieves user information by user ID
|
||||
|
GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) |
||||
|
|
||||
|
// ValidateToken validates a token and returns claims
|
||||
|
ValidateToken(ctx context.Context, token string) (*TokenClaims, error) |
||||
|
} |
||||
|
|
||||
|
// ExternalIdentity represents an identity from an external provider
|
||||
|
type ExternalIdentity struct { |
||||
|
// UserID is the unique identifier from the external provider
|
||||
|
UserID string `json:"userId"` |
||||
|
|
||||
|
// Email is the user's email address
|
||||
|
Email string `json:"email"` |
||||
|
|
||||
|
// DisplayName is the user's display name
|
||||
|
DisplayName string `json:"displayName"` |
||||
|
|
||||
|
// Groups are the groups the user belongs to
|
||||
|
Groups []string `json:"groups,omitempty"` |
||||
|
|
||||
|
// Attributes are additional user attributes
|
||||
|
Attributes map[string]string `json:"attributes,omitempty"` |
||||
|
|
||||
|
// Provider is the name of the identity provider
|
||||
|
Provider string `json:"provider"` |
||||
|
} |
||||
|
|
||||
|
// Validate validates the external identity structure
|
||||
|
func (e *ExternalIdentity) Validate() error { |
||||
|
if e.UserID == "" { |
||||
|
return fmt.Errorf("user ID is required") |
||||
|
} |
||||
|
|
||||
|
if e.Provider == "" { |
||||
|
return fmt.Errorf("provider is required") |
||||
|
} |
||||
|
|
||||
|
if e.Email != "" { |
||||
|
if _, err := mail.ParseAddress(e.Email); err != nil { |
||||
|
return fmt.Errorf("invalid email format: %w", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// TokenClaims represents claims from a validated token
|
||||
|
type TokenClaims struct { |
||||
|
// Subject (sub) - user identifier
|
||||
|
Subject string `json:"sub"` |
||||
|
|
||||
|
// Issuer (iss) - token issuer
|
||||
|
Issuer string `json:"iss"` |
||||
|
|
||||
|
// Audience (aud) - intended audience
|
||||
|
Audience string `json:"aud"` |
||||
|
|
||||
|
// ExpiresAt (exp) - expiration time
|
||||
|
ExpiresAt time.Time `json:"exp"` |
||||
|
|
||||
|
// IssuedAt (iat) - issued at time
|
||||
|
IssuedAt time.Time `json:"iat"` |
||||
|
|
||||
|
// NotBefore (nbf) - not valid before time
|
||||
|
NotBefore time.Time `json:"nbf,omitempty"` |
||||
|
|
||||
|
// Claims are additional claims from the token
|
||||
|
Claims map[string]interface{} `json:"claims,omitempty"` |
||||
|
} |
||||
|
|
||||
|
// IsValid checks if the token claims are valid (not expired, etc.)
|
||||
|
func (c *TokenClaims) IsValid() bool { |
||||
|
now := time.Now() |
||||
|
|
||||
|
// Check expiration
|
||||
|
if !c.ExpiresAt.IsZero() && now.After(c.ExpiresAt) { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
// Check not before
|
||||
|
if !c.NotBefore.IsZero() && now.Before(c.NotBefore) { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
// Check issued at (shouldn't be in the future)
|
||||
|
if !c.IssuedAt.IsZero() && now.Before(c.IssuedAt) { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
// GetClaimString returns a string claim value
|
||||
|
func (c *TokenClaims) GetClaimString(key string) (string, bool) { |
||||
|
if value, exists := c.Claims[key]; exists { |
||||
|
if str, ok := value.(string); ok { |
||||
|
return str, true |
||||
|
} |
||||
|
} |
||||
|
return "", false |
||||
|
} |
||||
|
|
||||
|
// GetClaimStringSlice returns a string slice claim value
|
||||
|
func (c *TokenClaims) GetClaimStringSlice(key string) ([]string, bool) { |
||||
|
if value, exists := c.Claims[key]; exists { |
||||
|
switch v := value.(type) { |
||||
|
case []string: |
||||
|
return v, true |
||||
|
case []interface{}: |
||||
|
var result []string |
||||
|
for _, item := range v { |
||||
|
if str, ok := item.(string); ok { |
||||
|
result = append(result, str) |
||||
|
} |
||||
|
} |
||||
|
return result, len(result) > 0 |
||||
|
case string: |
||||
|
// Single string can be treated as slice
|
||||
|
return []string{v}, true |
||||
|
} |
||||
|
} |
||||
|
return nil, false |
||||
|
} |
||||
|
|
||||
|
// ProviderConfig represents configuration for identity providers
|
||||
|
type ProviderConfig struct { |
||||
|
// Type of provider (oidc, ldap, saml)
|
||||
|
Type string `json:"type"` |
||||
|
|
||||
|
// Name of the provider instance
|
||||
|
Name string `json:"name"` |
||||
|
|
||||
|
// Enabled indicates if the provider is active
|
||||
|
Enabled bool `json:"enabled"` |
||||
|
|
||||
|
// Config is provider-specific configuration
|
||||
|
Config map[string]interface{} `json:"config"` |
||||
|
|
||||
|
// RoleMapping defines how to map external identities to roles
|
||||
|
RoleMapping *RoleMapping `json:"roleMapping,omitempty"` |
||||
|
} |
||||
|
|
||||
|
// RoleMapping defines rules for mapping external identities to roles
|
||||
|
type RoleMapping struct { |
||||
|
// Rules are the mapping rules
|
||||
|
Rules []MappingRule `json:"rules"` |
||||
|
|
||||
|
// DefaultRole is assigned if no rules match
|
||||
|
DefaultRole string `json:"defaultRole,omitempty"` |
||||
|
} |
||||
|
|
||||
|
// MappingRule defines a single mapping rule
|
||||
|
type MappingRule struct { |
||||
|
// Claim is the claim key to check
|
||||
|
Claim string `json:"claim"` |
||||
|
|
||||
|
// Value is the expected claim value (supports wildcards)
|
||||
|
Value string `json:"value"` |
||||
|
|
||||
|
// Role is the role ARN to assign
|
||||
|
Role string `json:"role"` |
||||
|
|
||||
|
// Condition is additional condition logic (optional)
|
||||
|
Condition string `json:"condition,omitempty"` |
||||
|
} |
||||
|
|
||||
|
// Matches checks if a rule matches the given claims
|
||||
|
func (r *MappingRule) Matches(claims *TokenClaims) bool { |
||||
|
if r.Claim == "" || r.Value == "" { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
claimValue, exists := claims.GetClaimString(r.Claim) |
||||
|
if !exists { |
||||
|
// Try as string slice
|
||||
|
if claimSlice, sliceExists := claims.GetClaimStringSlice(r.Claim); sliceExists { |
||||
|
for _, val := range claimSlice { |
||||
|
if r.matchValue(val) { |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
return r.matchValue(claimValue) |
||||
|
} |
||||
|
|
||||
|
// matchValue checks if a value matches the rule value (with wildcard support)
|
||||
|
func (r *MappingRule) matchValue(value string) bool { |
||||
|
// Simple wildcard matching
|
||||
|
if strings.Contains(r.Value, "*") { |
||||
|
// Convert wildcard to regex-like matching
|
||||
|
pattern := strings.ReplaceAll(r.Value, "*", "") |
||||
|
if pattern == "" { |
||||
|
return true // "*" matches everything
|
||||
|
} |
||||
|
return strings.Contains(value, pattern) |
||||
|
} |
||||
|
|
||||
|
return value == r.Value |
||||
|
} |
@ -0,0 +1,246 @@ |
|||||
|
package providers |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/stretchr/testify/assert" |
||||
|
"github.com/stretchr/testify/require" |
||||
|
) |
||||
|
|
||||
|
// TestIdentityProviderInterface tests the core identity provider interface
|
||||
|
func TestIdentityProviderInterface(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
provider IdentityProvider |
||||
|
wantErr bool |
||||
|
}{ |
||||
|
// We'll add test cases as we implement providers
|
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
// Test provider name
|
||||
|
name := tt.provider.Name() |
||||
|
assert.NotEmpty(t, name, "Provider name should not be empty") |
||||
|
|
||||
|
// Test initialization
|
||||
|
err := tt.provider.Initialize(nil) |
||||
|
if tt.wantErr { |
||||
|
assert.Error(t, err) |
||||
|
return |
||||
|
} |
||||
|
require.NoError(t, err) |
||||
|
|
||||
|
// Test authentication with invalid token
|
||||
|
ctx := context.Background() |
||||
|
_, err = tt.provider.Authenticate(ctx, "invalid-token") |
||||
|
assert.Error(t, err, "Should fail with invalid token") |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TestExternalIdentityValidation tests external identity structure validation
|
||||
|
func TestExternalIdentityValidation(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
identity *ExternalIdentity |
||||
|
wantErr bool |
||||
|
}{ |
||||
|
{ |
||||
|
name: "valid identity", |
||||
|
identity: &ExternalIdentity{ |
||||
|
UserID: "user123", |
||||
|
Email: "user@example.com", |
||||
|
DisplayName: "Test User", |
||||
|
Groups: []string{"group1", "group2"}, |
||||
|
Attributes: map[string]string{"dept": "engineering"}, |
||||
|
Provider: "test-provider", |
||||
|
}, |
||||
|
wantErr: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "missing user id", |
||||
|
identity: &ExternalIdentity{ |
||||
|
Email: "user@example.com", |
||||
|
Provider: "test-provider", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "missing provider", |
||||
|
identity: &ExternalIdentity{ |
||||
|
UserID: "user123", |
||||
|
Email: "user@example.com", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "invalid email", |
||||
|
identity: &ExternalIdentity{ |
||||
|
UserID: "user123", |
||||
|
Email: "invalid-email", |
||||
|
Provider: "test-provider", |
||||
|
}, |
||||
|
wantErr: true, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
err := tt.identity.Validate() |
||||
|
if tt.wantErr { |
||||
|
assert.Error(t, err) |
||||
|
} else { |
||||
|
assert.NoError(t, err) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TestTokenClaimsValidation tests token claims structure
|
||||
|
func TestTokenClaimsValidation(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
claims *TokenClaims |
||||
|
valid bool |
||||
|
}{ |
||||
|
{ |
||||
|
name: "valid claims", |
||||
|
claims: &TokenClaims{ |
||||
|
Subject: "user123", |
||||
|
Issuer: "https://provider.example.com", |
||||
|
Audience: "seaweedfs", |
||||
|
ExpiresAt: time.Now().Add(time.Hour), |
||||
|
IssuedAt: time.Now().Add(-time.Minute), |
||||
|
Claims: map[string]interface{}{"email": "user@example.com"}, |
||||
|
}, |
||||
|
valid: true, |
||||
|
}, |
||||
|
{ |
||||
|
name: "expired token", |
||||
|
claims: &TokenClaims{ |
||||
|
Subject: "user123", |
||||
|
Issuer: "https://provider.example.com", |
||||
|
Audience: "seaweedfs", |
||||
|
ExpiresAt: time.Now().Add(-time.Hour), // Expired
|
||||
|
IssuedAt: time.Now().Add(-time.Hour * 2), |
||||
|
Claims: map[string]interface{}{"email": "user@example.com"}, |
||||
|
}, |
||||
|
valid: false, |
||||
|
}, |
||||
|
{ |
||||
|
name: "future issued token", |
||||
|
claims: &TokenClaims{ |
||||
|
Subject: "user123", |
||||
|
Issuer: "https://provider.example.com", |
||||
|
Audience: "seaweedfs", |
||||
|
ExpiresAt: time.Now().Add(time.Hour), |
||||
|
IssuedAt: time.Now().Add(time.Hour), // Future
|
||||
|
Claims: map[string]interface{}{"email": "user@example.com"}, |
||||
|
}, |
||||
|
valid: false, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
valid := tt.claims.IsValid() |
||||
|
assert.Equal(t, tt.valid, valid) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TestProviderRegistry tests provider registration and discovery
|
||||
|
func TestProviderRegistry(t *testing.T) { |
||||
|
// Clear registry for test
|
||||
|
registry := NewProviderRegistry() |
||||
|
|
||||
|
t.Run("register provider", func(t *testing.T) { |
||||
|
mockProvider := &MockProvider{name: "test-provider"} |
||||
|
|
||||
|
err := registry.RegisterProvider(mockProvider) |
||||
|
assert.NoError(t, err) |
||||
|
|
||||
|
// Test duplicate registration
|
||||
|
err = registry.RegisterProvider(mockProvider) |
||||
|
assert.Error(t, err, "Should not allow duplicate registration") |
||||
|
}) |
||||
|
|
||||
|
t.Run("get provider", func(t *testing.T) { |
||||
|
provider, exists := registry.GetProvider("test-provider") |
||||
|
assert.True(t, exists) |
||||
|
assert.Equal(t, "test-provider", provider.Name()) |
||||
|
|
||||
|
// Test non-existent provider
|
||||
|
_, exists = registry.GetProvider("non-existent") |
||||
|
assert.False(t, exists) |
||||
|
}) |
||||
|
|
||||
|
t.Run("list providers", func(t *testing.T) { |
||||
|
providers := registry.ListProviders() |
||||
|
assert.Len(t, providers, 1) |
||||
|
assert.Equal(t, "test-provider", providers[0]) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// MockProvider for testing
|
||||
|
type MockProvider struct { |
||||
|
name string |
||||
|
initialized bool |
||||
|
shouldError bool |
||||
|
} |
||||
|
|
||||
|
func (m *MockProvider) Name() string { |
||||
|
return m.name |
||||
|
} |
||||
|
|
||||
|
func (m *MockProvider) Initialize(config interface{}) error { |
||||
|
if m.shouldError { |
||||
|
return assert.AnError |
||||
|
} |
||||
|
m.initialized = true |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) { |
||||
|
if !m.initialized { |
||||
|
return nil, assert.AnError |
||||
|
} |
||||
|
if token == "invalid-token" { |
||||
|
return nil, assert.AnError |
||||
|
} |
||||
|
return &ExternalIdentity{ |
||||
|
UserID: "test-user", |
||||
|
Email: "test@example.com", |
||||
|
DisplayName: "Test User", |
||||
|
Provider: m.name, |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) { |
||||
|
if !m.initialized || userID == "" { |
||||
|
return nil, assert.AnError |
||||
|
} |
||||
|
return &ExternalIdentity{ |
||||
|
UserID: userID, |
||||
|
Email: userID + "@example.com", |
||||
|
DisplayName: "User " + userID, |
||||
|
Provider: m.name, |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) { |
||||
|
if !m.initialized || token == "invalid-token" { |
||||
|
return nil, assert.AnError |
||||
|
} |
||||
|
return &TokenClaims{ |
||||
|
Subject: "test-user", |
||||
|
Issuer: "test-issuer", |
||||
|
Audience: "seaweedfs", |
||||
|
ExpiresAt: time.Now().Add(time.Hour), |
||||
|
IssuedAt: time.Now(), |
||||
|
Claims: map[string]interface{}{"email": "test@example.com"}, |
||||
|
}, nil |
||||
|
} |
@ -0,0 +1,109 @@ |
|||||
|
package providers |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"sync" |
||||
|
) |
||||
|
|
||||
|
// ProviderRegistry manages registered identity providers
|
||||
|
type ProviderRegistry struct { |
||||
|
mu sync.RWMutex |
||||
|
providers map[string]IdentityProvider |
||||
|
} |
||||
|
|
||||
|
// NewProviderRegistry creates a new provider registry
|
||||
|
func NewProviderRegistry() *ProviderRegistry { |
||||
|
return &ProviderRegistry{ |
||||
|
providers: make(map[string]IdentityProvider), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// RegisterProvider registers a new identity provider
|
||||
|
func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error { |
||||
|
if provider == nil { |
||||
|
return fmt.Errorf("provider cannot be nil") |
||||
|
} |
||||
|
|
||||
|
name := provider.Name() |
||||
|
if name == "" { |
||||
|
return fmt.Errorf("provider name cannot be empty") |
||||
|
} |
||||
|
|
||||
|
r.mu.Lock() |
||||
|
defer r.mu.Unlock() |
||||
|
|
||||
|
if _, exists := r.providers[name]; exists { |
||||
|
return fmt.Errorf("provider %s is already registered", name) |
||||
|
} |
||||
|
|
||||
|
r.providers[name] = provider |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// GetProvider retrieves a provider by name
|
||||
|
func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) { |
||||
|
r.mu.RLock() |
||||
|
defer r.mu.RUnlock() |
||||
|
|
||||
|
provider, exists := r.providers[name] |
||||
|
return provider, exists |
||||
|
} |
||||
|
|
||||
|
// ListProviders returns all registered provider names
|
||||
|
func (r *ProviderRegistry) ListProviders() []string { |
||||
|
r.mu.RLock() |
||||
|
defer r.mu.RUnlock() |
||||
|
|
||||
|
var names []string |
||||
|
for name := range r.providers { |
||||
|
names = append(names, name) |
||||
|
} |
||||
|
return names |
||||
|
} |
||||
|
|
||||
|
// UnregisterProvider removes a provider from the registry
|
||||
|
func (r *ProviderRegistry) UnregisterProvider(name string) error { |
||||
|
r.mu.Lock() |
||||
|
defer r.mu.Unlock() |
||||
|
|
||||
|
if _, exists := r.providers[name]; !exists { |
||||
|
return fmt.Errorf("provider %s is not registered", name) |
||||
|
} |
||||
|
|
||||
|
delete(r.providers, name) |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Clear removes all providers from the registry
|
||||
|
func (r *ProviderRegistry) Clear() { |
||||
|
r.mu.Lock() |
||||
|
defer r.mu.Unlock() |
||||
|
|
||||
|
r.providers = make(map[string]IdentityProvider) |
||||
|
} |
||||
|
|
||||
|
// GetProviderCount returns the number of registered providers
|
||||
|
func (r *ProviderRegistry) GetProviderCount() int { |
||||
|
r.mu.RLock() |
||||
|
defer r.mu.RUnlock() |
||||
|
|
||||
|
return len(r.providers) |
||||
|
} |
||||
|
|
||||
|
// Default global registry
|
||||
|
var defaultRegistry = NewProviderRegistry() |
||||
|
|
||||
|
// RegisterProvider registers a provider in the default registry
|
||||
|
func RegisterProvider(provider IdentityProvider) error { |
||||
|
return defaultRegistry.RegisterProvider(provider) |
||||
|
} |
||||
|
|
||||
|
// GetProvider retrieves a provider from the default registry
|
||||
|
func GetProvider(name string) (IdentityProvider, bool) { |
||||
|
return defaultRegistry.GetProvider(name) |
||||
|
} |
||||
|
|
||||
|
// ListProviders returns all provider names from the default registry
|
||||
|
func ListProviders() []string { |
||||
|
return defaultRegistry.ListProviders() |
||||
|
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue