You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

460 lines
13 KiB

package oidc
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestOIDCProviderInitialization tests OIDC provider initialization
func TestOIDCProviderInitialization(t *testing.T) {
tests := []struct {
name string
config *OIDCConfig
wantErr bool
}{
{
name: "valid config",
config: &OIDCConfig{
Issuer: "https://accounts.google.com",
ClientID: "test-client-id",
JWKSUri: "https://www.googleapis.com/oauth2/v3/certs",
},
wantErr: false,
},
{
name: "missing issuer",
config: &OIDCConfig{
ClientID: "test-client-id",
},
wantErr: true,
},
{
name: "missing client id",
config: &OIDCConfig{
Issuer: "https://accounts.google.com",
},
wantErr: true,
},
{
name: "invalid issuer url",
config: &OIDCConfig{
Issuer: "not-a-url",
ClientID: "test-client-id",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := NewOIDCProvider("test-provider")
err := provider.Initialize(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, "test-provider", provider.Name())
}
})
}
}
// TestOIDCProviderJWTValidation tests JWT token validation
func TestOIDCProviderJWTValidation(t *testing.T) {
// Set up test server with JWKS endpoint
privateKey, publicKey := generateTestKeys(t)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": "test-key-id",
"use": "sig",
"alg": "RS256",
"n": encodePublicKey(t, publicKey),
"e": "AQAB",
},
},
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid_configuration" {
config := map[string]interface{}{
"issuer": "http://" + r.Host,
"jwks_uri": "http://" + r.Host + "/jwks",
}
json.NewEncoder(w).Encode(config)
} else if r.URL.Path == "/jwks" {
json.NewEncoder(w).Encode(jwks)
}
}))
defer server.Close()
provider := NewOIDCProvider("test-oidc")
config := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
JWKSUri: server.URL + "/jwks",
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("valid token", func(t *testing.T) {
// Create valid JWT token
token := createTestJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"email": "user@example.com",
"name": "Test User",
})
claims, err := provider.ValidateToken(context.Background(), token)
require.NoError(t, err)
require.NotNil(t, claims)
assert.Equal(t, "user123", claims.Subject)
assert.Equal(t, server.URL, claims.Issuer)
email, exists := claims.GetClaimString("email")
assert.True(t, exists)
assert.Equal(t, "user@example.com", email)
})
t.Run("valid token with array audience", func(t *testing.T) {
// Create valid JWT token with audience as an array (per RFC 7519)
token := createTestJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": []string{"test-client", "another-client"},
"sub": "user456",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"email": "user2@example.com",
"name": "Test User 2",
})
claims, err := provider.ValidateToken(context.Background(), token)
require.NoError(t, err)
require.NotNil(t, claims)
assert.Equal(t, "user456", claims.Subject)
assert.Equal(t, server.URL, claims.Issuer)
email, exists := claims.GetClaimString("email")
assert.True(t, exists)
assert.Equal(t, "user2@example.com", email)
})
t.Run("expired token", func(t *testing.T) {
// Create expired JWT token
token := createTestJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(-time.Hour).Unix(), // Expired
"iat": time.Now().Add(-time.Hour * 2).Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
})
t.Run("invalid signature", func(t *testing.T) {
// Create token with wrong key
wrongKey, _ := generateTestKeys(t)
token := createTestJWT(t, wrongKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
_, err := provider.ValidateToken(context.Background(), token)
assert.Error(t, err)
})
}
// TestOIDCProviderAuthentication tests authentication flow
func TestOIDCProviderAuthentication(t *testing.T) {
// Set up test OIDC provider
privateKey, publicKey := generateTestKeys(t)
server := setupOIDCTestServer(t, publicKey)
defer server.Close()
provider := NewOIDCProvider("test-oidc")
config := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
JWKSUri: server.URL + "/jwks",
RoleMapping: &providers.RoleMapping{
Rules: []providers.MappingRule{
{
Claim: "email",
Value: "*@example.com",
Role: "arn:seaweed:iam::role/UserRole",
},
{
Claim: "groups",
Value: "admins",
Role: "arn:seaweed:iam::role/AdminRole",
},
},
DefaultRole: "arn:seaweed:iam::role/GuestRole",
},
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("successful authentication", func(t *testing.T) {
token := createTestJWT(t, privateKey, jwt.MapClaims{
"iss": server.URL,
"aud": "test-client",
"sub": "user123",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"email": "user@example.com",
"name": "Test User",
"groups": []string{"users", "developers"},
})
identity, err := provider.Authenticate(context.Background(), token)
require.NoError(t, err)
require.NotNil(t, identity)
assert.Equal(t, "user123", identity.UserID)
assert.Equal(t, "user@example.com", identity.Email)
assert.Equal(t, "Test User", identity.DisplayName)
assert.Equal(t, "test-oidc", identity.Provider)
assert.Contains(t, identity.Groups, "users")
assert.Contains(t, identity.Groups, "developers")
})
t.Run("authentication with invalid token", func(t *testing.T) {
_, err := provider.Authenticate(context.Background(), "invalid-token")
assert.Error(t, err)
})
}
// TestOIDCProviderUserInfo tests user info retrieval
func TestOIDCProviderUserInfo(t *testing.T) {
// Set up test server with UserInfo endpoint
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/userinfo" {
// Check for Authorization header
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, "Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error": "unauthorized"}`))
return
}
accessToken := strings.TrimPrefix(authHeader, "Bearer ")
// Return 401 for explicitly invalid tokens
if accessToken == "invalid-token" {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error": "invalid_token"}`))
return
}
// Mock user info response
userInfo := map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
"name": "Test User",
"groups": []string{"users", "developers"},
}
// Customize response based on token
if strings.Contains(accessToken, "admin") {
userInfo["groups"] = []string{"admins"}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(userInfo)
}
}))
defer server.Close()
provider := NewOIDCProvider("test-oidc")
config := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
UserInfoUri: server.URL + "/userinfo",
}
err := provider.Initialize(config)
require.NoError(t, err)
t.Run("get user info with access token", func(t *testing.T) {
// Test using access token (real UserInfo endpoint call)
identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token")
require.NoError(t, err)
require.NotNil(t, identity)
assert.Equal(t, "user123", identity.UserID)
assert.Equal(t, "user@example.com", identity.Email)
assert.Equal(t, "Test User", identity.DisplayName)
assert.Contains(t, identity.Groups, "users")
assert.Contains(t, identity.Groups, "developers")
assert.Equal(t, "test-oidc", identity.Provider)
})
t.Run("get admin user info", func(t *testing.T) {
// Test admin token response
identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token")
require.NoError(t, err)
require.NotNil(t, identity)
assert.Equal(t, "user123", identity.UserID)
assert.Contains(t, identity.Groups, "admins")
})
t.Run("get user info without token", func(t *testing.T) {
// Test without access token (should fail)
_, err := provider.GetUserInfoWithToken(context.Background(), "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "access token cannot be empty")
})
t.Run("get user info with invalid token", func(t *testing.T) {
// Test with invalid access token (should get 401)
_, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token")
assert.Error(t, err)
assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401")
})
t.Run("get user info with custom claims mapping", func(t *testing.T) {
// Create provider with custom claims mapping
customProvider := NewOIDCProvider("test-custom-oidc")
customConfig := &OIDCConfig{
Issuer: server.URL,
ClientID: "test-client",
UserInfoUri: server.URL + "/userinfo",
ClaimsMapping: map[string]string{
"customEmail": "email",
"customName": "name",
},
}
err := customProvider.Initialize(customConfig)
require.NoError(t, err)
identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token")
require.NoError(t, err)
require.NotNil(t, identity)
// Standard claims should still work
assert.Equal(t, "user123", identity.UserID)
assert.Equal(t, "user@example.com", identity.Email)
assert.Equal(t, "Test User", identity.DisplayName)
})
t.Run("get user info with empty id", func(t *testing.T) {
_, err := provider.GetUserInfo(context.Background(), "")
assert.Error(t, err)
})
}
// Helper functions for testing
func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
return privateKey, &privateKey.PublicKey
}
func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
// Properly encode the RSA modulus (N) as base64url
return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes())
}
func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": "test-key-id",
"use": "sig",
"alg": "RS256",
"n": encodePublicKey(t, publicKey),
"e": "AQAB",
},
},
}
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid_configuration":
config := map[string]interface{}{
"issuer": "http://" + r.Host,
"jwks_uri": "http://" + r.Host + "/jwks",
"userinfo_endpoint": "http://" + r.Host + "/userinfo",
}
json.NewEncoder(w).Encode(config)
case "/jwks":
json.NewEncoder(w).Encode(jwks)
case "/userinfo":
// Mock UserInfo endpoint
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, "Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error": "unauthorized"}`))
return
}
accessToken := strings.TrimPrefix(authHeader, "Bearer ")
// Return 401 for explicitly invalid tokens
if accessToken == "invalid-token" {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error": "invalid_token"}`))
return
}
// Mock user info response based on access token
userInfo := map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
"name": "Test User",
"groups": []string{"users", "developers"},
}
// Customize response based on token
if strings.Contains(accessToken, "admin") {
userInfo["groups"] = []string{"admins"}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(userInfo)
default:
http.NotFound(w, r)
}
}))
}