You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							407 lines
						
					
					
						
							11 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							407 lines
						
					
					
						
							11 KiB
						
					
					
				| package s3api | |
| 
 | |
| import ( | |
| 	"bytes" | |
| 	"crypto/md5" | |
| 	"encoding/base64" | |
| 	"fmt" | |
| 	"io" | |
| 	"net/http" | |
| 	"testing" | |
| 
 | |
| 	"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" | |
| ) | |
| 
 | |
| func base64MD5(b []byte) string { | |
| 	s := md5.Sum(b) | |
| 	return base64.StdEncoding.EncodeToString(s[:]) | |
| } | |
| 
 | |
| func TestSSECHeaderValidation(t *testing.T) { | |
| 	// Test valid SSE-C headers | |
| 	req := &http.Request{Header: make(http.Header)} | |
| 
 | |
| 	key := make([]byte, 32) // 256-bit key | |
| 	for i := range key { | |
| 		key[i] = byte(i) | |
| 	} | |
| 
 | |
| 	keyBase64 := base64.StdEncoding.EncodeToString(key) | |
| 	md5sum := md5.Sum(key) | |
| 	keyMD5 := base64.StdEncoding.EncodeToString(md5sum[:]) | |
| 
 | |
| 	req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") | |
| 	req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyBase64) | |
| 	req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyMD5) | |
| 
 | |
| 	// Test validation | |
| 	err := ValidateSSECHeaders(req) | |
| 	if err != nil { | |
| 		t.Errorf("Expected valid headers, got error: %v", err) | |
| 	} | |
| 
 | |
| 	// Test parsing | |
| 	customerKey, err := ParseSSECHeaders(req) | |
| 	if err != nil { | |
| 		t.Errorf("Expected successful parsing, got error: %v", err) | |
| 	} | |
| 
 | |
| 	if customerKey == nil { | |
| 		t.Error("Expected customer key, got nil") | |
| 	} | |
| 
 | |
| 	if customerKey.Algorithm != "AES256" { | |
| 		t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm) | |
| 	} | |
| 
 | |
| 	if !bytes.Equal(customerKey.Key, key) { | |
| 		t.Error("Key doesn't match original") | |
| 	} | |
| 
 | |
| 	if customerKey.KeyMD5 != keyMD5 { | |
| 		t.Errorf("Expected key MD5 %s, got %s", keyMD5, customerKey.KeyMD5) | |
| 	} | |
| } | |
| 
 | |
| func TestSSECCopySourceHeaders(t *testing.T) { | |
| 	// Test valid SSE-C copy source headers | |
| 	req := &http.Request{Header: make(http.Header)} | |
| 
 | |
| 	key := make([]byte, 32) // 256-bit key | |
| 	for i := range key { | |
| 		key[i] = byte(i) + 1 // Different from regular test | |
| 	} | |
| 
 | |
| 	keyBase64 := base64.StdEncoding.EncodeToString(key) | |
| 	md5sum2 := md5.Sum(key) | |
| 	keyMD5 := base64.StdEncoding.EncodeToString(md5sum2[:]) | |
| 
 | |
| 	req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerAlgorithm, "AES256") | |
| 	req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKey, keyBase64) | |
| 	req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKeyMD5, keyMD5) | |
| 
 | |
| 	// Test parsing copy source headers | |
| 	customerKey, err := ParseSSECCopySourceHeaders(req) | |
| 	if err != nil { | |
| 		t.Errorf("Expected successful copy source parsing, got error: %v", err) | |
| 	} | |
| 
 | |
| 	if customerKey == nil { | |
| 		t.Error("Expected customer key from copy source headers, got nil") | |
| 	} | |
| 
 | |
| 	if customerKey.Algorithm != "AES256" { | |
| 		t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm) | |
| 	} | |
| 
 | |
| 	if !bytes.Equal(customerKey.Key, key) { | |
| 		t.Error("Copy source key doesn't match original") | |
| 	} | |
| 
 | |
| 	// Test that regular headers don't interfere with copy source headers | |
| 	regularKey, err := ParseSSECHeaders(req) | |
| 	if err != nil { | |
| 		t.Errorf("Regular header parsing should not fail: %v", err) | |
| 	} | |
| 
 | |
| 	if regularKey != nil { | |
| 		t.Error("Expected nil for regular headers when only copy source headers are present") | |
| 	} | |
| } | |
| 
 | |
| func TestSSECHeaderValidationErrors(t *testing.T) { | |
| 	tests := []struct { | |
| 		name      string | |
| 		algorithm string | |
| 		key       string | |
| 		keyMD5    string | |
| 		wantErr   error | |
| 	}{ | |
| 		{ | |
| 			name:      "invalid algorithm", | |
| 			algorithm: "AES128", | |
| 			key:       base64.StdEncoding.EncodeToString(make([]byte, 32)), | |
| 			keyMD5:    base64MD5(make([]byte, 32)), | |
| 			wantErr:   ErrInvalidEncryptionAlgorithm, | |
| 		}, | |
| 		{ | |
| 			name:      "invalid key length", | |
| 			algorithm: "AES256", | |
| 			key:       base64.StdEncoding.EncodeToString(make([]byte, 16)), | |
| 			keyMD5:    base64MD5(make([]byte, 16)), | |
| 			wantErr:   ErrInvalidEncryptionKey, | |
| 		}, | |
| 		{ | |
| 			name:      "mismatched MD5", | |
| 			algorithm: "AES256", | |
| 			key:       base64.StdEncoding.EncodeToString(make([]byte, 32)), | |
| 			keyMD5:    "wrong==md5", | |
| 			wantErr:   ErrSSECustomerKeyMD5Mismatch, | |
| 		}, | |
| 		{ | |
| 			name:      "incomplete headers", | |
| 			algorithm: "AES256", | |
| 			key:       "", | |
| 			keyMD5:    "", | |
| 			wantErr:   ErrInvalidRequest, | |
| 		}, | |
| 	} | |
| 
 | |
| 	for _, tt := range tests { | |
| 		t.Run(tt.name, func(t *testing.T) { | |
| 			req := &http.Request{Header: make(http.Header)} | |
| 
 | |
| 			if tt.algorithm != "" { | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tt.algorithm) | |
| 			} | |
| 			if tt.key != "" { | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tt.key) | |
| 			} | |
| 			if tt.keyMD5 != "" { | |
| 				req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tt.keyMD5) | |
| 			} | |
| 
 | |
| 			err := ValidateSSECHeaders(req) | |
| 			if err != tt.wantErr { | |
| 				t.Errorf("Expected error %v, got %v", tt.wantErr, err) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| func TestSSECEncryptionDecryption(t *testing.T) { | |
| 	// Create customer key | |
| 	key := make([]byte, 32) | |
| 	for i := range key { | |
| 		key[i] = byte(i) | |
| 	} | |
| 
 | |
| 	md5sumKey := md5.Sum(key) | |
| 	customerKey := &SSECustomerKey{ | |
| 		Algorithm: "AES256", | |
| 		Key:       key, | |
| 		KeyMD5:    base64.StdEncoding.EncodeToString(md5sumKey[:]), | |
| 	} | |
| 
 | |
| 	// Test data | |
| 	testData := []byte("Hello, World! This is a test of SSE-C encryption.") | |
| 
 | |
| 	// Create encrypted reader | |
| 	dataReader := bytes.NewReader(testData) | |
| 	encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create encrypted reader: %v", err) | |
| 	} | |
| 
 | |
| 	// Read encrypted data | |
| 	encryptedData, err := io.ReadAll(encryptedReader) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to read encrypted data: %v", err) | |
| 	} | |
| 
 | |
| 	// Verify data is actually encrypted (different from original) | |
| 	if bytes.Equal(encryptedData[16:], testData) { // Skip IV | |
| 		t.Error("Data doesn't appear to be encrypted") | |
| 	} | |
| 
 | |
| 	// Create decrypted reader | |
| 	encryptedReader2 := bytes.NewReader(encryptedData) | |
| 	decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create decrypted reader: %v", err) | |
| 	} | |
| 
 | |
| 	// Read decrypted data | |
| 	decryptedData, err := io.ReadAll(decryptedReader) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to read decrypted data: %v", err) | |
| 	} | |
| 
 | |
| 	// Verify decrypted data matches original | |
| 	if !bytes.Equal(decryptedData, testData) { | |
| 		t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData) | |
| 	} | |
| } | |
| 
 | |
| func TestSSECIsSSECRequest(t *testing.T) { | |
| 	// Test with SSE-C headers | |
| 	req := &http.Request{Header: make(http.Header)} | |
| 	req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") | |
| 
 | |
| 	if !IsSSECRequest(req) { | |
| 		t.Error("Expected IsSSECRequest to return true when SSE-C headers are present") | |
| 	} | |
| 
 | |
| 	// Test without SSE-C headers | |
| 	req2 := &http.Request{Header: make(http.Header)} | |
| 	if IsSSECRequest(req2) { | |
| 		t.Error("Expected IsSSECRequest to return false when no SSE-C headers are present") | |
| 	} | |
| } | |
| 
 | |
| // Test encryption with different data sizes (similar to s3tests) | |
| func TestSSECEncryptionVariousSizes(t *testing.T) { | |
| 	sizes := []int{1, 13, 1024, 1024 * 1024} // 1B, 13B, 1KB, 1MB | |
|  | |
| 	for _, size := range sizes { | |
| 		t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { | |
| 			// Create customer key | |
| 			key := make([]byte, 32) | |
| 			for i := range key { | |
| 				key[i] = byte(i + size) // Make key unique per test | |
| 			} | |
| 
 | |
| 			md5sumDyn := md5.Sum(key) | |
| 			customerKey := &SSECustomerKey{ | |
| 				Algorithm: "AES256", | |
| 				Key:       key, | |
| 				KeyMD5:    base64.StdEncoding.EncodeToString(md5sumDyn[:]), | |
| 			} | |
| 
 | |
| 			// Create test data of specified size | |
| 			testData := make([]byte, size) | |
| 			for i := range testData { | |
| 				testData[i] = byte('A' + (i % 26)) // Pattern of A-Z | |
| 			} | |
| 
 | |
| 			// Encrypt | |
| 			dataReader := bytes.NewReader(testData) | |
| 			encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey) | |
| 			if err != nil { | |
| 				t.Fatalf("Failed to create encrypted reader: %v", err) | |
| 			} | |
| 
 | |
| 			encryptedData, err := io.ReadAll(encryptedReader) | |
| 			if err != nil { | |
| 				t.Fatalf("Failed to read encrypted data: %v", err) | |
| 			} | |
| 
 | |
| 			// Verify encrypted data has same size as original (IV is stored in metadata, not in stream) | |
| 			if len(encryptedData) != size { | |
| 				t.Errorf("Expected encrypted data length %d (same as original), got %d", size, len(encryptedData)) | |
| 			} | |
| 
 | |
| 			// Decrypt | |
| 			encryptedReader2 := bytes.NewReader(encryptedData) | |
| 			decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv) | |
| 			if err != nil { | |
| 				t.Fatalf("Failed to create decrypted reader: %v", err) | |
| 			} | |
| 
 | |
| 			decryptedData, err := io.ReadAll(decryptedReader) | |
| 			if err != nil { | |
| 				t.Fatalf("Failed to read decrypted data: %v", err) | |
| 			} | |
| 
 | |
| 			// Verify decrypted data matches original | |
| 			if !bytes.Equal(decryptedData, testData) { | |
| 				t.Errorf("Decrypted data doesn't match original for size %d", size) | |
| 			} | |
| 		}) | |
| 	} | |
| } | |
| 
 | |
| func TestSSECEncryptionWithNilKey(t *testing.T) { | |
| 	testData := []byte("test data") | |
| 	dataReader := bytes.NewReader(testData) | |
| 
 | |
| 	// Test encryption with nil key (should pass through) | |
| 	encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, nil) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create encrypted reader with nil key: %v", err) | |
| 	} | |
| 
 | |
| 	result, err := io.ReadAll(encryptedReader) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to read from pass-through reader: %v", err) | |
| 	} | |
| 
 | |
| 	if !bytes.Equal(result, testData) { | |
| 		t.Error("Data should pass through unchanged when key is nil") | |
| 	} | |
| 
 | |
| 	// Test decryption with nil key (should pass through) | |
| 	dataReader2 := bytes.NewReader(testData) | |
| 	decryptedReader, err := CreateSSECDecryptedReader(dataReader2, nil, iv) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create decrypted reader with nil key: %v", err) | |
| 	} | |
| 
 | |
| 	result2, err := io.ReadAll(decryptedReader) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to read from pass-through reader: %v", err) | |
| 	} | |
| 
 | |
| 	if !bytes.Equal(result2, testData) { | |
| 		t.Error("Data should pass through unchanged when key is nil") | |
| 	} | |
| } | |
| 
 | |
| // TestSSECEncryptionSmallBuffers tests the fix for the critical bug where small buffers | |
| // could corrupt the data stream when reading in chunks smaller than the IV size | |
| func TestSSECEncryptionSmallBuffers(t *testing.T) { | |
| 	testData := []byte("This is a test message for small buffer reads") | |
| 
 | |
| 	// Create customer key | |
| 	key := make([]byte, 32) | |
| 	for i := range key { | |
| 		key[i] = byte(i) | |
| 	} | |
| 
 | |
| 	md5sumKey3 := md5.Sum(key) | |
| 	customerKey := &SSECustomerKey{ | |
| 		Algorithm: "AES256", | |
| 		Key:       key, | |
| 		KeyMD5:    base64.StdEncoding.EncodeToString(md5sumKey3[:]), | |
| 	} | |
| 
 | |
| 	// Create encrypted reader | |
| 	dataReader := bytes.NewReader(testData) | |
| 	encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create encrypted reader: %v", err) | |
| 	} | |
| 
 | |
| 	// Read with very small buffers (smaller than IV size of 16 bytes) | |
| 	var encryptedData []byte | |
| 	smallBuffer := make([]byte, 5) // Much smaller than 16-byte IV | |
|  | |
| 	for { | |
| 		n, err := encryptedReader.Read(smallBuffer) | |
| 		if n > 0 { | |
| 			encryptedData = append(encryptedData, smallBuffer[:n]...) | |
| 		} | |
| 		if err == io.EOF { | |
| 			break | |
| 		} | |
| 		if err != nil { | |
| 			t.Fatalf("Error reading encrypted data: %v", err) | |
| 		} | |
| 	} | |
| 
 | |
| 	// Verify we have some encrypted data (IV is in metadata, not in stream) | |
| 	if len(encryptedData) == 0 && len(testData) > 0 { | |
| 		t.Fatal("Expected encrypted data but got none") | |
| 	} | |
| 
 | |
| 	// Expected size: same as original data (IV is stored in metadata, not in stream) | |
| 	if len(encryptedData) != len(testData) { | |
| 		t.Errorf("Expected encrypted data size %d (same as original), got %d", len(testData), len(encryptedData)) | |
| 	} | |
| 
 | |
| 	// Decrypt and verify | |
| 	encryptedReader2 := bytes.NewReader(encryptedData) | |
| 	decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to create decrypted reader: %v", err) | |
| 	} | |
| 
 | |
| 	decryptedData, err := io.ReadAll(decryptedReader) | |
| 	if err != nil { | |
| 		t.Fatalf("Failed to read decrypted data: %v", err) | |
| 	} | |
| 
 | |
| 	if !bytes.Equal(decryptedData, testData) { | |
| 		t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData) | |
| 	} | |
| }
 |