You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							400 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							400 lines
						
					
					
						
							12 KiB
						
					
					
				| package s3api | |
| 
 | |
| import ( | |
| 	"bytes" | |
| 	"fmt" | |
| 	"io" | |
| 	"net/http" | |
| 	"strings" | |
| 	"testing" | |
| 
 | |
| 	"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" | |
| ) | |
| 
 | |
| // TestSSECWrongKeyDecryption tests decryption with wrong SSE-C key | |
| func TestSSECWrongKeyDecryption(t *testing.T) { | |
| 	// Setup original key and encrypt data | |
| 	originalKey := GenerateTestSSECKey(1) | |
| 	testData := "Hello, SSE-C world!" | |
| 
 | |
| 	encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), &SSECustomerKey{ | |
| 		Algorithm: "AES256", | |
| 		Key:       originalKey.Key, | |
| 		KeyMD5:    originalKey.KeyMD5, | |
| 	}) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create encrypted reader: %v", err) | |
| 	} | |
| 
 | |
| 	// Read encrypted data | |
| 	encryptedData, err := io.ReadAll(encryptedReader) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to read encrypted data: %v", err) | |
| 	} | |
| 
 | |
| 	// Try to decrypt with wrong key | |
| 	wrongKey := GenerateTestSSECKey(2) // Different seed = different key | |
| 	decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), &SSECustomerKey{ | |
| 		Algorithm: "AES256", | |
| 		Key:       wrongKey.Key, | |
| 		KeyMD5:    wrongKey.KeyMD5, | |
| 	}, iv) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create decrypted reader: %v", err) | |
| 	} | |
| 
 | |
| 	// Read decrypted data - should be garbage/different from original | |
| 	decryptedData, err := io.ReadAll(decryptedReader) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to read decrypted data: %v", err) | |
| 	} | |
| 
 | |
| 	// Verify the decrypted data is NOT the same as original (wrong key used) | |
| 	if string(decryptedData) == testData { | |
| 		t.Error("Decryption with wrong key should not produce original data") | |
| 	} | |
| } | |
| 
 | |
| // TestSSEKMSKeyNotFound tests handling of missing KMS key | |
| func TestSSEKMSKeyNotFound(t *testing.T) { | |
| 	// Note: The local KMS provider creates keys on-demand by design. | |
| 	// This test validates that when on-demand creation fails or is disabled, | |
| 	// appropriate errors are returned. | |
|  | |
| 	// Test with an invalid key ID that would fail even on-demand creation | |
| 	invalidKeyID := "" // Empty key ID should fail | |
| 	encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) | |
| 
 | |
| 	_, _, err := CreateSSEKMSEncryptedReader(strings.NewReader("test data"), invalidKeyID, encryptionContext) | |
| 
 | |
| 	// Should get an error for invalid/empty key | |
| 	if err == nil { | |
| 		t.Error("Expected error for empty KMS key ID, got none") | |
| 	} | |
| 
 | |
| 	// For local KMS with on-demand creation, we test what we can realistically test | |
| 	if err != nil { | |
| 		t.Logf("Got expected error for empty key ID: %v", err) | |
| 	} | |
| } | |
| 
 | |
| // TestSSEHeadersWithoutEncryption tests inconsistent state where headers are present but no encryption | |
| func TestSSEHeadersWithoutEncryption(t *testing.T) { | |
| 	testCases := []struct { | |
| 		name     string | |
| 		setupReq func() *http.Request | |
| 	}{ | |
| 		{ | |
| 			name: "SSE-C algorithm without key", | |
| 			setupReq: func() *http.Request { | |
| 				req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") | |
| 				// Missing key and MD5 | |
| 				return req | |
| 			}, | |
| 		}, | |
| 		{ | |
| 			name: "SSE-C key without algorithm", | |
| 			setupReq: func() *http.Request { | |
| 				req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) | |
| 				keyPair := GenerateTestSSECKey(1) | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyPair.KeyB64) | |
| 				// Missing algorithm | |
| 				return req | |
| 			}, | |
| 		}, | |
| 		{ | |
| 			name: "SSE-KMS key ID without algorithm", | |
| 			setupReq: func() *http.Request { | |
| 				req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "test-key-id") | |
| 				// Missing algorithm | |
| 				return req | |
| 			}, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tc := range testCases { | |
| 		t.Run(tc.name, func(t *testing.T) { | |
| 			req := tc.setupReq() | |
| 
 | |
| 			// Validate headers - should catch incomplete configurations | |
| 			if strings.Contains(tc.name, "SSE-C") { | |
| 				err := ValidateSSECHeaders(req) | |
| 				if err == nil { | |
| 					t.Error("Expected validation error for incomplete SSE-C headers") | |
| 				} | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestSSECInvalidKeyFormats tests various invalid SSE-C key formats | |
| func TestSSECInvalidKeyFormats(t *testing.T) { | |
| 	testCases := []struct { | |
| 		name      string | |
| 		algorithm string | |
| 		key       string | |
| 		keyMD5    string | |
| 		expectErr bool | |
| 	}{ | |
| 		{ | |
| 			name:      "Invalid algorithm", | |
| 			algorithm: "AES128", | |
| 			key:       "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=", // 32 bytes base64 | |
| 			keyMD5:    "valid-md5-hash", | |
| 			expectErr: true, | |
| 		}, | |
| 		{ | |
| 			name:      "Invalid key length (too short)", | |
| 			algorithm: "AES256", | |
| 			key:       "c2hvcnRrZXk=", // "shortkey" base64 - too short | |
| 			keyMD5:    "valid-md5-hash", | |
| 			expectErr: true, | |
| 		}, | |
| 		{ | |
| 			name:      "Invalid key length (too long)", | |
| 			algorithm: "AES256", | |
| 			key:       "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleQ==", // too long | |
| 			keyMD5:    "valid-md5-hash", | |
| 			expectErr: true, | |
| 		}, | |
| 		{ | |
| 			name:      "Invalid base64 key", | |
| 			algorithm: "AES256", | |
| 			key:       "invalid-base64!", | |
| 			keyMD5:    "valid-md5-hash", | |
| 			expectErr: true, | |
| 		}, | |
| 		{ | |
| 			name:      "Invalid base64 MD5", | |
| 			algorithm: "AES256", | |
| 			key:       "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=", | |
| 			keyMD5:    "invalid-base64!", | |
| 			expectErr: true, | |
| 		}, | |
| 		{ | |
| 			name:      "Mismatched MD5", | |
| 			algorithm: "AES256", | |
| 			key:       "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=", | |
| 			keyMD5:    "d29uZy1tZDUtaGFzaA==", // "wrong-md5-hash" base64 | |
| 			expectErr: true, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tc := range testCases { | |
| 		t.Run(tc.name, func(t *testing.T) { | |
| 			req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) | |
| 			req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tc.algorithm) | |
| 			req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tc.key) | |
| 			req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tc.keyMD5) | |
| 
 | |
| 			err := ValidateSSECHeaders(req) | |
| 			if tc.expectErr && err == nil { | |
| 				t.Errorf("Expected error for %s, but got none", tc.name) | |
| 			} | |
| 			if !tc.expectErr && err != nil { | |
| 				t.Errorf("Expected no error for %s, but got: %v", tc.name, err) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestSSEKMSInvalidConfigurations tests various invalid SSE-KMS configurations | |
| func TestSSEKMSInvalidConfigurations(t *testing.T) { | |
| 	testCases := []struct { | |
| 		name         string | |
| 		setupRequest func() *http.Request | |
| 		expectError  bool | |
| 	}{ | |
| 		{ | |
| 			name: "Invalid algorithm", | |
| 			setupRequest: func() *http.Request { | |
| 				req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryption, "invalid-algorithm") | |
| 				return req | |
| 			}, | |
| 			expectError: true, | |
| 		}, | |
| 		{ | |
| 			name: "Empty key ID", | |
| 			setupRequest: func() *http.Request { | |
| 				req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryption, "aws:kms") | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "") | |
| 				return req | |
| 			}, | |
| 			expectError: false, // Empty key ID might be valid (use default) | |
| 		}, | |
| 		{ | |
| 			name: "Invalid key ID format", | |
| 			setupRequest: func() *http.Request { | |
| 				req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryption, "aws:kms") | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "invalid key id with spaces") | |
| 				return req | |
| 			}, | |
| 			expectError: true, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tc := range testCases { | |
| 		t.Run(tc.name, func(t *testing.T) { | |
| 			req := tc.setupRequest() | |
| 
 | |
| 			_, err := ParseSSEKMSHeaders(req) | |
| 			if tc.expectError && err == nil { | |
| 				t.Errorf("Expected error for %s, but got none", tc.name) | |
| 			} | |
| 			if !tc.expectError && err != nil { | |
| 				t.Errorf("Expected no error for %s, but got: %v", tc.name, err) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| // TestSSEEmptyDataHandling tests handling of empty data with SSE | |
| func TestSSEEmptyDataHandling(t *testing.T) { | |
| 	t.Run("SSE-C with empty data", func(t *testing.T) { | |
| 		keyPair := GenerateTestSSECKey(1) | |
| 		customerKey := &SSECustomerKey{ | |
| 			Algorithm: "AES256", | |
| 			Key:       keyPair.Key, | |
| 			KeyMD5:    keyPair.KeyMD5, | |
| 		} | |
| 
 | |
| 		// Encrypt empty data | |
| 		encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(""), customerKey) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to create encrypted reader for empty data: %v", err) | |
| 		} | |
| 
 | |
| 		encryptedData, err := io.ReadAll(encryptedReader) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to read encrypted empty data: %v", err) | |
| 		} | |
| 
 | |
| 		// Should have IV for empty data | |
| 		if len(iv) != s3_constants.AESBlockSize { | |
| 			t.Error("IV should be present even for empty data") | |
| 		} | |
| 
 | |
| 		// Decrypt and verify | |
| 		decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to create decrypted reader for empty data: %v", err) | |
| 		} | |
| 
 | |
| 		decryptedData, err := io.ReadAll(decryptedReader) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to read decrypted empty data: %v", err) | |
| 		} | |
| 
 | |
| 		if len(decryptedData) != 0 { | |
| 			t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData)) | |
| 		} | |
| 	}) | |
| 
 | |
| 	t.Run("SSE-KMS with empty data", func(t *testing.T) { | |
| 		kmsKey := SetupTestKMS(t) | |
| 		defer kmsKey.Cleanup() | |
| 
 | |
| 		encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) | |
| 
 | |
| 		// Encrypt empty data | |
| 		encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(""), kmsKey.KeyID, encryptionContext) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to create encrypted reader for empty data: %v", err) | |
| 		} | |
| 
 | |
| 		encryptedData, err := io.ReadAll(encryptedReader) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to read encrypted empty data: %v", err) | |
| 		} | |
| 
 | |
| 		// Empty data should produce empty encrypted data (IV is stored in metadata) | |
| 		if len(encryptedData) != 0 { | |
| 			t.Errorf("Encrypted empty data should be empty, got %d bytes", len(encryptedData)) | |
| 		} | |
| 
 | |
| 		// Decrypt and verify | |
| 		decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to create decrypted reader for empty data: %v", err) | |
| 		} | |
| 
 | |
| 		decryptedData, err := io.ReadAll(decryptedReader) | |
| 		if err != nil { | |
| 			t.Fatalf("Failed to read decrypted empty data: %v", err) | |
| 		} | |
| 
 | |
| 		if len(decryptedData) != 0 { | |
| 			t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData)) | |
| 		} | |
| 	}) | |
| } | |
| 
 | |
| // TestSSEConcurrentAccess tests SSE operations under concurrent access | |
| func TestSSEConcurrentAccess(t *testing.T) { | |
| 	keyPair := GenerateTestSSECKey(1) | |
| 	customerKey := &SSECustomerKey{ | |
| 		Algorithm: "AES256", | |
| 		Key:       keyPair.Key, | |
| 		KeyMD5:    keyPair.KeyMD5, | |
| 	} | |
| 
 | |
| 	const numGoroutines = 10 | |
| 	done := make(chan bool, numGoroutines) | |
| 	errors := make(chan error, numGoroutines) | |
| 
 | |
| 	// Run multiple encryption/decryption operations concurrently | |
| 	for i := 0; i < numGoroutines; i++ { | |
| 		go func(id int) { | |
| 			defer func() { done <- true }() | |
| 
 | |
| 			testData := fmt.Sprintf("test data %d", id) | |
| 
 | |
| 			// Encrypt | |
| 			encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), customerKey) | |
| 			if err != nil { | |
| 				errors <- fmt.Errorf("goroutine %d encrypt error: %v", id, err) | |
| 				return | |
| 			} | |
| 
 | |
| 			encryptedData, err := io.ReadAll(encryptedReader) | |
| 			if err != nil { | |
| 				errors <- fmt.Errorf("goroutine %d read encrypted error: %v", id, err) | |
| 				return | |
| 			} | |
| 
 | |
| 			// Decrypt | |
| 			decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) | |
| 			if err != nil { | |
| 				errors <- fmt.Errorf("goroutine %d decrypt error: %v", id, err) | |
| 				return | |
| 			} | |
| 
 | |
| 			decryptedData, err := io.ReadAll(decryptedReader) | |
| 			if err != nil { | |
| 				errors <- fmt.Errorf("goroutine %d read decrypted error: %v", id, err) | |
| 				return | |
| 			} | |
| 
 | |
| 			if string(decryptedData) != testData { | |
| 				errors <- fmt.Errorf("goroutine %d data mismatch: expected %s, got %s", id, testData, string(decryptedData)) | |
| 				return | |
| 			} | |
| 		}(i) | |
| 	} | |
| 
 | |
| 	// Wait for all goroutines to complete | |
| 	for i := 0; i < numGoroutines; i++ { | |
| 		<-done | |
| 	} | |
| 
 | |
| 	// Check for errors | |
| 	close(errors) | |
| 	for err := range errors { | |
| 		t.Error(err) | |
| 	} | |
| }
 |