You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
460 lines
13 KiB
460 lines
13 KiB
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestOIDCProviderInitialization tests OIDC provider initialization
|
|
func TestOIDCProviderInitialization(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *OIDCConfig
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
config: &OIDCConfig{
|
|
Issuer: "https://accounts.google.com",
|
|
ClientID: "test-client-id",
|
|
JWKSUri: "https://www.googleapis.com/oauth2/v3/certs",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "missing issuer",
|
|
config: &OIDCConfig{
|
|
ClientID: "test-client-id",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing client id",
|
|
config: &OIDCConfig{
|
|
Issuer: "https://accounts.google.com",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid issuer url",
|
|
config: &OIDCConfig{
|
|
Issuer: "not-a-url",
|
|
ClientID: "test-client-id",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
provider := NewOIDCProvider("test-provider")
|
|
|
|
err := provider.Initialize(tt.config)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "test-provider", provider.Name())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestOIDCProviderJWTValidation tests JWT token validation
|
|
func TestOIDCProviderJWTValidation(t *testing.T) {
|
|
// Set up test server with JWKS endpoint
|
|
privateKey, publicKey := generateTestKeys(t)
|
|
|
|
jwks := map[string]interface{}{
|
|
"keys": []map[string]interface{}{
|
|
{
|
|
"kty": "RSA",
|
|
"kid": "test-key-id",
|
|
"use": "sig",
|
|
"alg": "RS256",
|
|
"n": encodePublicKey(t, publicKey),
|
|
"e": "AQAB",
|
|
},
|
|
},
|
|
}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/.well-known/openid_configuration" {
|
|
config := map[string]interface{}{
|
|
"issuer": "http://" + r.Host,
|
|
"jwks_uri": "http://" + r.Host + "/jwks",
|
|
}
|
|
json.NewEncoder(w).Encode(config)
|
|
} else if r.URL.Path == "/jwks" {
|
|
json.NewEncoder(w).Encode(jwks)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider := NewOIDCProvider("test-oidc")
|
|
config := &OIDCConfig{
|
|
Issuer: server.URL,
|
|
ClientID: "test-client",
|
|
JWKSUri: server.URL + "/jwks",
|
|
}
|
|
|
|
err := provider.Initialize(config)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("valid token", func(t *testing.T) {
|
|
// Create valid JWT token
|
|
token := createTestJWT(t, privateKey, jwt.MapClaims{
|
|
"iss": server.URL,
|
|
"aud": "test-client",
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"email": "user@example.com",
|
|
"name": "Test User",
|
|
})
|
|
|
|
claims, err := provider.ValidateToken(context.Background(), token)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, claims)
|
|
assert.Equal(t, "user123", claims.Subject)
|
|
assert.Equal(t, server.URL, claims.Issuer)
|
|
|
|
email, exists := claims.GetClaimString("email")
|
|
assert.True(t, exists)
|
|
assert.Equal(t, "user@example.com", email)
|
|
})
|
|
|
|
t.Run("valid token with array audience", func(t *testing.T) {
|
|
// Create valid JWT token with audience as an array (per RFC 7519)
|
|
token := createTestJWT(t, privateKey, jwt.MapClaims{
|
|
"iss": server.URL,
|
|
"aud": []string{"test-client", "another-client"},
|
|
"sub": "user456",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"email": "user2@example.com",
|
|
"name": "Test User 2",
|
|
})
|
|
|
|
claims, err := provider.ValidateToken(context.Background(), token)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, claims)
|
|
assert.Equal(t, "user456", claims.Subject)
|
|
assert.Equal(t, server.URL, claims.Issuer)
|
|
|
|
email, exists := claims.GetClaimString("email")
|
|
assert.True(t, exists)
|
|
assert.Equal(t, "user2@example.com", email)
|
|
})
|
|
|
|
t.Run("expired token", func(t *testing.T) {
|
|
// Create expired JWT token
|
|
token := createTestJWT(t, privateKey, jwt.MapClaims{
|
|
"iss": server.URL,
|
|
"aud": "test-client",
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(-time.Hour).Unix(), // Expired
|
|
"iat": time.Now().Add(-time.Hour * 2).Unix(),
|
|
})
|
|
|
|
_, err := provider.ValidateToken(context.Background(), token)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "expired")
|
|
})
|
|
|
|
t.Run("invalid signature", func(t *testing.T) {
|
|
// Create token with wrong key
|
|
wrongKey, _ := generateTestKeys(t)
|
|
token := createTestJWT(t, wrongKey, jwt.MapClaims{
|
|
"iss": server.URL,
|
|
"aud": "test-client",
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
})
|
|
|
|
_, err := provider.ValidateToken(context.Background(), token)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
// TestOIDCProviderAuthentication tests authentication flow
|
|
func TestOIDCProviderAuthentication(t *testing.T) {
|
|
// Set up test OIDC provider
|
|
privateKey, publicKey := generateTestKeys(t)
|
|
|
|
server := setupOIDCTestServer(t, publicKey)
|
|
defer server.Close()
|
|
|
|
provider := NewOIDCProvider("test-oidc")
|
|
config := &OIDCConfig{
|
|
Issuer: server.URL,
|
|
ClientID: "test-client",
|
|
JWKSUri: server.URL + "/jwks",
|
|
RoleMapping: &providers.RoleMapping{
|
|
Rules: []providers.MappingRule{
|
|
{
|
|
Claim: "email",
|
|
Value: "*@example.com",
|
|
Role: "arn:seaweed:iam::role/UserRole",
|
|
},
|
|
{
|
|
Claim: "groups",
|
|
Value: "admins",
|
|
Role: "arn:seaweed:iam::role/AdminRole",
|
|
},
|
|
},
|
|
DefaultRole: "arn:seaweed:iam::role/GuestRole",
|
|
},
|
|
}
|
|
|
|
err := provider.Initialize(config)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("successful authentication", func(t *testing.T) {
|
|
token := createTestJWT(t, privateKey, jwt.MapClaims{
|
|
"iss": server.URL,
|
|
"aud": "test-client",
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"email": "user@example.com",
|
|
"name": "Test User",
|
|
"groups": []string{"users", "developers"},
|
|
})
|
|
|
|
identity, err := provider.Authenticate(context.Background(), token)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, identity)
|
|
assert.Equal(t, "user123", identity.UserID)
|
|
assert.Equal(t, "user@example.com", identity.Email)
|
|
assert.Equal(t, "Test User", identity.DisplayName)
|
|
assert.Equal(t, "test-oidc", identity.Provider)
|
|
assert.Contains(t, identity.Groups, "users")
|
|
assert.Contains(t, identity.Groups, "developers")
|
|
})
|
|
|
|
t.Run("authentication with invalid token", func(t *testing.T) {
|
|
_, err := provider.Authenticate(context.Background(), "invalid-token")
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
// TestOIDCProviderUserInfo tests user info retrieval
|
|
func TestOIDCProviderUserInfo(t *testing.T) {
|
|
// Set up test server with UserInfo endpoint
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/userinfo" {
|
|
// Check for Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error": "unauthorized"}`))
|
|
return
|
|
}
|
|
|
|
accessToken := strings.TrimPrefix(authHeader, "Bearer ")
|
|
|
|
// Return 401 for explicitly invalid tokens
|
|
if accessToken == "invalid-token" {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error": "invalid_token"}`))
|
|
return
|
|
}
|
|
|
|
// Mock user info response
|
|
userInfo := map[string]interface{}{
|
|
"sub": "user123",
|
|
"email": "user@example.com",
|
|
"name": "Test User",
|
|
"groups": []string{"users", "developers"},
|
|
}
|
|
|
|
// Customize response based on token
|
|
if strings.Contains(accessToken, "admin") {
|
|
userInfo["groups"] = []string{"admins"}
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(userInfo)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider := NewOIDCProvider("test-oidc")
|
|
config := &OIDCConfig{
|
|
Issuer: server.URL,
|
|
ClientID: "test-client",
|
|
UserInfoUri: server.URL + "/userinfo",
|
|
}
|
|
|
|
err := provider.Initialize(config)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("get user info with access token", func(t *testing.T) {
|
|
// Test using access token (real UserInfo endpoint call)
|
|
identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, identity)
|
|
assert.Equal(t, "user123", identity.UserID)
|
|
assert.Equal(t, "user@example.com", identity.Email)
|
|
assert.Equal(t, "Test User", identity.DisplayName)
|
|
assert.Contains(t, identity.Groups, "users")
|
|
assert.Contains(t, identity.Groups, "developers")
|
|
assert.Equal(t, "test-oidc", identity.Provider)
|
|
})
|
|
|
|
t.Run("get admin user info", func(t *testing.T) {
|
|
// Test admin token response
|
|
identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, identity)
|
|
assert.Equal(t, "user123", identity.UserID)
|
|
assert.Contains(t, identity.Groups, "admins")
|
|
})
|
|
|
|
t.Run("get user info without token", func(t *testing.T) {
|
|
// Test without access token (should fail)
|
|
_, err := provider.GetUserInfoWithToken(context.Background(), "")
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "access token cannot be empty")
|
|
})
|
|
|
|
t.Run("get user info with invalid token", func(t *testing.T) {
|
|
// Test with invalid access token (should get 401)
|
|
_, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token")
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401")
|
|
})
|
|
|
|
t.Run("get user info with custom claims mapping", func(t *testing.T) {
|
|
// Create provider with custom claims mapping
|
|
customProvider := NewOIDCProvider("test-custom-oidc")
|
|
customConfig := &OIDCConfig{
|
|
Issuer: server.URL,
|
|
ClientID: "test-client",
|
|
UserInfoUri: server.URL + "/userinfo",
|
|
ClaimsMapping: map[string]string{
|
|
"customEmail": "email",
|
|
"customName": "name",
|
|
},
|
|
}
|
|
|
|
err := customProvider.Initialize(customConfig)
|
|
require.NoError(t, err)
|
|
|
|
identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, identity)
|
|
|
|
// Standard claims should still work
|
|
assert.Equal(t, "user123", identity.UserID)
|
|
assert.Equal(t, "user@example.com", identity.Email)
|
|
assert.Equal(t, "Test User", identity.DisplayName)
|
|
})
|
|
|
|
t.Run("get user info with empty id", func(t *testing.T) {
|
|
_, err := provider.GetUserInfo(context.Background(), "")
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
// Helper functions for testing
|
|
|
|
func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
return privateKey, &privateKey.PublicKey
|
|
}
|
|
|
|
func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
token.Header["kid"] = "test-key-id"
|
|
|
|
tokenString, err := token.SignedString(privateKey)
|
|
require.NoError(t, err)
|
|
return tokenString
|
|
}
|
|
|
|
func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
|
|
// Properly encode the RSA modulus (N) as base64url
|
|
return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes())
|
|
}
|
|
|
|
func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
|
|
jwks := map[string]interface{}{
|
|
"keys": []map[string]interface{}{
|
|
{
|
|
"kty": "RSA",
|
|
"kid": "test-key-id",
|
|
"use": "sig",
|
|
"alg": "RS256",
|
|
"n": encodePublicKey(t, publicKey),
|
|
"e": "AQAB",
|
|
},
|
|
},
|
|
}
|
|
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/.well-known/openid_configuration":
|
|
config := map[string]interface{}{
|
|
"issuer": "http://" + r.Host,
|
|
"jwks_uri": "http://" + r.Host + "/jwks",
|
|
"userinfo_endpoint": "http://" + r.Host + "/userinfo",
|
|
}
|
|
json.NewEncoder(w).Encode(config)
|
|
case "/jwks":
|
|
json.NewEncoder(w).Encode(jwks)
|
|
case "/userinfo":
|
|
// Mock UserInfo endpoint
|
|
authHeader := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error": "unauthorized"}`))
|
|
return
|
|
}
|
|
|
|
accessToken := strings.TrimPrefix(authHeader, "Bearer ")
|
|
|
|
// Return 401 for explicitly invalid tokens
|
|
if accessToken == "invalid-token" {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error": "invalid_token"}`))
|
|
return
|
|
}
|
|
|
|
// Mock user info response based on access token
|
|
userInfo := map[string]interface{}{
|
|
"sub": "user123",
|
|
"email": "user@example.com",
|
|
"name": "Test User",
|
|
"groups": []string{"users", "developers"},
|
|
}
|
|
|
|
// Customize response based on token
|
|
if strings.Contains(accessToken, "admin") {
|
|
userInfo["groups"] = []string{"admins"}
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(userInfo)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
}
|