Browse Source

TDD RED Phase: Add identity provider framework tests

- Add core IdentityProvider interface with tests
- Add OIDC provider tests with JWT token validation
- Add LDAP provider tests with authentication flows
- Add ProviderRegistry for managing multiple providers
- Tests currently failing as expected in TDD RED phase
pull/7160/head
chrislu 1 month ago
parent
commit
ffab92e6cd
  1. 219
      weed/iam/ldap/ldap_provider.go
  2. 357
      weed/iam/ldap/ldap_provider_test.go
  3. 124
      weed/iam/oidc/oidc_provider.go
  4. 318
      weed/iam/oidc/oidc_provider_test.go
  5. 224
      weed/iam/providers/provider.go
  6. 246
      weed/iam/providers/provider_test.go
  7. 109
      weed/iam/providers/registry.go

219
weed/iam/ldap/ldap_provider.go

@ -0,0 +1,219 @@
package ldap
import (
"context"
"fmt"
"strings"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
)
// LDAPProvider implements LDAP authentication
type LDAPProvider struct {
name string
config *LDAPConfig
initialized bool
connPool interface{} // Will be proper LDAP connection pool
}
// LDAPConfig holds LDAP provider configuration
type LDAPConfig struct {
// Server is the LDAP server URL (e.g., ldap://localhost:389)
Server string `json:"server"`
// BaseDN is the base distinguished name for searches
BaseDN string `json:"baseDn"`
// BindDN is the distinguished name for binding (authentication)
BindDN string `json:"bindDn,omitempty"`
// BindPass is the password for binding
BindPass string `json:"bindPass,omitempty"`
// UserFilter is the LDAP filter for finding users (e.g., "(sAMAccountName=%s)")
UserFilter string `json:"userFilter"`
// GroupFilter is the LDAP filter for finding groups (e.g., "(member=%s)")
GroupFilter string `json:"groupFilter,omitempty"`
// Attributes maps SeaweedFS identity fields to LDAP attributes
Attributes map[string]string `json:"attributes,omitempty"`
// RoleMapping defines how to map LDAP groups to roles
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"`
// TLS configuration
UseTLS bool `json:"useTls,omitempty"`
TLSCert string `json:"tlsCert,omitempty"`
TLSKey string `json:"tlsKey,omitempty"`
TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"`
// Connection pool settings
MaxConnections int `json:"maxConnections,omitempty"`
ConnTimeout int `json:"connTimeout,omitempty"` // seconds
}
// NewLDAPProvider creates a new LDAP provider
func NewLDAPProvider(name string) *LDAPProvider {
return &LDAPProvider{
name: name,
}
}
// Name returns the provider name
func (p *LDAPProvider) Name() string {
return p.name
}
// Initialize initializes the LDAP provider with configuration
func (p *LDAPProvider) Initialize(config interface{}) error {
ldapConfig, ok := config.(*LDAPConfig)
if !ok {
return fmt.Errorf("invalid config type for LDAP provider")
}
if err := p.validateConfig(ldapConfig); err != nil {
return fmt.Errorf("invalid LDAP configuration: %w", err)
}
p.config = ldapConfig
p.initialized = true
// TODO: Initialize LDAP connection pool
return fmt.Errorf("not implemented yet")
}
// validateConfig validates the LDAP configuration
func (p *LDAPProvider) validateConfig(config *LDAPConfig) error {
if config.Server == "" {
return fmt.Errorf("server is required")
}
if config.BaseDN == "" {
return fmt.Errorf("base DN is required")
}
// Basic URL validation
if !strings.HasPrefix(config.Server, "ldap://") && !strings.HasPrefix(config.Server, "ldaps://") {
return fmt.Errorf("invalid server URL format")
}
// Set default user filter if not provided
if config.UserFilter == "" {
config.UserFilter = "(uid=%s)" // Default LDAP user filter
}
// Set default attributes if not provided
if config.Attributes == nil {
config.Attributes = map[string]string{
"email": "mail",
"displayName": "cn",
"groups": "memberOf",
}
}
return nil
}
// Authenticate authenticates a user with LDAP
func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
// TODO: Parse credentials (username:password), bind to LDAP, search for user
return nil, fmt.Errorf("not implemented yet")
}
// GetUserInfo retrieves user information from LDAP
func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if userID == "" {
return nil, fmt.Errorf("user ID cannot be empty")
}
// TODO: Search LDAP for user, get attributes
return nil, fmt.Errorf("not implemented yet")
}
// ValidateToken validates credentials (for LDAP, this is username/password)
func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
// TODO: For LDAP, "token" would be username:password format
// Validate credentials and return claims
return nil, fmt.Errorf("not implemented yet")
}
// mapLDAPAttributes maps LDAP attributes to ExternalIdentity
func (p *LDAPProvider) mapLDAPAttributes(userID string, attrs map[string][]string) *providers.ExternalIdentity {
identity := &providers.ExternalIdentity{
UserID: userID,
Provider: p.name,
Attributes: make(map[string]string),
}
// Map configured attributes
for identityField, ldapAttr := range p.config.Attributes {
if values, exists := attrs[ldapAttr]; exists && len(values) > 0 {
switch identityField {
case "email":
identity.Email = values[0]
case "displayName":
identity.DisplayName = values[0]
case "groups":
identity.Groups = values
default:
// Store as custom attribute
identity.Attributes[identityField] = values[0]
}
}
}
return identity
}
// mapUserToRole maps user groups to roles based on role mapping rules
func (p *LDAPProvider) mapUserToRole(identity *providers.ExternalIdentity) string {
if p.config.RoleMapping == nil {
return ""
}
// Create token claims from identity for rule matching
claims := &providers.TokenClaims{
Subject: identity.UserID,
Claims: map[string]interface{}{
"groups": identity.Groups,
"email": identity.Email,
},
}
// Check mapping rules
for _, rule := range p.config.RoleMapping.Rules {
if rule.Matches(claims) {
return rule.Role
}
}
// Return default role if no rules match
return p.config.RoleMapping.DefaultRole
}
// Connection management methods (stubs for now)
func (p *LDAPProvider) getConnectionPool() interface{} {
return p.connPool
}
func (p *LDAPProvider) getConnection() (interface{}, error) {
// TODO: Get connection from pool
return nil, fmt.Errorf("not implemented")
}
func (p *LDAPProvider) releaseConnection(conn interface{}) {
// TODO: Return connection to pool
}

357
weed/iam/ldap/ldap_provider_test.go

@ -0,0 +1,357 @@
package ldap
import (
"context"
"fmt"
"testing"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestLDAPProviderInitialization tests LDAP provider initialization
func TestLDAPProviderInitialization(t *testing.T) {
tests := []struct {
name string
config *LDAPConfig
wantErr bool
}{
{
name: "valid config",
config: &LDAPConfig{
Server: "ldap://localhost:389",
BaseDN: "DC=example,DC=com",
BindDN: "CN=admin,DC=example,DC=com",
BindPass: "password",
UserFilter: "(sAMAccountName=%s)",
GroupFilter: "(member=%s)",
},
wantErr: false,
},
{
name: "missing server",
config: &LDAPConfig{
BaseDN: "DC=example,DC=com",
},
wantErr: true,
},
{
name: "missing base DN",
config: &LDAPConfig{
Server: "ldap://localhost:389",
},
wantErr: true,
},
{
name: "invalid server URL",
config: &LDAPConfig{
Server: "invalid-url",
BaseDN: "DC=example,DC=com",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := NewLDAPProvider("test-ldap")
err := provider.Initialize(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, "test-ldap", provider.Name())
}
})
}
}
// TestLDAPProviderAuthentication tests LDAP authentication
func TestLDAPProviderAuthentication(t *testing.T) {
// Skip if no LDAP test server available
if testing.Short() {
t.Skip("Skipping LDAP integration test in short mode")
}
provider := NewLDAPProvider("test-ldap")
config := &LDAPConfig{
Server: "ldap://localhost:389",
BaseDN: "DC=example,DC=com",
BindDN: "CN=admin,DC=example,DC=com",
BindPass: "password",
UserFilter: "(sAMAccountName=%s)",
GroupFilter: "(member=%s)",
Attributes: map[string]string{
"email": "mail",
"displayName": "displayName",
"groups": "memberOf",
},
RoleMapping: &providers.RoleMapping{
Rules: []providers.MappingRule{
{
Claim: "groups",
Value: "*CN=Admins*",
Role: "arn:seaweed:iam::role/AdminRole",
},
{
Claim: "groups",
Value: "*CN=Users*",
Role: "arn:seaweed:iam::role/UserRole",
},
},
DefaultRole: "arn:seaweed:iam::role/GuestRole",
},
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("authenticate with username/password", func(t *testing.T) {
// This would require an actual LDAP server for integration testing
credentials := "user:password" // Basic auth format
identity, err := provider.Authenticate(context.Background(), credentials)
if err != nil {
t.Skip("LDAP server not available for testing")
}
assert.NoError(t, err)
assert.Equal(t, "user", identity.UserID)
assert.Equal(t, "test-ldap", identity.Provider)
assert.NotEmpty(t, identity.Email)
})
t.Run("authenticate with invalid credentials", func(t *testing.T) {
_, err := provider.Authenticate(context.Background(), "invalid:credentials")
assert.Error(t, err)
})
}
// TestLDAPProviderUserInfo tests LDAP user info retrieval
func TestLDAPProviderUserInfo(t *testing.T) {
if testing.Short() {
t.Skip("Skipping LDAP integration test in short mode")
}
provider := NewLDAPProvider("test-ldap")
config := &LDAPConfig{
Server: "ldap://localhost:389",
BaseDN: "DC=example,DC=com",
BindDN: "CN=admin,DC=example,DC=com",
BindPass: "password",
UserFilter: "(sAMAccountName=%s)",
Attributes: map[string]string{
"email": "mail",
"displayName": "displayName",
"department": "department",
},
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("get user info", func(t *testing.T) {
identity, err := provider.GetUserInfo(context.Background(), "testuser")
if err != nil {
t.Skip("LDAP server not available for testing")
}
assert.NoError(t, err)
assert.Equal(t, "testuser", identity.UserID)
assert.Equal(t, "test-ldap", identity.Provider)
assert.NotEmpty(t, identity.Email)
assert.NotEmpty(t, identity.DisplayName)
})
t.Run("get user info with empty username", func(t *testing.T) {
_, err := provider.GetUserInfo(context.Background(), "")
assert.Error(t, err)
})
t.Run("get user info for non-existent user", func(t *testing.T) {
_, err := provider.GetUserInfo(context.Background(), "nonexistent")
assert.Error(t, err)
})
}
// TestLDAPAttributeMapping tests LDAP attribute mapping
func TestLDAPAttributeMapping(t *testing.T) {
provider := NewLDAPProvider("test-ldap")
config := &LDAPConfig{
Server: "ldap://localhost:389",
BaseDN: "DC=example,DC=com",
Attributes: map[string]string{
"email": "mail",
"displayName": "cn",
"department": "departmentNumber",
"groups": "memberOf",
},
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("map LDAP attributes to identity", func(t *testing.T) {
ldapAttrs := map[string][]string{
"mail": {"user@example.com"},
"cn": {"John Doe"},
"departmentNumber": {"IT"},
"memberOf": {
"CN=Users,OU=Groups,DC=example,DC=com",
"CN=Developers,OU=Groups,DC=example,DC=com",
},
}
identity := provider.mapLDAPAttributes("testuser", ldapAttrs)
assert.Equal(t, "testuser", identity.UserID)
assert.Equal(t, "user@example.com", identity.Email)
assert.Equal(t, "John Doe", identity.DisplayName)
assert.Equal(t, "test-ldap", identity.Provider)
// Check groups
assert.Contains(t, identity.Groups, "CN=Users,OU=Groups,DC=example,DC=com")
assert.Contains(t, identity.Groups, "CN=Developers,OU=Groups,DC=example,DC=com")
// Check attributes
assert.Equal(t, "IT", identity.Attributes["department"])
})
}
// TestLDAPGroupFiltering tests LDAP group filtering and role mapping
func TestLDAPGroupFiltering(t *testing.T) {
provider := NewLDAPProvider("test-ldap")
config := &LDAPConfig{
Server: "ldap://localhost:389",
BaseDN: "DC=example,DC=com",
RoleMapping: &providers.RoleMapping{
Rules: []providers.MappingRule{
{
Claim: "groups",
Value: "*Admins*",
Role: "arn:seaweed:iam::role/AdminRole",
},
{
Claim: "groups",
Value: "*Users*",
Role: "arn:seaweed:iam::role/UserRole",
},
},
DefaultRole: "arn:seaweed:iam::role/GuestRole",
},
}
err := provider.Initialize(config)
require.NoError(t, err)
tests := []struct {
name string
groups []string
expectedRole string
expectedClaims map[string]interface{}
}{
{
name: "admin user",
groups: []string{"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"},
expectedRole: "arn:seaweed:iam::role/AdminRole",
},
{
name: "regular user",
groups: []string{"CN=Users,OU=Groups,DC=example,DC=com"},
expectedRole: "arn:seaweed:iam::role/UserRole",
},
{
name: "guest user",
groups: []string{"CN=Guests,OU=Groups,DC=example,DC=com"},
expectedRole: "arn:seaweed:iam::role/GuestRole",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
identity := &providers.ExternalIdentity{
UserID: "testuser",
Groups: tt.groups,
Provider: "test-ldap",
}
role := provider.mapUserToRole(identity)
assert.Equal(t, tt.expectedRole, role)
})
}
}
// TestLDAPConnectionPool tests LDAP connection pooling
func TestLDAPConnectionPool(t *testing.T) {
if testing.Short() {
t.Skip("Skipping LDAP connection pool test in short mode")
}
provider := NewLDAPProvider("test-ldap")
config := &LDAPConfig{
Server: "ldap://localhost:389",
BaseDN: "DC=example,DC=com",
BindDN: "CN=admin,DC=example,DC=com",
BindPass: "password",
MaxConnections: 5,
ConnTimeout: 30,
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("connection pool management", func(t *testing.T) {
// Test that multiple concurrent requests work
// This would require actual LDAP server for full testing
pool := provider.getConnectionPool()
assert.NotNil(t, pool)
// Test connection acquisition and release
conn, err := provider.getConnection()
if err != nil {
t.Skip("LDAP server not available")
}
assert.NotNil(t, conn)
provider.releaseConnection(conn)
})
}
// MockLDAPServer for unit testing (without external dependencies)
type MockLDAPServer struct {
users map[string]map[string][]string
}
func NewMockLDAPServer() *MockLDAPServer {
return &MockLDAPServer{
users: map[string]map[string][]string{
"testuser": {
"mail": {"testuser@example.com"},
"cn": {"Test User"},
"department": {"Engineering"},
"memberOf": {"CN=Users,OU=Groups,DC=example,DC=com"},
},
"admin": {
"mail": {"admin@example.com"},
"cn": {"Administrator"},
"department": {"IT"},
"memberOf": {"CN=Admins,OU=Groups,DC=example,DC=com", "CN=Users,OU=Groups,DC=example,DC=com"},
},
},
}
}
func (m *MockLDAPServer) Authenticate(username, password string) bool {
_, exists := m.users[username]
return exists && password == "password" // Mock authentication
}
func (m *MockLDAPServer) GetUserAttributes(username string) (map[string][]string, error) {
if attrs, exists := m.users[username]; exists {
return attrs, nil
}
return nil, fmt.Errorf("user not found: %s", username)
}

124
weed/iam/oidc/oidc_provider.go

@ -0,0 +1,124 @@
package oidc
import (
"context"
"fmt"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
)
// OIDCProvider implements OpenID Connect authentication
type OIDCProvider struct {
name string
config *OIDCConfig
initialized bool
}
// OIDCConfig holds OIDC provider configuration
type OIDCConfig struct {
// Issuer is the OIDC issuer URL
Issuer string `json:"issuer"`
// ClientID is the OAuth2 client ID
ClientID string `json:"clientId"`
// ClientSecret is the OAuth2 client secret (optional for public clients)
ClientSecret string `json:"clientSecret,omitempty"`
// JWKSUri is the JSON Web Key Set URI
JWKSUri string `json:"jwksUri,omitempty"`
// UserInfoUri is the UserInfo endpoint URI
UserInfoUri string `json:"userInfoUri,omitempty"`
// Scopes are the OAuth2 scopes to request
Scopes []string `json:"scopes,omitempty"`
// RoleMapping defines how to map OIDC claims to roles
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"`
// ClaimsMapping defines how to map OIDC claims to identity attributes
ClaimsMapping map[string]string `json:"claimsMapping,omitempty"`
}
// NewOIDCProvider creates a new OIDC provider
func NewOIDCProvider(name string) *OIDCProvider {
return &OIDCProvider{
name: name,
}
}
// Name returns the provider name
func (p *OIDCProvider) Name() string {
return p.name
}
// Initialize initializes the OIDC provider with configuration
func (p *OIDCProvider) Initialize(config interface{}) error {
oidcConfig, ok := config.(*OIDCConfig)
if !ok {
return fmt.Errorf("invalid config type for OIDC provider")
}
if err := p.validateConfig(oidcConfig); err != nil {
return fmt.Errorf("invalid OIDC configuration: %w", err)
}
p.config = oidcConfig
p.initialized = true
// TODO: Initialize OIDC client, fetch JWKS, etc.
return fmt.Errorf("not implemented yet")
}
// validateConfig validates the OIDC configuration
func (p *OIDCProvider) validateConfig(config *OIDCConfig) error {
if config.Issuer == "" {
return fmt.Errorf("issuer is required")
}
if config.ClientID == "" {
return fmt.Errorf("client ID is required")
}
// Basic URL validation for issuer
if config.Issuer != "" && config.Issuer != "https://accounts.google.com" && config.Issuer[0:4] != "http" {
return fmt.Errorf("invalid issuer URL format")
}
return nil
}
// Authenticate authenticates a user with an OIDC token
func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
// TODO: Validate JWT token, extract claims, map to identity
return nil, fmt.Errorf("not implemented yet")
}
// GetUserInfo retrieves user information from the UserInfo endpoint
func (p *OIDCProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if userID == "" {
return nil, fmt.Errorf("user ID cannot be empty")
}
// TODO: Call UserInfo endpoint
return nil, fmt.Errorf("not implemented yet")
}
// ValidateToken validates an OIDC JWT token
func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
// TODO: Validate JWT signature, claims, expiration
return nil, fmt.Errorf("not implemented yet")
}

318
weed/iam/oidc/oidc_provider_test.go

@ -0,0 +1,318 @@
package oidc
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestOIDCProviderInitialization tests OIDC provider initialization
func TestOIDCProviderInitialization(t *testing.T) {
tests := []struct {
name string
config *OIDCConfig
wantErr bool
}{
{
name: "valid config",
config: &OIDCConfig{
Issuer: "https://accounts.google.com",
ClientID: "test-client-id",
JWKSUri: "https://www.googleapis.com/oauth2/v3/certs",
},
wantErr: false,
},
{
name: "missing issuer",
config: &OIDCConfig{
ClientID: "test-client-id",
},
wantErr: true,
},
{
name: "missing client id",
config: &OIDCConfig{
Issuer: "https://accounts.google.com",
},
wantErr: true,
},
{
name: "invalid issuer url",
config: &OIDCConfig{
Issuer: "not-a-url",
ClientID: "test-client-id",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := NewOIDCProvider("test-provider")
err := provider.Initialize(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, "test-provider", provider.Name())
}
})
}
}
// TestOIDCProviderJWTValidation tests JWT token validation
func TestOIDCProviderJWTValidation(t *testing.T) {
// Set up test server with JWKS endpoint
privateKey, publicKey := generateTestKeys(t)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": "test-key-id",
"use": "sig",
"alg": "RS256",
"n": encodePublicKey(t, publicKey),
"e": "AQAB",
},
},
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid_configuration" {
config := map[string]interface{}{
"issuer": "http://" + r.Host,
"jwks_uri": "http://" + r.Host + "/jwks",
}
json.NewEncoder(w).Encode(config)
} else if r.URL.Path == "/jwks" {
json.NewEncoder(w).Encode(jwks)
}
}))
defer server.Close()
provider := NewOIDCProvider("test-oidc")
config := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
JWKSUri: server.URL + "/jwks",
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("valid token", func(t *testing.T) {
// Create valid JWT token
token := createTestJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"email": "user@example.com",
"name": "Test User",
})
claims, err := provider.ValidateToken(context.Background(), token)
assert.NoError(t, err)
assert.Equal(t, "user123", claims.Subject)
assert.Equal(t, server.URL, claims.Issuer)
email, exists := claims.GetClaimString("email")
assert.True(t, exists)
assert.Equal(t, "user@example.com", email)
})
t.Run("expired token", func(t *testing.T) {
// Create expired JWT token
token := createTestJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(-time.Hour).Unix(), // Expired
"iat": time.Now().Add(-time.Hour * 2).Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
assert.Error(t, err)
})
t.Run("invalid signature", func(t *testing.T) {
// Create token with wrong key
wrongKey, _ := generateTestKeys(t)
token := createTestJWT(t, wrongKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
assert.Error(t, err)
})
}
// TestOIDCProviderAuthentication tests authentication flow
func TestOIDCProviderAuthentication(t *testing.T) {
// Set up test OIDC provider
privateKey, publicKey := generateTestKeys(t)
server := setupOIDCTestServer(t, publicKey)
defer server.Close()
provider := NewOIDCProvider("test-oidc")
config := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
JWKSUri: server.URL + "/jwks",
RoleMapping: &providers.RoleMapping{
Rules: []providers.MappingRule{
{
Claim: "email",
Value: "*@example.com",
Role: "arn:seaweed:iam::role/UserRole",
},
{
Claim: "groups",
Value: "admins",
Role: "arn:seaweed:iam::role/AdminRole",
},
},
DefaultRole: "arn:seaweed:iam::role/GuestRole",
},
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("successful authentication", func(t *testing.T) {
token := createTestJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"email": "user@example.com",
"name": "Test User",
"groups": []string{"users", "developers"},
})
identity, err := provider.Authenticate(context.Background(), token)
assert.NoError(t, err)
assert.Equal(t, "user123", identity.UserID)
assert.Equal(t, "user@example.com", identity.Email)
assert.Equal(t, "Test User", identity.DisplayName)
assert.Equal(t, "test-oidc", identity.Provider)
assert.Contains(t, identity.Groups, "users")
assert.Contains(t, identity.Groups, "developers")
})
t.Run("authentication with invalid token", func(t *testing.T) {
_, err := provider.Authenticate(context.Background(), "invalid-token")
assert.Error(t, err)
})
}
// TestOIDCProviderUserInfo tests user info retrieval
func TestOIDCProviderUserInfo(t *testing.T) {
// Set up test server with UserInfo endpoint
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/userinfo" {
userInfo := map[string]interface{}{
"sub": r.URL.Query().Get("user_id"),
"email": "user@example.com",
"name": "Test User",
}
json.NewEncoder(w).Encode(userInfo)
}
}))
defer server.Close()
provider := NewOIDCProvider("test-oidc")
config := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
UserInfoUri: server.URL + "/userinfo",
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("get user info", func(t *testing.T) {
identity, err := provider.GetUserInfo(context.Background(), "user123")
assert.NoError(t, err)
assert.Equal(t, "user123", identity.UserID)
assert.Equal(t, "user@example.com", identity.Email)
assert.Equal(t, "Test User", identity.DisplayName)
})
t.Run("get user info with empty id", func(t *testing.T) {
_, err := provider.GetUserInfo(context.Background(), "")
assert.Error(t, err)
})
}
// Helper functions for testing
func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
return privateKey, &privateKey.PublicKey
}
func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
// This is a simplified version - real implementation would properly encode the public key
return "test-public-key-n-value"
}
func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": "test-key-id",
"use": "sig",
"alg": "RS256",
"n": encodePublicKey(t, publicKey),
"e": "AQAB",
},
},
}
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid_configuration":
config := map[string]interface{}{
"issuer": "http://" + r.Host,
"jwks_uri": "http://" + r.Host + "/jwks",
}
json.NewEncoder(w).Encode(config)
case "/jwks":
json.NewEncoder(w).Encode(jwks)
default:
http.NotFound(w, r)
}
}))
}

224
weed/iam/providers/provider.go

@ -0,0 +1,224 @@
package providers
import (
"context"
"fmt"
"net/mail"
"strings"
"time"
)
// IdentityProvider defines the interface for external identity providers
type IdentityProvider interface {
// Name returns the unique name of the provider
Name() string
// Initialize initializes the provider with configuration
Initialize(config interface{}) error
// Authenticate authenticates a user with a token and returns external identity
Authenticate(ctx context.Context, token string) (*ExternalIdentity, error)
// GetUserInfo retrieves user information by user ID
GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error)
// ValidateToken validates a token and returns claims
ValidateToken(ctx context.Context, token string) (*TokenClaims, error)
}
// ExternalIdentity represents an identity from an external provider
type ExternalIdentity struct {
// UserID is the unique identifier from the external provider
UserID string `json:"userId"`
// Email is the user's email address
Email string `json:"email"`
// DisplayName is the user's display name
DisplayName string `json:"displayName"`
// Groups are the groups the user belongs to
Groups []string `json:"groups,omitempty"`
// Attributes are additional user attributes
Attributes map[string]string `json:"attributes,omitempty"`
// Provider is the name of the identity provider
Provider string `json:"provider"`
}
// Validate validates the external identity structure
func (e *ExternalIdentity) Validate() error {
if e.UserID == "" {
return fmt.Errorf("user ID is required")
}
if e.Provider == "" {
return fmt.Errorf("provider is required")
}
if e.Email != "" {
if _, err := mail.ParseAddress(e.Email); err != nil {
return fmt.Errorf("invalid email format: %w", err)
}
}
return nil
}
// TokenClaims represents claims from a validated token
type TokenClaims struct {
// Subject (sub) - user identifier
Subject string `json:"sub"`
// Issuer (iss) - token issuer
Issuer string `json:"iss"`
// Audience (aud) - intended audience
Audience string `json:"aud"`
// ExpiresAt (exp) - expiration time
ExpiresAt time.Time `json:"exp"`
// IssuedAt (iat) - issued at time
IssuedAt time.Time `json:"iat"`
// NotBefore (nbf) - not valid before time
NotBefore time.Time `json:"nbf,omitempty"`
// Claims are additional claims from the token
Claims map[string]interface{} `json:"claims,omitempty"`
}
// IsValid checks if the token claims are valid (not expired, etc.)
func (c *TokenClaims) IsValid() bool {
now := time.Now()
// Check expiration
if !c.ExpiresAt.IsZero() && now.After(c.ExpiresAt) {
return false
}
// Check not before
if !c.NotBefore.IsZero() && now.Before(c.NotBefore) {
return false
}
// Check issued at (shouldn't be in the future)
if !c.IssuedAt.IsZero() && now.Before(c.IssuedAt) {
return false
}
return true
}
// GetClaimString returns a string claim value
func (c *TokenClaims) GetClaimString(key string) (string, bool) {
if value, exists := c.Claims[key]; exists {
if str, ok := value.(string); ok {
return str, true
}
}
return "", false
}
// GetClaimStringSlice returns a string slice claim value
func (c *TokenClaims) GetClaimStringSlice(key string) ([]string, bool) {
if value, exists := c.Claims[key]; exists {
switch v := value.(type) {
case []string:
return v, true
case []interface{}:
var result []string
for _, item := range v {
if str, ok := item.(string); ok {
result = append(result, str)
}
}
return result, len(result) > 0
case string:
// Single string can be treated as slice
return []string{v}, true
}
}
return nil, false
}
// ProviderConfig represents configuration for identity providers
type ProviderConfig struct {
// Type of provider (oidc, ldap, saml)
Type string `json:"type"`
// Name of the provider instance
Name string `json:"name"`
// Enabled indicates if the provider is active
Enabled bool `json:"enabled"`
// Config is provider-specific configuration
Config map[string]interface{} `json:"config"`
// RoleMapping defines how to map external identities to roles
RoleMapping *RoleMapping `json:"roleMapping,omitempty"`
}
// RoleMapping defines rules for mapping external identities to roles
type RoleMapping struct {
// Rules are the mapping rules
Rules []MappingRule `json:"rules"`
// DefaultRole is assigned if no rules match
DefaultRole string `json:"defaultRole,omitempty"`
}
// MappingRule defines a single mapping rule
type MappingRule struct {
// Claim is the claim key to check
Claim string `json:"claim"`
// Value is the expected claim value (supports wildcards)
Value string `json:"value"`
// Role is the role ARN to assign
Role string `json:"role"`
// Condition is additional condition logic (optional)
Condition string `json:"condition,omitempty"`
}
// Matches checks if a rule matches the given claims
func (r *MappingRule) Matches(claims *TokenClaims) bool {
if r.Claim == "" || r.Value == "" {
return false
}
claimValue, exists := claims.GetClaimString(r.Claim)
if !exists {
// Try as string slice
if claimSlice, sliceExists := claims.GetClaimStringSlice(r.Claim); sliceExists {
for _, val := range claimSlice {
if r.matchValue(val) {
return true
}
}
}
return false
}
return r.matchValue(claimValue)
}
// matchValue checks if a value matches the rule value (with wildcard support)
func (r *MappingRule) matchValue(value string) bool {
// Simple wildcard matching
if strings.Contains(r.Value, "*") {
// Convert wildcard to regex-like matching
pattern := strings.ReplaceAll(r.Value, "*", "")
if pattern == "" {
return true // "*" matches everything
}
return strings.Contains(value, pattern)
}
return value == r.Value
}

246
weed/iam/providers/provider_test.go

@ -0,0 +1,246 @@
package providers
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIdentityProviderInterface tests the core identity provider interface
func TestIdentityProviderInterface(t *testing.T) {
tests := []struct {
name string
provider IdentityProvider
wantErr bool
}{
// We'll add test cases as we implement providers
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test provider name
name := tt.provider.Name()
assert.NotEmpty(t, name, "Provider name should not be empty")
// Test initialization
err := tt.provider.Initialize(nil)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
// Test authentication with invalid token
ctx := context.Background()
_, err = tt.provider.Authenticate(ctx, "invalid-token")
assert.Error(t, err, "Should fail with invalid token")
})
}
}
// TestExternalIdentityValidation tests external identity structure validation
func TestExternalIdentityValidation(t *testing.T) {
tests := []struct {
name string
identity *ExternalIdentity
wantErr bool
}{
{
name: "valid identity",
identity: &ExternalIdentity{
UserID: "user123",
Email: "user@example.com",
DisplayName: "Test User",
Groups: []string{"group1", "group2"},
Attributes: map[string]string{"dept": "engineering"},
Provider: "test-provider",
},
wantErr: false,
},
{
name: "missing user id",
identity: &ExternalIdentity{
Email: "user@example.com",
Provider: "test-provider",
},
wantErr: true,
},
{
name: "missing provider",
identity: &ExternalIdentity{
UserID: "user123",
Email: "user@example.com",
},
wantErr: true,
},
{
name: "invalid email",
identity: &ExternalIdentity{
UserID: "user123",
Email: "invalid-email",
Provider: "test-provider",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.identity.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestTokenClaimsValidation tests token claims structure
func TestTokenClaimsValidation(t *testing.T) {
tests := []struct {
name string
claims *TokenClaims
valid bool
}{
{
name: "valid claims",
claims: &TokenClaims{
Subject: "user123",
Issuer: "https://provider.example.com",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now().Add(-time.Minute),
Claims: map[string]interface{}{"email": "user@example.com"},
},
valid: true,
},
{
name: "expired token",
claims: &TokenClaims{
Subject: "user123",
Issuer: "https://provider.example.com",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(-time.Hour), // Expired
IssuedAt: time.Now().Add(-time.Hour * 2),
Claims: map[string]interface{}{"email": "user@example.com"},
},
valid: false,
},
{
name: "future issued token",
claims: &TokenClaims{
Subject: "user123",
Issuer: "https://provider.example.com",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now().Add(time.Hour), // Future
Claims: map[string]interface{}{"email": "user@example.com"},
},
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
valid := tt.claims.IsValid()
assert.Equal(t, tt.valid, valid)
})
}
}
// TestProviderRegistry tests provider registration and discovery
func TestProviderRegistry(t *testing.T) {
// Clear registry for test
registry := NewProviderRegistry()
t.Run("register provider", func(t *testing.T) {
mockProvider := &MockProvider{name: "test-provider"}
err := registry.RegisterProvider(mockProvider)
assert.NoError(t, err)
// Test duplicate registration
err = registry.RegisterProvider(mockProvider)
assert.Error(t, err, "Should not allow duplicate registration")
})
t.Run("get provider", func(t *testing.T) {
provider, exists := registry.GetProvider("test-provider")
assert.True(t, exists)
assert.Equal(t, "test-provider", provider.Name())
// Test non-existent provider
_, exists = registry.GetProvider("non-existent")
assert.False(t, exists)
})
t.Run("list providers", func(t *testing.T) {
providers := registry.ListProviders()
assert.Len(t, providers, 1)
assert.Equal(t, "test-provider", providers[0])
})
}
// MockProvider for testing
type MockProvider struct {
name string
initialized bool
shouldError bool
}
func (m *MockProvider) Name() string {
return m.name
}
func (m *MockProvider) Initialize(config interface{}) error {
if m.shouldError {
return assert.AnError
}
m.initialized = true
return nil
}
func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) {
if !m.initialized {
return nil, assert.AnError
}
if token == "invalid-token" {
return nil, assert.AnError
}
return &ExternalIdentity{
UserID: "test-user",
Email: "test@example.com",
DisplayName: "Test User",
Provider: m.name,
}, nil
}
func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) {
if !m.initialized || userID == "" {
return nil, assert.AnError
}
return &ExternalIdentity{
UserID: userID,
Email: userID + "@example.com",
DisplayName: "User " + userID,
Provider: m.name,
}, nil
}
func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) {
if !m.initialized || token == "invalid-token" {
return nil, assert.AnError
}
return &TokenClaims{
Subject: "test-user",
Issuer: "test-issuer",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now(),
Claims: map[string]interface{}{"email": "test@example.com"},
}, nil
}

109
weed/iam/providers/registry.go

@ -0,0 +1,109 @@
package providers
import (
"fmt"
"sync"
)
// ProviderRegistry manages registered identity providers
type ProviderRegistry struct {
mu sync.RWMutex
providers map[string]IdentityProvider
}
// NewProviderRegistry creates a new provider registry
func NewProviderRegistry() *ProviderRegistry {
return &ProviderRegistry{
providers: make(map[string]IdentityProvider),
}
}
// RegisterProvider registers a new identity provider
func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error {
if provider == nil {
return fmt.Errorf("provider cannot be nil")
}
name := provider.Name()
if name == "" {
return fmt.Errorf("provider name cannot be empty")
}
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.providers[name]; exists {
return fmt.Errorf("provider %s is already registered", name)
}
r.providers[name] = provider
return nil
}
// GetProvider retrieves a provider by name
func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
provider, exists := r.providers[name]
return provider, exists
}
// ListProviders returns all registered provider names
func (r *ProviderRegistry) ListProviders() []string {
r.mu.RLock()
defer r.mu.RUnlock()
var names []string
for name := range r.providers {
names = append(names, name)
}
return names
}
// UnregisterProvider removes a provider from the registry
func (r *ProviderRegistry) UnregisterProvider(name string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.providers[name]; !exists {
return fmt.Errorf("provider %s is not registered", name)
}
delete(r.providers, name)
return nil
}
// Clear removes all providers from the registry
func (r *ProviderRegistry) Clear() {
r.mu.Lock()
defer r.mu.Unlock()
r.providers = make(map[string]IdentityProvider)
}
// GetProviderCount returns the number of registered providers
func (r *ProviderRegistry) GetProviderCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.providers)
}
// Default global registry
var defaultRegistry = NewProviderRegistry()
// RegisterProvider registers a provider in the default registry
func RegisterProvider(provider IdentityProvider) error {
return defaultRegistry.RegisterProvider(provider)
}
// GetProvider retrieves a provider from the default registry
func GetProvider(name string) (IdentityProvider, bool) {
return defaultRegistry.GetProvider(name)
}
// ListProviders returns all provider names from the default registry
func ListProviders() []string {
return defaultRegistry.ListProviders()
}
Loading…
Cancel
Save