package iam import ( "context" "crypto/rand" "crypto/rsa" "encoding/base64" "fmt" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) const ( TestS3Endpoint = "http://localhost:8333" TestRegion = "us-west-2" ) // S3IAMTestFramework provides utilities for S3+IAM integration testing type S3IAMTestFramework struct { t *testing.T mockOIDC *httptest.Server privateKey *rsa.PrivateKey publicKey *rsa.PublicKey createdBuckets []string ctx context.Context } // NewS3IAMTestFramework creates a new test framework instance func NewS3IAMTestFramework(t *testing.T) *S3IAMTestFramework { framework := &S3IAMTestFramework{ t: t, ctx: context.Background(), createdBuckets: make([]string, 0), } // Generate RSA keys for JWT signing var err error framework.privateKey, err = rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) framework.publicKey = &framework.privateKey.PublicKey // Setup mock OIDC server framework.setupMockOIDCServer() return framework } // setupMockOIDCServer creates a mock OIDC server for testing func (f *S3IAMTestFramework) setupMockOIDCServer() { f.mockOIDC = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/.well-known/openid_configuration": config := map[string]interface{}{ "issuer": "http://" + r.Host, "jwks_uri": "http://" + r.Host + "/jwks", "userinfo_endpoint": "http://" + r.Host + "/userinfo", } w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s", "userinfo_endpoint": "%s" }`, config["issuer"], config["jwks_uri"], config["userinfo_endpoint"]) case "/jwks": w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{ "keys": [ { "kty": "RSA", "kid": "test-key-id", "use": "sig", "alg": "RS256", "n": "%s", "e": "AQAB" } ] }`, f.encodePublicKey()) case "/userinfo": authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authHeader, "Bearer ") { w.WriteHeader(http.StatusUnauthorized) return } token := strings.TrimPrefix(authHeader, "Bearer ") userInfo := map[string]interface{}{ "sub": "test-user", "email": "test@example.com", "name": "Test User", "groups": []string{"users", "developers"}, } if strings.Contains(token, "admin") { userInfo["groups"] = []string{"admins"} } w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{ "sub": "%s", "email": "%s", "name": "%s", "groups": %v }`, userInfo["sub"], userInfo["email"], userInfo["name"], userInfo["groups"]) default: http.NotFound(w, r) } })) } // encodePublicKey encodes the RSA public key for JWKS func (f *S3IAMTestFramework) encodePublicKey() string { return base64.RawURLEncoding.EncodeToString(f.publicKey.N.Bytes()) } // BearerTokenTransport is an HTTP transport that adds Bearer token authentication type BearerTokenTransport struct { Transport http.RoundTripper Token string } // RoundTrip implements the http.RoundTripper interface func (t *BearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Clone the request to avoid modifying the original newReq := req.Clone(req.Context()) // Add Bearer token authorization header newReq.Header.Set("Authorization", "Bearer "+t.Token) // Remove AWS signature headers if present (they conflict with Bearer auth) newReq.Header.Del("X-Amz-Date") newReq.Header.Del("X-Amz-Content-Sha256") newReq.Header.Del("X-Amz-Signature") newReq.Header.Del("X-Amz-Algorithm") newReq.Header.Del("X-Amz-Credential") newReq.Header.Del("X-Amz-SignedHeaders") // Use underlying transport transport := t.Transport if transport == nil { transport = http.DefaultTransport } return transport.RoundTrip(newReq) } // generateSTSSessionToken creates a session token using the actual STS service for proper validation func (f *S3IAMTestFramework) generateSTSSessionToken(username, roleName string, validDuration time.Duration) (string, error) { // For now, simulate what the STS service would return by calling AssumeRoleWithWebIdentity // In a real test, we'd make an actual HTTP call to the STS endpoint // But for unit testing, we'll create a realistic JWT manually that will pass validation now := time.Now() signingKeyB64 := "dGVzdC1zaWduaW5nLWtleS0zMi1jaGFyYWN0ZXJzLWxvbmc=" signingKey, err := base64.StdEncoding.DecodeString(signingKeyB64) if err != nil { return "", fmt.Errorf("failed to decode signing key: %v", err) } // Generate a session ID that would be created by the STS service sessionId := fmt.Sprintf("test-session-%s-%s-%d", username, roleName, now.Unix()) // Create session token claims exactly as TokenGenerator does sessionClaims := jwt.MapClaims{ "iss": "seaweedfs-sts", "sub": sessionId, "iat": now.Unix(), "exp": now.Add(validDuration).Unix(), "token_type": "session", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, sessionClaims) tokenString, err := token.SignedString(signingKey) if err != nil { return "", err } // Note: In a real implementation, we would also create the session in the STS session store // For now, we'll rely on the fact that if JWT validation passes, the session should be considered valid // This is a limitation of our current testing approach return tokenString, nil } // CreateS3ClientWithJWT creates an S3 client authenticated with a JWT token for the specified role func (f *S3IAMTestFramework) CreateS3ClientWithJWT(username, roleName string) (*s3.S3, error) { // Generate STS session token (not OIDC token) token, err := f.generateSTSSessionToken(username, roleName, time.Hour) if err != nil { return nil, fmt.Errorf("failed to generate STS session token: %v", err) } // Create custom HTTP client with Bearer token transport httpClient := &http.Client{ Transport: &BearerTokenTransport{ Token: token, }, } sess, err := session.NewSession(&aws.Config{ Region: aws.String(TestRegion), Endpoint: aws.String(TestS3Endpoint), HTTPClient: httpClient, // Use anonymous credentials to avoid AWS signature generation Credentials: credentials.AnonymousCredentials, DisableSSL: aws.Bool(true), S3ForcePathStyle: aws.Bool(true), }) if err != nil { return nil, fmt.Errorf("failed to create AWS session: %v", err) } return s3.New(sess), nil } // CreateS3ClientWithInvalidJWT creates an S3 client with an invalid JWT token func (f *S3IAMTestFramework) CreateS3ClientWithInvalidJWT() (*s3.S3, error) { invalidToken := "invalid.jwt.token" // Create custom HTTP client with Bearer token transport httpClient := &http.Client{ Transport: &BearerTokenTransport{ Token: invalidToken, }, } sess, err := session.NewSession(&aws.Config{ Region: aws.String(TestRegion), Endpoint: aws.String(TestS3Endpoint), HTTPClient: httpClient, // Use anonymous credentials to avoid AWS signature generation Credentials: credentials.AnonymousCredentials, DisableSSL: aws.Bool(true), S3ForcePathStyle: aws.Bool(true), }) if err != nil { return nil, fmt.Errorf("failed to create AWS session: %v", err) } return s3.New(sess), nil } // CreateS3ClientWithExpiredJWT creates an S3 client with an expired JWT token func (f *S3IAMTestFramework) CreateS3ClientWithExpiredJWT(username, roleName string) (*s3.S3, error) { // Generate expired STS session token (expired 1 hour ago) token, err := f.generateSTSSessionToken(username, roleName, -time.Hour) if err != nil { return nil, fmt.Errorf("failed to generate expired STS session token: %v", err) } // Create custom HTTP client with Bearer token transport httpClient := &http.Client{ Transport: &BearerTokenTransport{ Token: token, }, } sess, err := session.NewSession(&aws.Config{ Region: aws.String(TestRegion), Endpoint: aws.String(TestS3Endpoint), HTTPClient: httpClient, // Use anonymous credentials to avoid AWS signature generation Credentials: credentials.AnonymousCredentials, DisableSSL: aws.Bool(true), S3ForcePathStyle: aws.Bool(true), }) if err != nil { return nil, fmt.Errorf("failed to create AWS session: %v", err) } return s3.New(sess), nil } // CreateS3ClientWithSessionToken creates an S3 client with a session token func (f *S3IAMTestFramework) CreateS3ClientWithSessionToken(sessionToken string) (*s3.S3, error) { sess, err := session.NewSession(&aws.Config{ Region: aws.String(TestRegion), Endpoint: aws.String(TestS3Endpoint), Credentials: credentials.NewStaticCredentials( "session-access-key", "session-secret-key", sessionToken, ), DisableSSL: aws.Bool(true), S3ForcePathStyle: aws.Bool(true), }) if err != nil { return nil, fmt.Errorf("failed to create AWS session: %v", err) } return s3.New(sess), nil } // generateJWTToken creates a JWT token for testing func (f *S3IAMTestFramework) generateJWTToken(username, roleName string, validDuration time.Duration) (string, error) { now := time.Now() claims := jwt.MapClaims{ "sub": username, "iss": f.mockOIDC.URL, "aud": "test-client", "exp": now.Add(validDuration).Unix(), "iat": now.Unix(), "email": username + "@example.com", "name": strings.Title(username), } // Add role-specific groups switch roleName { case "TestAdminRole": claims["groups"] = []string{"admins"} case "TestReadOnlyRole": claims["groups"] = []string{"users"} case "TestWriteOnlyRole": claims["groups"] = []string{"writers"} default: claims["groups"] = []string{"users"} } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = "test-key-id" tokenString, err := token.SignedString(f.privateKey) if err != nil { return "", fmt.Errorf("failed to sign token: %v", err) } return tokenString, nil } // CreateShortLivedSessionToken creates a mock session token for testing func (f *S3IAMTestFramework) CreateShortLivedSessionToken(username, roleName string, durationSeconds int64) (string, error) { // For testing purposes, create a mock session token // In reality, this would be generated by the STS service return fmt.Sprintf("mock-session-token-%s-%s-%d", username, roleName, time.Now().Unix()), nil } // ExpireSessionForTesting simulates session expiration for testing func (f *S3IAMTestFramework) ExpireSessionForTesting(sessionToken string) error { // For integration tests, this would typically involve calling the STS service // For now, we just simulate success since the actual expiration will be handled by SeaweedFS return nil } // CreateBucket creates a bucket and tracks it for cleanup func (f *S3IAMTestFramework) CreateBucket(s3Client *s3.S3, bucketName string) error { _, err := s3Client.CreateBucket(&s3.CreateBucketInput{ Bucket: aws.String(bucketName), }) if err != nil { return err } // Track bucket for cleanup f.createdBuckets = append(f.createdBuckets, bucketName) return nil } // Cleanup cleans up test resources func (f *S3IAMTestFramework) Cleanup() { // Clean up buckets (best effort) if len(f.createdBuckets) > 0 { // Create admin client for cleanup adminClient, err := f.CreateS3ClientWithJWT("admin-user", "TestAdminRole") if err == nil { for _, bucket := range f.createdBuckets { // Try to empty bucket first listResult, err := adminClient.ListObjects(&s3.ListObjectsInput{ Bucket: aws.String(bucket), }) if err == nil { for _, obj := range listResult.Contents { adminClient.DeleteObject(&s3.DeleteObjectInput{ Bucket: aws.String(bucket), Key: obj.Key, }) } } // Delete bucket adminClient.DeleteBucket(&s3.DeleteBucketInput{ Bucket: aws.String(bucket), }) } } } // Close mock OIDC server if f.mockOIDC != nil { f.mockOIDC.Close() } } // WaitForS3Service waits for the S3 service to be available func (f *S3IAMTestFramework) WaitForS3Service() error { // Create a basic S3 client sess, err := session.NewSession(&aws.Config{ Region: aws.String(TestRegion), Endpoint: aws.String(TestS3Endpoint), Credentials: credentials.NewStaticCredentials( "test-access-key", "test-secret-key", "", ), DisableSSL: aws.Bool(true), S3ForcePathStyle: aws.Bool(true), }) if err != nil { return fmt.Errorf("failed to create AWS session: %v", err) } s3Client := s3.New(sess) // Try to list buckets to check if service is available maxRetries := 30 for i := 0; i < maxRetries; i++ { _, err := s3Client.ListBuckets(&s3.ListBucketsInput{}) if err == nil { return nil } time.Sleep(1 * time.Second) } return fmt.Errorf("S3 service not available after %d retries", maxRetries) } // WaitForS3Service waits for the S3 service to be available (simplified version) func (f *S3IAMTestFramework) WaitForS3ServiceSimple() error { // This is a simplified version that just checks if the endpoint responds // The full implementation would be in the Makefile's wait-for-services target return nil }