Browse Source
TDD RED Phase: Add identity provider framework tests
TDD RED Phase: Add identity provider framework tests
- Add core IdentityProvider interface with tests - Add OIDC provider tests with JWT token validation - Add LDAP provider tests with authentication flows - Add ProviderRegistry for managing multiple providers - Tests currently failing as expected in TDD RED phasepull/7160/head
7 changed files with 1597 additions and 0 deletions
-
219weed/iam/ldap/ldap_provider.go
-
357weed/iam/ldap/ldap_provider_test.go
-
124weed/iam/oidc/oidc_provider.go
-
318weed/iam/oidc/oidc_provider_test.go
-
224weed/iam/providers/provider.go
-
246weed/iam/providers/provider_test.go
-
109weed/iam/providers/registry.go
@ -0,0 +1,219 @@ |
|||
package ldap |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"strings" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|||
) |
|||
|
|||
// LDAPProvider implements LDAP authentication
|
|||
type LDAPProvider struct { |
|||
name string |
|||
config *LDAPConfig |
|||
initialized bool |
|||
connPool interface{} // Will be proper LDAP connection pool
|
|||
} |
|||
|
|||
// LDAPConfig holds LDAP provider configuration
|
|||
type LDAPConfig struct { |
|||
// Server is the LDAP server URL (e.g., ldap://localhost:389)
|
|||
Server string `json:"server"` |
|||
|
|||
// BaseDN is the base distinguished name for searches
|
|||
BaseDN string `json:"baseDn"` |
|||
|
|||
// BindDN is the distinguished name for binding (authentication)
|
|||
BindDN string `json:"bindDn,omitempty"` |
|||
|
|||
// BindPass is the password for binding
|
|||
BindPass string `json:"bindPass,omitempty"` |
|||
|
|||
// UserFilter is the LDAP filter for finding users (e.g., "(sAMAccountName=%s)")
|
|||
UserFilter string `json:"userFilter"` |
|||
|
|||
// GroupFilter is the LDAP filter for finding groups (e.g., "(member=%s)")
|
|||
GroupFilter string `json:"groupFilter,omitempty"` |
|||
|
|||
// Attributes maps SeaweedFS identity fields to LDAP attributes
|
|||
Attributes map[string]string `json:"attributes,omitempty"` |
|||
|
|||
// RoleMapping defines how to map LDAP groups to roles
|
|||
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` |
|||
|
|||
// TLS configuration
|
|||
UseTLS bool `json:"useTls,omitempty"` |
|||
TLSCert string `json:"tlsCert,omitempty"` |
|||
TLSKey string `json:"tlsKey,omitempty"` |
|||
TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"` |
|||
|
|||
// Connection pool settings
|
|||
MaxConnections int `json:"maxConnections,omitempty"` |
|||
ConnTimeout int `json:"connTimeout,omitempty"` // seconds
|
|||
} |
|||
|
|||
// NewLDAPProvider creates a new LDAP provider
|
|||
func NewLDAPProvider(name string) *LDAPProvider { |
|||
return &LDAPProvider{ |
|||
name: name, |
|||
} |
|||
} |
|||
|
|||
// Name returns the provider name
|
|||
func (p *LDAPProvider) Name() string { |
|||
return p.name |
|||
} |
|||
|
|||
// Initialize initializes the LDAP provider with configuration
|
|||
func (p *LDAPProvider) Initialize(config interface{}) error { |
|||
ldapConfig, ok := config.(*LDAPConfig) |
|||
if !ok { |
|||
return fmt.Errorf("invalid config type for LDAP provider") |
|||
} |
|||
|
|||
if err := p.validateConfig(ldapConfig); err != nil { |
|||
return fmt.Errorf("invalid LDAP configuration: %w", err) |
|||
} |
|||
|
|||
p.config = ldapConfig |
|||
p.initialized = true |
|||
|
|||
// TODO: Initialize LDAP connection pool
|
|||
return fmt.Errorf("not implemented yet") |
|||
} |
|||
|
|||
// validateConfig validates the LDAP configuration
|
|||
func (p *LDAPProvider) validateConfig(config *LDAPConfig) error { |
|||
if config.Server == "" { |
|||
return fmt.Errorf("server is required") |
|||
} |
|||
|
|||
if config.BaseDN == "" { |
|||
return fmt.Errorf("base DN is required") |
|||
} |
|||
|
|||
// Basic URL validation
|
|||
if !strings.HasPrefix(config.Server, "ldap://") && !strings.HasPrefix(config.Server, "ldaps://") { |
|||
return fmt.Errorf("invalid server URL format") |
|||
} |
|||
|
|||
// Set default user filter if not provided
|
|||
if config.UserFilter == "" { |
|||
config.UserFilter = "(uid=%s)" // Default LDAP user filter
|
|||
} |
|||
|
|||
// Set default attributes if not provided
|
|||
if config.Attributes == nil { |
|||
config.Attributes = map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "cn", |
|||
"groups": "memberOf", |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Authenticate authenticates a user with LDAP
|
|||
func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*providers.ExternalIdentity, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
// TODO: Parse credentials (username:password), bind to LDAP, search for user
|
|||
return nil, fmt.Errorf("not implemented yet") |
|||
} |
|||
|
|||
// GetUserInfo retrieves user information from LDAP
|
|||
func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
if userID == "" { |
|||
return nil, fmt.Errorf("user ID cannot be empty") |
|||
} |
|||
|
|||
// TODO: Search LDAP for user, get attributes
|
|||
return nil, fmt.Errorf("not implemented yet") |
|||
} |
|||
|
|||
// ValidateToken validates credentials (for LDAP, this is username/password)
|
|||
func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
// TODO: For LDAP, "token" would be username:password format
|
|||
// Validate credentials and return claims
|
|||
return nil, fmt.Errorf("not implemented yet") |
|||
} |
|||
|
|||
// mapLDAPAttributes maps LDAP attributes to ExternalIdentity
|
|||
func (p *LDAPProvider) mapLDAPAttributes(userID string, attrs map[string][]string) *providers.ExternalIdentity { |
|||
identity := &providers.ExternalIdentity{ |
|||
UserID: userID, |
|||
Provider: p.name, |
|||
Attributes: make(map[string]string), |
|||
} |
|||
|
|||
// Map configured attributes
|
|||
for identityField, ldapAttr := range p.config.Attributes { |
|||
if values, exists := attrs[ldapAttr]; exists && len(values) > 0 { |
|||
switch identityField { |
|||
case "email": |
|||
identity.Email = values[0] |
|||
case "displayName": |
|||
identity.DisplayName = values[0] |
|||
case "groups": |
|||
identity.Groups = values |
|||
default: |
|||
// Store as custom attribute
|
|||
identity.Attributes[identityField] = values[0] |
|||
} |
|||
} |
|||
} |
|||
|
|||
return identity |
|||
} |
|||
|
|||
// mapUserToRole maps user groups to roles based on role mapping rules
|
|||
func (p *LDAPProvider) mapUserToRole(identity *providers.ExternalIdentity) string { |
|||
if p.config.RoleMapping == nil { |
|||
return "" |
|||
} |
|||
|
|||
// Create token claims from identity for rule matching
|
|||
claims := &providers.TokenClaims{ |
|||
Subject: identity.UserID, |
|||
Claims: map[string]interface{}{ |
|||
"groups": identity.Groups, |
|||
"email": identity.Email, |
|||
}, |
|||
} |
|||
|
|||
// Check mapping rules
|
|||
for _, rule := range p.config.RoleMapping.Rules { |
|||
if rule.Matches(claims) { |
|||
return rule.Role |
|||
} |
|||
} |
|||
|
|||
// Return default role if no rules match
|
|||
return p.config.RoleMapping.DefaultRole |
|||
} |
|||
|
|||
// Connection management methods (stubs for now)
|
|||
func (p *LDAPProvider) getConnectionPool() interface{} { |
|||
return p.connPool |
|||
} |
|||
|
|||
func (p *LDAPProvider) getConnection() (interface{}, error) { |
|||
// TODO: Get connection from pool
|
|||
return nil, fmt.Errorf("not implemented") |
|||
} |
|||
|
|||
func (p *LDAPProvider) releaseConnection(conn interface{}) { |
|||
// TODO: Return connection to pool
|
|||
} |
@ -0,0 +1,357 @@ |
|||
package ldap |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"testing" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|||
"github.com/stretchr/testify/assert" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
// TestLDAPProviderInitialization tests LDAP provider initialization
|
|||
func TestLDAPProviderInitialization(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
config *LDAPConfig |
|||
wantErr bool |
|||
}{ |
|||
{ |
|||
name: "valid config", |
|||
config: &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
UserFilter: "(sAMAccountName=%s)", |
|||
GroupFilter: "(member=%s)", |
|||
}, |
|||
wantErr: false, |
|||
}, |
|||
{ |
|||
name: "missing server", |
|||
config: &LDAPConfig{ |
|||
BaseDN: "DC=example,DC=com", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "missing base DN", |
|||
config: &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "invalid server URL", |
|||
config: &LDAPConfig{ |
|||
Server: "invalid-url", |
|||
BaseDN: "DC=example,DC=com", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
provider := NewLDAPProvider("test-ldap") |
|||
|
|||
err := provider.Initialize(tt.config) |
|||
|
|||
if tt.wantErr { |
|||
assert.Error(t, err) |
|||
} else { |
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "test-ldap", provider.Name()) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestLDAPProviderAuthentication tests LDAP authentication
|
|||
func TestLDAPProviderAuthentication(t *testing.T) { |
|||
// Skip if no LDAP test server available
|
|||
if testing.Short() { |
|||
t.Skip("Skipping LDAP integration test in short mode") |
|||
} |
|||
|
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
UserFilter: "(sAMAccountName=%s)", |
|||
GroupFilter: "(member=%s)", |
|||
Attributes: map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "displayName", |
|||
"groups": "memberOf", |
|||
}, |
|||
RoleMapping: &providers.RoleMapping{ |
|||
Rules: []providers.MappingRule{ |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*CN=Admins*", |
|||
Role: "arn:seaweed:iam::role/AdminRole", |
|||
}, |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*CN=Users*", |
|||
Role: "arn:seaweed:iam::role/UserRole", |
|||
}, |
|||
}, |
|||
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("authenticate with username/password", func(t *testing.T) { |
|||
// This would require an actual LDAP server for integration testing
|
|||
credentials := "user:password" // Basic auth format
|
|||
|
|||
identity, err := provider.Authenticate(context.Background(), credentials) |
|||
if err != nil { |
|||
t.Skip("LDAP server not available for testing") |
|||
} |
|||
|
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "user", identity.UserID) |
|||
assert.Equal(t, "test-ldap", identity.Provider) |
|||
assert.NotEmpty(t, identity.Email) |
|||
}) |
|||
|
|||
t.Run("authenticate with invalid credentials", func(t *testing.T) { |
|||
_, err := provider.Authenticate(context.Background(), "invalid:credentials") |
|||
assert.Error(t, err) |
|||
}) |
|||
} |
|||
|
|||
// TestLDAPProviderUserInfo tests LDAP user info retrieval
|
|||
func TestLDAPProviderUserInfo(t *testing.T) { |
|||
if testing.Short() { |
|||
t.Skip("Skipping LDAP integration test in short mode") |
|||
} |
|||
|
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
UserFilter: "(sAMAccountName=%s)", |
|||
Attributes: map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "displayName", |
|||
"department": "department", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("get user info", func(t *testing.T) { |
|||
identity, err := provider.GetUserInfo(context.Background(), "testuser") |
|||
if err != nil { |
|||
t.Skip("LDAP server not available for testing") |
|||
} |
|||
|
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "testuser", identity.UserID) |
|||
assert.Equal(t, "test-ldap", identity.Provider) |
|||
assert.NotEmpty(t, identity.Email) |
|||
assert.NotEmpty(t, identity.DisplayName) |
|||
}) |
|||
|
|||
t.Run("get user info with empty username", func(t *testing.T) { |
|||
_, err := provider.GetUserInfo(context.Background(), "") |
|||
assert.Error(t, err) |
|||
}) |
|||
|
|||
t.Run("get user info for non-existent user", func(t *testing.T) { |
|||
_, err := provider.GetUserInfo(context.Background(), "nonexistent") |
|||
assert.Error(t, err) |
|||
}) |
|||
} |
|||
|
|||
// TestLDAPAttributeMapping tests LDAP attribute mapping
|
|||
func TestLDAPAttributeMapping(t *testing.T) { |
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
Attributes: map[string]string{ |
|||
"email": "mail", |
|||
"displayName": "cn", |
|||
"department": "departmentNumber", |
|||
"groups": "memberOf", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("map LDAP attributes to identity", func(t *testing.T) { |
|||
ldapAttrs := map[string][]string{ |
|||
"mail": {"user@example.com"}, |
|||
"cn": {"John Doe"}, |
|||
"departmentNumber": {"IT"}, |
|||
"memberOf": { |
|||
"CN=Users,OU=Groups,DC=example,DC=com", |
|||
"CN=Developers,OU=Groups,DC=example,DC=com", |
|||
}, |
|||
} |
|||
|
|||
identity := provider.mapLDAPAttributes("testuser", ldapAttrs) |
|||
|
|||
assert.Equal(t, "testuser", identity.UserID) |
|||
assert.Equal(t, "user@example.com", identity.Email) |
|||
assert.Equal(t, "John Doe", identity.DisplayName) |
|||
assert.Equal(t, "test-ldap", identity.Provider) |
|||
|
|||
// Check groups
|
|||
assert.Contains(t, identity.Groups, "CN=Users,OU=Groups,DC=example,DC=com") |
|||
assert.Contains(t, identity.Groups, "CN=Developers,OU=Groups,DC=example,DC=com") |
|||
|
|||
// Check attributes
|
|||
assert.Equal(t, "IT", identity.Attributes["department"]) |
|||
}) |
|||
} |
|||
|
|||
// TestLDAPGroupFiltering tests LDAP group filtering and role mapping
|
|||
func TestLDAPGroupFiltering(t *testing.T) { |
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
RoleMapping: &providers.RoleMapping{ |
|||
Rules: []providers.MappingRule{ |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*Admins*", |
|||
Role: "arn:seaweed:iam::role/AdminRole", |
|||
}, |
|||
{ |
|||
Claim: "groups", |
|||
Value: "*Users*", |
|||
Role: "arn:seaweed:iam::role/UserRole", |
|||
}, |
|||
}, |
|||
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
tests := []struct { |
|||
name string |
|||
groups []string |
|||
expectedRole string |
|||
expectedClaims map[string]interface{} |
|||
}{ |
|||
{ |
|||
name: "admin user", |
|||
groups: []string{"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
expectedRole: "arn:seaweed:iam::role/AdminRole", |
|||
}, |
|||
{ |
|||
name: "regular user", |
|||
groups: []string{"CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
expectedRole: "arn:seaweed:iam::role/UserRole", |
|||
}, |
|||
{ |
|||
name: "guest user", |
|||
groups: []string{"CN=Guests,OU=Groups,DC=example,DC=com"}, |
|||
expectedRole: "arn:seaweed:iam::role/GuestRole", |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
identity := &providers.ExternalIdentity{ |
|||
UserID: "testuser", |
|||
Groups: tt.groups, |
|||
Provider: "test-ldap", |
|||
} |
|||
|
|||
role := provider.mapUserToRole(identity) |
|||
assert.Equal(t, tt.expectedRole, role) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestLDAPConnectionPool tests LDAP connection pooling
|
|||
func TestLDAPConnectionPool(t *testing.T) { |
|||
if testing.Short() { |
|||
t.Skip("Skipping LDAP connection pool test in short mode") |
|||
} |
|||
|
|||
provider := NewLDAPProvider("test-ldap") |
|||
config := &LDAPConfig{ |
|||
Server: "ldap://localhost:389", |
|||
BaseDN: "DC=example,DC=com", |
|||
BindDN: "CN=admin,DC=example,DC=com", |
|||
BindPass: "password", |
|||
MaxConnections: 5, |
|||
ConnTimeout: 30, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("connection pool management", func(t *testing.T) { |
|||
// Test that multiple concurrent requests work
|
|||
// This would require actual LDAP server for full testing
|
|||
pool := provider.getConnectionPool() |
|||
assert.NotNil(t, pool) |
|||
|
|||
// Test connection acquisition and release
|
|||
conn, err := provider.getConnection() |
|||
if err != nil { |
|||
t.Skip("LDAP server not available") |
|||
} |
|||
|
|||
assert.NotNil(t, conn) |
|||
provider.releaseConnection(conn) |
|||
}) |
|||
} |
|||
|
|||
// MockLDAPServer for unit testing (without external dependencies)
|
|||
type MockLDAPServer struct { |
|||
users map[string]map[string][]string |
|||
} |
|||
|
|||
func NewMockLDAPServer() *MockLDAPServer { |
|||
return &MockLDAPServer{ |
|||
users: map[string]map[string][]string{ |
|||
"testuser": { |
|||
"mail": {"testuser@example.com"}, |
|||
"cn": {"Test User"}, |
|||
"department": {"Engineering"}, |
|||
"memberOf": {"CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
}, |
|||
"admin": { |
|||
"mail": {"admin@example.com"}, |
|||
"cn": {"Administrator"}, |
|||
"department": {"IT"}, |
|||
"memberOf": {"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"}, |
|||
}, |
|||
}, |
|||
} |
|||
} |
|||
|
|||
func (m *MockLDAPServer) Authenticate(username, password string) bool { |
|||
_, exists := m.users[username] |
|||
return exists && password == "password" // Mock authentication
|
|||
} |
|||
|
|||
func (m *MockLDAPServer) GetUserAttributes(username string) (map[string][]string, error) { |
|||
if attrs, exists := m.users[username]; exists { |
|||
return attrs, nil |
|||
} |
|||
return nil, fmt.Errorf("user not found: %s", username) |
|||
} |
@ -0,0 +1,124 @@ |
|||
package oidc |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|||
) |
|||
|
|||
// OIDCProvider implements OpenID Connect authentication
|
|||
type OIDCProvider struct { |
|||
name string |
|||
config *OIDCConfig |
|||
initialized bool |
|||
} |
|||
|
|||
// OIDCConfig holds OIDC provider configuration
|
|||
type OIDCConfig struct { |
|||
// Issuer is the OIDC issuer URL
|
|||
Issuer string `json:"issuer"` |
|||
|
|||
// ClientID is the OAuth2 client ID
|
|||
ClientID string `json:"clientId"` |
|||
|
|||
// ClientSecret is the OAuth2 client secret (optional for public clients)
|
|||
ClientSecret string `json:"clientSecret,omitempty"` |
|||
|
|||
// JWKSUri is the JSON Web Key Set URI
|
|||
JWKSUri string `json:"jwksUri,omitempty"` |
|||
|
|||
// UserInfoUri is the UserInfo endpoint URI
|
|||
UserInfoUri string `json:"userInfoUri,omitempty"` |
|||
|
|||
// Scopes are the OAuth2 scopes to request
|
|||
Scopes []string `json:"scopes,omitempty"` |
|||
|
|||
// RoleMapping defines how to map OIDC claims to roles
|
|||
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"` |
|||
|
|||
// ClaimsMapping defines how to map OIDC claims to identity attributes
|
|||
ClaimsMapping map[string]string `json:"claimsMapping,omitempty"` |
|||
} |
|||
|
|||
// NewOIDCProvider creates a new OIDC provider
|
|||
func NewOIDCProvider(name string) *OIDCProvider { |
|||
return &OIDCProvider{ |
|||
name: name, |
|||
} |
|||
} |
|||
|
|||
// Name returns the provider name
|
|||
func (p *OIDCProvider) Name() string { |
|||
return p.name |
|||
} |
|||
|
|||
// Initialize initializes the OIDC provider with configuration
|
|||
func (p *OIDCProvider) Initialize(config interface{}) error { |
|||
oidcConfig, ok := config.(*OIDCConfig) |
|||
if !ok { |
|||
return fmt.Errorf("invalid config type for OIDC provider") |
|||
} |
|||
|
|||
if err := p.validateConfig(oidcConfig); err != nil { |
|||
return fmt.Errorf("invalid OIDC configuration: %w", err) |
|||
} |
|||
|
|||
p.config = oidcConfig |
|||
p.initialized = true |
|||
|
|||
// TODO: Initialize OIDC client, fetch JWKS, etc.
|
|||
return fmt.Errorf("not implemented yet") |
|||
} |
|||
|
|||
// validateConfig validates the OIDC configuration
|
|||
func (p *OIDCProvider) validateConfig(config *OIDCConfig) error { |
|||
if config.Issuer == "" { |
|||
return fmt.Errorf("issuer is required") |
|||
} |
|||
|
|||
if config.ClientID == "" { |
|||
return fmt.Errorf("client ID is required") |
|||
} |
|||
|
|||
// Basic URL validation for issuer
|
|||
if config.Issuer != "" && config.Issuer != "https://accounts.google.com" && config.Issuer[0:4] != "http" { |
|||
return fmt.Errorf("invalid issuer URL format") |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Authenticate authenticates a user with an OIDC token
|
|||
func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
// TODO: Validate JWT token, extract claims, map to identity
|
|||
return nil, fmt.Errorf("not implemented yet") |
|||
} |
|||
|
|||
// GetUserInfo retrieves user information from the UserInfo endpoint
|
|||
func (p *OIDCProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
if userID == "" { |
|||
return nil, fmt.Errorf("user ID cannot be empty") |
|||
} |
|||
|
|||
// TODO: Call UserInfo endpoint
|
|||
return nil, fmt.Errorf("not implemented yet") |
|||
} |
|||
|
|||
// ValidateToken validates an OIDC JWT token
|
|||
func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { |
|||
if !p.initialized { |
|||
return nil, fmt.Errorf("provider not initialized") |
|||
} |
|||
|
|||
// TODO: Validate JWT signature, claims, expiration
|
|||
return nil, fmt.Errorf("not implemented yet") |
|||
} |
@ -0,0 +1,318 @@ |
|||
package oidc |
|||
|
|||
import ( |
|||
"context" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"encoding/json" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/golang-jwt/jwt/v5" |
|||
"github.com/seaweedfs/seaweedfs/weed/iam/providers" |
|||
"github.com/stretchr/testify/assert" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
// TestOIDCProviderInitialization tests OIDC provider initialization
|
|||
func TestOIDCProviderInitialization(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
config *OIDCConfig |
|||
wantErr bool |
|||
}{ |
|||
{ |
|||
name: "valid config", |
|||
config: &OIDCConfig{ |
|||
Issuer: "https://accounts.google.com", |
|||
ClientID: "test-client-id", |
|||
JWKSUri: "https://www.googleapis.com/oauth2/v3/certs", |
|||
}, |
|||
wantErr: false, |
|||
}, |
|||
{ |
|||
name: "missing issuer", |
|||
config: &OIDCConfig{ |
|||
ClientID: "test-client-id", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "missing client id", |
|||
config: &OIDCConfig{ |
|||
Issuer: "https://accounts.google.com", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "invalid issuer url", |
|||
config: &OIDCConfig{ |
|||
Issuer: "not-a-url", |
|||
ClientID: "test-client-id", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
provider := NewOIDCProvider("test-provider") |
|||
|
|||
err := provider.Initialize(tt.config) |
|||
|
|||
if tt.wantErr { |
|||
assert.Error(t, err) |
|||
} else { |
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "test-provider", provider.Name()) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestOIDCProviderJWTValidation tests JWT token validation
|
|||
func TestOIDCProviderJWTValidation(t *testing.T) { |
|||
// Set up test server with JWKS endpoint
|
|||
privateKey, publicKey := generateTestKeys(t) |
|||
|
|||
jwks := map[string]interface{}{ |
|||
"keys": []map[string]interface{}{ |
|||
{ |
|||
"kty": "RSA", |
|||
"kid": "test-key-id", |
|||
"use": "sig", |
|||
"alg": "RS256", |
|||
"n": encodePublicKey(t, publicKey), |
|||
"e": "AQAB", |
|||
}, |
|||
}, |
|||
} |
|||
|
|||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.Path == "/.well-known/openid_configuration" { |
|||
config := map[string]interface{}{ |
|||
"issuer": "http://" + r.Host, |
|||
"jwks_uri": "http://" + r.Host + "/jwks", |
|||
} |
|||
json.NewEncoder(w).Encode(config) |
|||
} else if r.URL.Path == "/jwks" { |
|||
json.NewEncoder(w).Encode(jwks) |
|||
} |
|||
})) |
|||
defer server.Close() |
|||
|
|||
provider := NewOIDCProvider("test-oidc") |
|||
config := &OIDCConfig{ |
|||
Issuer: server.URL, |
|||
ClientID: "test-client", |
|||
JWKSUri: server.URL + "/jwks", |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("valid token", func(t *testing.T) { |
|||
// Create valid JWT token
|
|||
token := createTestJWT(t, privateKey, jwt.MapClaims{ |
|||
"iss": server.URL, |
|||
"aud": "test-client", |
|||
"sub": "user123", |
|||
"exp": time.Now().Add(time.Hour).Unix(), |
|||
"iat": time.Now().Unix(), |
|||
"email": "user@example.com", |
|||
"name": "Test User", |
|||
}) |
|||
|
|||
claims, err := provider.ValidateToken(context.Background(), token) |
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "user123", claims.Subject) |
|||
assert.Equal(t, server.URL, claims.Issuer) |
|||
|
|||
email, exists := claims.GetClaimString("email") |
|||
assert.True(t, exists) |
|||
assert.Equal(t, "user@example.com", email) |
|||
}) |
|||
|
|||
t.Run("expired token", func(t *testing.T) { |
|||
// Create expired JWT token
|
|||
token := createTestJWT(t, privateKey, jwt.MapClaims{ |
|||
"iss": server.URL, |
|||
"aud": "test-client", |
|||
"sub": "user123", |
|||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired
|
|||
"iat": time.Now().Add(-time.Hour * 2).Unix(), |
|||
}) |
|||
|
|||
_, err := provider.ValidateToken(context.Background(), token) |
|||
assert.Error(t, err) |
|||
}) |
|||
|
|||
t.Run("invalid signature", func(t *testing.T) { |
|||
// Create token with wrong key
|
|||
wrongKey, _ := generateTestKeys(t) |
|||
token := createTestJWT(t, wrongKey, jwt.MapClaims{ |
|||
"iss": server.URL, |
|||
"aud": "test-client", |
|||
"sub": "user123", |
|||
"exp": time.Now().Add(time.Hour).Unix(), |
|||
"iat": time.Now().Unix(), |
|||
}) |
|||
|
|||
_, err := provider.ValidateToken(context.Background(), token) |
|||
assert.Error(t, err) |
|||
}) |
|||
} |
|||
|
|||
// TestOIDCProviderAuthentication tests authentication flow
|
|||
func TestOIDCProviderAuthentication(t *testing.T) { |
|||
// Set up test OIDC provider
|
|||
privateKey, publicKey := generateTestKeys(t) |
|||
|
|||
server := setupOIDCTestServer(t, publicKey) |
|||
defer server.Close() |
|||
|
|||
provider := NewOIDCProvider("test-oidc") |
|||
config := &OIDCConfig{ |
|||
Issuer: server.URL, |
|||
ClientID: "test-client", |
|||
JWKSUri: server.URL + "/jwks", |
|||
RoleMapping: &providers.RoleMapping{ |
|||
Rules: []providers.MappingRule{ |
|||
{ |
|||
Claim: "email", |
|||
Value: "*@example.com", |
|||
Role: "arn:seaweed:iam::role/UserRole", |
|||
}, |
|||
{ |
|||
Claim: "groups", |
|||
Value: "admins", |
|||
Role: "arn:seaweed:iam::role/AdminRole", |
|||
}, |
|||
}, |
|||
DefaultRole: "arn:seaweed:iam::role/GuestRole", |
|||
}, |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("successful authentication", func(t *testing.T) { |
|||
token := createTestJWT(t, privateKey, jwt.MapClaims{ |
|||
"iss": server.URL, |
|||
"aud": "test-client", |
|||
"sub": "user123", |
|||
"exp": time.Now().Add(time.Hour).Unix(), |
|||
"iat": time.Now().Unix(), |
|||
"email": "user@example.com", |
|||
"name": "Test User", |
|||
"groups": []string{"users", "developers"}, |
|||
}) |
|||
|
|||
identity, err := provider.Authenticate(context.Background(), token) |
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "user123", identity.UserID) |
|||
assert.Equal(t, "user@example.com", identity.Email) |
|||
assert.Equal(t, "Test User", identity.DisplayName) |
|||
assert.Equal(t, "test-oidc", identity.Provider) |
|||
assert.Contains(t, identity.Groups, "users") |
|||
assert.Contains(t, identity.Groups, "developers") |
|||
}) |
|||
|
|||
t.Run("authentication with invalid token", func(t *testing.T) { |
|||
_, err := provider.Authenticate(context.Background(), "invalid-token") |
|||
assert.Error(t, err) |
|||
}) |
|||
} |
|||
|
|||
// TestOIDCProviderUserInfo tests user info retrieval
|
|||
func TestOIDCProviderUserInfo(t *testing.T) { |
|||
// Set up test server with UserInfo endpoint
|
|||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.Path == "/userinfo" { |
|||
userInfo := map[string]interface{}{ |
|||
"sub": r.URL.Query().Get("user_id"), |
|||
"email": "user@example.com", |
|||
"name": "Test User", |
|||
} |
|||
json.NewEncoder(w).Encode(userInfo) |
|||
} |
|||
})) |
|||
defer server.Close() |
|||
|
|||
provider := NewOIDCProvider("test-oidc") |
|||
config := &OIDCConfig{ |
|||
Issuer: server.URL, |
|||
ClientID: "test-client", |
|||
UserInfoUri: server.URL + "/userinfo", |
|||
} |
|||
|
|||
err := provider.Initialize(config) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("get user info", func(t *testing.T) { |
|||
identity, err := provider.GetUserInfo(context.Background(), "user123") |
|||
assert.NoError(t, err) |
|||
assert.Equal(t, "user123", identity.UserID) |
|||
assert.Equal(t, "user@example.com", identity.Email) |
|||
assert.Equal(t, "Test User", identity.DisplayName) |
|||
}) |
|||
|
|||
t.Run("get user info with empty id", func(t *testing.T) { |
|||
_, err := provider.GetUserInfo(context.Background(), "") |
|||
assert.Error(t, err) |
|||
}) |
|||
} |
|||
|
|||
// Helper functions for testing
|
|||
|
|||
func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { |
|||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
|||
require.NoError(t, err) |
|||
return privateKey, &privateKey.PublicKey |
|||
} |
|||
|
|||
func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string { |
|||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) |
|||
token.Header["kid"] = "test-key-id" |
|||
|
|||
tokenString, err := token.SignedString(privateKey) |
|||
require.NoError(t, err) |
|||
return tokenString |
|||
} |
|||
|
|||
func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string { |
|||
// This is a simplified version - real implementation would properly encode the public key
|
|||
return "test-public-key-n-value" |
|||
} |
|||
|
|||
func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server { |
|||
jwks := map[string]interface{}{ |
|||
"keys": []map[string]interface{}{ |
|||
{ |
|||
"kty": "RSA", |
|||
"kid": "test-key-id", |
|||
"use": "sig", |
|||
"alg": "RS256", |
|||
"n": encodePublicKey(t, publicKey), |
|||
"e": "AQAB", |
|||
}, |
|||
}, |
|||
} |
|||
|
|||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
switch r.URL.Path { |
|||
case "/.well-known/openid_configuration": |
|||
config := map[string]interface{}{ |
|||
"issuer": "http://" + r.Host, |
|||
"jwks_uri": "http://" + r.Host + "/jwks", |
|||
} |
|||
json.NewEncoder(w).Encode(config) |
|||
case "/jwks": |
|||
json.NewEncoder(w).Encode(jwks) |
|||
default: |
|||
http.NotFound(w, r) |
|||
} |
|||
})) |
|||
} |
@ -0,0 +1,224 @@ |
|||
package providers |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"net/mail" |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
// IdentityProvider defines the interface for external identity providers
|
|||
type IdentityProvider interface { |
|||
// Name returns the unique name of the provider
|
|||
Name() string |
|||
|
|||
// Initialize initializes the provider with configuration
|
|||
Initialize(config interface{}) error |
|||
|
|||
// Authenticate authenticates a user with a token and returns external identity
|
|||
Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) |
|||
|
|||
// GetUserInfo retrieves user information by user ID
|
|||
GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) |
|||
|
|||
// ValidateToken validates a token and returns claims
|
|||
ValidateToken(ctx context.Context, token string) (*TokenClaims, error) |
|||
} |
|||
|
|||
// ExternalIdentity represents an identity from an external provider
|
|||
type ExternalIdentity struct { |
|||
// UserID is the unique identifier from the external provider
|
|||
UserID string `json:"userId"` |
|||
|
|||
// Email is the user's email address
|
|||
Email string `json:"email"` |
|||
|
|||
// DisplayName is the user's display name
|
|||
DisplayName string `json:"displayName"` |
|||
|
|||
// Groups are the groups the user belongs to
|
|||
Groups []string `json:"groups,omitempty"` |
|||
|
|||
// Attributes are additional user attributes
|
|||
Attributes map[string]string `json:"attributes,omitempty"` |
|||
|
|||
// Provider is the name of the identity provider
|
|||
Provider string `json:"provider"` |
|||
} |
|||
|
|||
// Validate validates the external identity structure
|
|||
func (e *ExternalIdentity) Validate() error { |
|||
if e.UserID == "" { |
|||
return fmt.Errorf("user ID is required") |
|||
} |
|||
|
|||
if e.Provider == "" { |
|||
return fmt.Errorf("provider is required") |
|||
} |
|||
|
|||
if e.Email != "" { |
|||
if _, err := mail.ParseAddress(e.Email); err != nil { |
|||
return fmt.Errorf("invalid email format: %w", err) |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// TokenClaims represents claims from a validated token
|
|||
type TokenClaims struct { |
|||
// Subject (sub) - user identifier
|
|||
Subject string `json:"sub"` |
|||
|
|||
// Issuer (iss) - token issuer
|
|||
Issuer string `json:"iss"` |
|||
|
|||
// Audience (aud) - intended audience
|
|||
Audience string `json:"aud"` |
|||
|
|||
// ExpiresAt (exp) - expiration time
|
|||
ExpiresAt time.Time `json:"exp"` |
|||
|
|||
// IssuedAt (iat) - issued at time
|
|||
IssuedAt time.Time `json:"iat"` |
|||
|
|||
// NotBefore (nbf) - not valid before time
|
|||
NotBefore time.Time `json:"nbf,omitempty"` |
|||
|
|||
// Claims are additional claims from the token
|
|||
Claims map[string]interface{} `json:"claims,omitempty"` |
|||
} |
|||
|
|||
// IsValid checks if the token claims are valid (not expired, etc.)
|
|||
func (c *TokenClaims) IsValid() bool { |
|||
now := time.Now() |
|||
|
|||
// Check expiration
|
|||
if !c.ExpiresAt.IsZero() && now.After(c.ExpiresAt) { |
|||
return false |
|||
} |
|||
|
|||
// Check not before
|
|||
if !c.NotBefore.IsZero() && now.Before(c.NotBefore) { |
|||
return false |
|||
} |
|||
|
|||
// Check issued at (shouldn't be in the future)
|
|||
if !c.IssuedAt.IsZero() && now.Before(c.IssuedAt) { |
|||
return false |
|||
} |
|||
|
|||
return true |
|||
} |
|||
|
|||
// GetClaimString returns a string claim value
|
|||
func (c *TokenClaims) GetClaimString(key string) (string, bool) { |
|||
if value, exists := c.Claims[key]; exists { |
|||
if str, ok := value.(string); ok { |
|||
return str, true |
|||
} |
|||
} |
|||
return "", false |
|||
} |
|||
|
|||
// GetClaimStringSlice returns a string slice claim value
|
|||
func (c *TokenClaims) GetClaimStringSlice(key string) ([]string, bool) { |
|||
if value, exists := c.Claims[key]; exists { |
|||
switch v := value.(type) { |
|||
case []string: |
|||
return v, true |
|||
case []interface{}: |
|||
var result []string |
|||
for _, item := range v { |
|||
if str, ok := item.(string); ok { |
|||
result = append(result, str) |
|||
} |
|||
} |
|||
return result, len(result) > 0 |
|||
case string: |
|||
// Single string can be treated as slice
|
|||
return []string{v}, true |
|||
} |
|||
} |
|||
return nil, false |
|||
} |
|||
|
|||
// ProviderConfig represents configuration for identity providers
|
|||
type ProviderConfig struct { |
|||
// Type of provider (oidc, ldap, saml)
|
|||
Type string `json:"type"` |
|||
|
|||
// Name of the provider instance
|
|||
Name string `json:"name"` |
|||
|
|||
// Enabled indicates if the provider is active
|
|||
Enabled bool `json:"enabled"` |
|||
|
|||
// Config is provider-specific configuration
|
|||
Config map[string]interface{} `json:"config"` |
|||
|
|||
// RoleMapping defines how to map external identities to roles
|
|||
RoleMapping *RoleMapping `json:"roleMapping,omitempty"` |
|||
} |
|||
|
|||
// RoleMapping defines rules for mapping external identities to roles
|
|||
type RoleMapping struct { |
|||
// Rules are the mapping rules
|
|||
Rules []MappingRule `json:"rules"` |
|||
|
|||
// DefaultRole is assigned if no rules match
|
|||
DefaultRole string `json:"defaultRole,omitempty"` |
|||
} |
|||
|
|||
// MappingRule defines a single mapping rule
|
|||
type MappingRule struct { |
|||
// Claim is the claim key to check
|
|||
Claim string `json:"claim"` |
|||
|
|||
// Value is the expected claim value (supports wildcards)
|
|||
Value string `json:"value"` |
|||
|
|||
// Role is the role ARN to assign
|
|||
Role string `json:"role"` |
|||
|
|||
// Condition is additional condition logic (optional)
|
|||
Condition string `json:"condition,omitempty"` |
|||
} |
|||
|
|||
// Matches checks if a rule matches the given claims
|
|||
func (r *MappingRule) Matches(claims *TokenClaims) bool { |
|||
if r.Claim == "" || r.Value == "" { |
|||
return false |
|||
} |
|||
|
|||
claimValue, exists := claims.GetClaimString(r.Claim) |
|||
if !exists { |
|||
// Try as string slice
|
|||
if claimSlice, sliceExists := claims.GetClaimStringSlice(r.Claim); sliceExists { |
|||
for _, val := range claimSlice { |
|||
if r.matchValue(val) { |
|||
return true |
|||
} |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
return r.matchValue(claimValue) |
|||
} |
|||
|
|||
// matchValue checks if a value matches the rule value (with wildcard support)
|
|||
func (r *MappingRule) matchValue(value string) bool { |
|||
// Simple wildcard matching
|
|||
if strings.Contains(r.Value, "*") { |
|||
// Convert wildcard to regex-like matching
|
|||
pattern := strings.ReplaceAll(r.Value, "*", "") |
|||
if pattern == "" { |
|||
return true // "*" matches everything
|
|||
} |
|||
return strings.Contains(value, pattern) |
|||
} |
|||
|
|||
return value == r.Value |
|||
} |
@ -0,0 +1,246 @@ |
|||
package providers |
|||
|
|||
import ( |
|||
"context" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
// TestIdentityProviderInterface tests the core identity provider interface
|
|||
func TestIdentityProviderInterface(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
provider IdentityProvider |
|||
wantErr bool |
|||
}{ |
|||
// We'll add test cases as we implement providers
|
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
// Test provider name
|
|||
name := tt.provider.Name() |
|||
assert.NotEmpty(t, name, "Provider name should not be empty") |
|||
|
|||
// Test initialization
|
|||
err := tt.provider.Initialize(nil) |
|||
if tt.wantErr { |
|||
assert.Error(t, err) |
|||
return |
|||
} |
|||
require.NoError(t, err) |
|||
|
|||
// Test authentication with invalid token
|
|||
ctx := context.Background() |
|||
_, err = tt.provider.Authenticate(ctx, "invalid-token") |
|||
assert.Error(t, err, "Should fail with invalid token") |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestExternalIdentityValidation tests external identity structure validation
|
|||
func TestExternalIdentityValidation(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
identity *ExternalIdentity |
|||
wantErr bool |
|||
}{ |
|||
{ |
|||
name: "valid identity", |
|||
identity: &ExternalIdentity{ |
|||
UserID: "user123", |
|||
Email: "user@example.com", |
|||
DisplayName: "Test User", |
|||
Groups: []string{"group1", "group2"}, |
|||
Attributes: map[string]string{"dept": "engineering"}, |
|||
Provider: "test-provider", |
|||
}, |
|||
wantErr: false, |
|||
}, |
|||
{ |
|||
name: "missing user id", |
|||
identity: &ExternalIdentity{ |
|||
Email: "user@example.com", |
|||
Provider: "test-provider", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "missing provider", |
|||
identity: &ExternalIdentity{ |
|||
UserID: "user123", |
|||
Email: "user@example.com", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
{ |
|||
name: "invalid email", |
|||
identity: &ExternalIdentity{ |
|||
UserID: "user123", |
|||
Email: "invalid-email", |
|||
Provider: "test-provider", |
|||
}, |
|||
wantErr: true, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
err := tt.identity.Validate() |
|||
if tt.wantErr { |
|||
assert.Error(t, err) |
|||
} else { |
|||
assert.NoError(t, err) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestTokenClaimsValidation tests token claims structure
|
|||
func TestTokenClaimsValidation(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
claims *TokenClaims |
|||
valid bool |
|||
}{ |
|||
{ |
|||
name: "valid claims", |
|||
claims: &TokenClaims{ |
|||
Subject: "user123", |
|||
Issuer: "https://provider.example.com", |
|||
Audience: "seaweedfs", |
|||
ExpiresAt: time.Now().Add(time.Hour), |
|||
IssuedAt: time.Now().Add(-time.Minute), |
|||
Claims: map[string]interface{}{"email": "user@example.com"}, |
|||
}, |
|||
valid: true, |
|||
}, |
|||
{ |
|||
name: "expired token", |
|||
claims: &TokenClaims{ |
|||
Subject: "user123", |
|||
Issuer: "https://provider.example.com", |
|||
Audience: "seaweedfs", |
|||
ExpiresAt: time.Now().Add(-time.Hour), // Expired
|
|||
IssuedAt: time.Now().Add(-time.Hour * 2), |
|||
Claims: map[string]interface{}{"email": "user@example.com"}, |
|||
}, |
|||
valid: false, |
|||
}, |
|||
{ |
|||
name: "future issued token", |
|||
claims: &TokenClaims{ |
|||
Subject: "user123", |
|||
Issuer: "https://provider.example.com", |
|||
Audience: "seaweedfs", |
|||
ExpiresAt: time.Now().Add(time.Hour), |
|||
IssuedAt: time.Now().Add(time.Hour), // Future
|
|||
Claims: map[string]interface{}{"email": "user@example.com"}, |
|||
}, |
|||
valid: false, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
valid := tt.claims.IsValid() |
|||
assert.Equal(t, tt.valid, valid) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// TestProviderRegistry tests provider registration and discovery
|
|||
func TestProviderRegistry(t *testing.T) { |
|||
// Clear registry for test
|
|||
registry := NewProviderRegistry() |
|||
|
|||
t.Run("register provider", func(t *testing.T) { |
|||
mockProvider := &MockProvider{name: "test-provider"} |
|||
|
|||
err := registry.RegisterProvider(mockProvider) |
|||
assert.NoError(t, err) |
|||
|
|||
// Test duplicate registration
|
|||
err = registry.RegisterProvider(mockProvider) |
|||
assert.Error(t, err, "Should not allow duplicate registration") |
|||
}) |
|||
|
|||
t.Run("get provider", func(t *testing.T) { |
|||
provider, exists := registry.GetProvider("test-provider") |
|||
assert.True(t, exists) |
|||
assert.Equal(t, "test-provider", provider.Name()) |
|||
|
|||
// Test non-existent provider
|
|||
_, exists = registry.GetProvider("non-existent") |
|||
assert.False(t, exists) |
|||
}) |
|||
|
|||
t.Run("list providers", func(t *testing.T) { |
|||
providers := registry.ListProviders() |
|||
assert.Len(t, providers, 1) |
|||
assert.Equal(t, "test-provider", providers[0]) |
|||
}) |
|||
} |
|||
|
|||
// MockProvider for testing
|
|||
type MockProvider struct { |
|||
name string |
|||
initialized bool |
|||
shouldError bool |
|||
} |
|||
|
|||
func (m *MockProvider) Name() string { |
|||
return m.name |
|||
} |
|||
|
|||
func (m *MockProvider) Initialize(config interface{}) error { |
|||
if m.shouldError { |
|||
return assert.AnError |
|||
} |
|||
m.initialized = true |
|||
return nil |
|||
} |
|||
|
|||
func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) { |
|||
if !m.initialized { |
|||
return nil, assert.AnError |
|||
} |
|||
if token == "invalid-token" { |
|||
return nil, assert.AnError |
|||
} |
|||
return &ExternalIdentity{ |
|||
UserID: "test-user", |
|||
Email: "test@example.com", |
|||
DisplayName: "Test User", |
|||
Provider: m.name, |
|||
}, nil |
|||
} |
|||
|
|||
func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) { |
|||
if !m.initialized || userID == "" { |
|||
return nil, assert.AnError |
|||
} |
|||
return &ExternalIdentity{ |
|||
UserID: userID, |
|||
Email: userID + "@example.com", |
|||
DisplayName: "User " + userID, |
|||
Provider: m.name, |
|||
}, nil |
|||
} |
|||
|
|||
func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) { |
|||
if !m.initialized || token == "invalid-token" { |
|||
return nil, assert.AnError |
|||
} |
|||
return &TokenClaims{ |
|||
Subject: "test-user", |
|||
Issuer: "test-issuer", |
|||
Audience: "seaweedfs", |
|||
ExpiresAt: time.Now().Add(time.Hour), |
|||
IssuedAt: time.Now(), |
|||
Claims: map[string]interface{}{"email": "test@example.com"}, |
|||
}, nil |
|||
} |
@ -0,0 +1,109 @@ |
|||
package providers |
|||
|
|||
import ( |
|||
"fmt" |
|||
"sync" |
|||
) |
|||
|
|||
// ProviderRegistry manages registered identity providers
|
|||
type ProviderRegistry struct { |
|||
mu sync.RWMutex |
|||
providers map[string]IdentityProvider |
|||
} |
|||
|
|||
// NewProviderRegistry creates a new provider registry
|
|||
func NewProviderRegistry() *ProviderRegistry { |
|||
return &ProviderRegistry{ |
|||
providers: make(map[string]IdentityProvider), |
|||
} |
|||
} |
|||
|
|||
// RegisterProvider registers a new identity provider
|
|||
func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error { |
|||
if provider == nil { |
|||
return fmt.Errorf("provider cannot be nil") |
|||
} |
|||
|
|||
name := provider.Name() |
|||
if name == "" { |
|||
return fmt.Errorf("provider name cannot be empty") |
|||
} |
|||
|
|||
r.mu.Lock() |
|||
defer r.mu.Unlock() |
|||
|
|||
if _, exists := r.providers[name]; exists { |
|||
return fmt.Errorf("provider %s is already registered", name) |
|||
} |
|||
|
|||
r.providers[name] = provider |
|||
return nil |
|||
} |
|||
|
|||
// GetProvider retrieves a provider by name
|
|||
func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) { |
|||
r.mu.RLock() |
|||
defer r.mu.RUnlock() |
|||
|
|||
provider, exists := r.providers[name] |
|||
return provider, exists |
|||
} |
|||
|
|||
// ListProviders returns all registered provider names
|
|||
func (r *ProviderRegistry) ListProviders() []string { |
|||
r.mu.RLock() |
|||
defer r.mu.RUnlock() |
|||
|
|||
var names []string |
|||
for name := range r.providers { |
|||
names = append(names, name) |
|||
} |
|||
return names |
|||
} |
|||
|
|||
// UnregisterProvider removes a provider from the registry
|
|||
func (r *ProviderRegistry) UnregisterProvider(name string) error { |
|||
r.mu.Lock() |
|||
defer r.mu.Unlock() |
|||
|
|||
if _, exists := r.providers[name]; !exists { |
|||
return fmt.Errorf("provider %s is not registered", name) |
|||
} |
|||
|
|||
delete(r.providers, name) |
|||
return nil |
|||
} |
|||
|
|||
// Clear removes all providers from the registry
|
|||
func (r *ProviderRegistry) Clear() { |
|||
r.mu.Lock() |
|||
defer r.mu.Unlock() |
|||
|
|||
r.providers = make(map[string]IdentityProvider) |
|||
} |
|||
|
|||
// GetProviderCount returns the number of registered providers
|
|||
func (r *ProviderRegistry) GetProviderCount() int { |
|||
r.mu.RLock() |
|||
defer r.mu.RUnlock() |
|||
|
|||
return len(r.providers) |
|||
} |
|||
|
|||
// Default global registry
|
|||
var defaultRegistry = NewProviderRegistry() |
|||
|
|||
// RegisterProvider registers a provider in the default registry
|
|||
func RegisterProvider(provider IdentityProvider) error { |
|||
return defaultRegistry.RegisterProvider(provider) |
|||
} |
|||
|
|||
// GetProvider retrieves a provider from the default registry
|
|||
func GetProvider(name string) (IdentityProvider, bool) { |
|||
return defaultRegistry.GetProvider(name) |
|||
} |
|||
|
|||
// ListProviders returns all provider names from the default registry
|
|||
func ListProviders() []string { |
|||
return defaultRegistry.ListProviders() |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue