You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							407 lines
						
					
					
						
							11 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							407 lines
						
					
					
						
							11 KiB
						
					
					
				
								package s3api
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"bytes"
							 | 
						|
									"crypto/md5"
							 | 
						|
									"encoding/base64"
							 | 
						|
									"fmt"
							 | 
						|
									"io"
							 | 
						|
									"net/http"
							 | 
						|
									"testing"
							 | 
						|
								
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								func base64MD5(b []byte) string {
							 | 
						|
									s := md5.Sum(b)
							 | 
						|
									return base64.StdEncoding.EncodeToString(s[:])
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestSSECHeaderValidation(t *testing.T) {
							 | 
						|
									// Test valid SSE-C headers
							 | 
						|
									req := &http.Request{Header: make(http.Header)}
							 | 
						|
								
							 | 
						|
									key := make([]byte, 32) // 256-bit key
							 | 
						|
									for i := range key {
							 | 
						|
										key[i] = byte(i)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									keyBase64 := base64.StdEncoding.EncodeToString(key)
							 | 
						|
									md5sum := md5.Sum(key)
							 | 
						|
									keyMD5 := base64.StdEncoding.EncodeToString(md5sum[:])
							 | 
						|
								
							 | 
						|
									req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
							 | 
						|
									req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyBase64)
							 | 
						|
									req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyMD5)
							 | 
						|
								
							 | 
						|
									// Test validation
							 | 
						|
									err := ValidateSSECHeaders(req)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Errorf("Expected valid headers, got error: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Test parsing
							 | 
						|
									customerKey, err := ParseSSECHeaders(req)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Errorf("Expected successful parsing, got error: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if customerKey == nil {
							 | 
						|
										t.Error("Expected customer key, got nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if customerKey.Algorithm != "AES256" {
							 | 
						|
										t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if !bytes.Equal(customerKey.Key, key) {
							 | 
						|
										t.Error("Key doesn't match original")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if customerKey.KeyMD5 != keyMD5 {
							 | 
						|
										t.Errorf("Expected key MD5 %s, got %s", keyMD5, customerKey.KeyMD5)
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestSSECCopySourceHeaders(t *testing.T) {
							 | 
						|
									// Test valid SSE-C copy source headers
							 | 
						|
									req := &http.Request{Header: make(http.Header)}
							 | 
						|
								
							 | 
						|
									key := make([]byte, 32) // 256-bit key
							 | 
						|
									for i := range key {
							 | 
						|
										key[i] = byte(i) + 1 // Different from regular test
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									keyBase64 := base64.StdEncoding.EncodeToString(key)
							 | 
						|
									md5sum2 := md5.Sum(key)
							 | 
						|
									keyMD5 := base64.StdEncoding.EncodeToString(md5sum2[:])
							 | 
						|
								
							 | 
						|
									req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerAlgorithm, "AES256")
							 | 
						|
									req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKey, keyBase64)
							 | 
						|
									req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKeyMD5, keyMD5)
							 | 
						|
								
							 | 
						|
									// Test parsing copy source headers
							 | 
						|
									customerKey, err := ParseSSECCopySourceHeaders(req)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Errorf("Expected successful copy source parsing, got error: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if customerKey == nil {
							 | 
						|
										t.Error("Expected customer key from copy source headers, got nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if customerKey.Algorithm != "AES256" {
							 | 
						|
										t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if !bytes.Equal(customerKey.Key, key) {
							 | 
						|
										t.Error("Copy source key doesn't match original")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Test that regular headers don't interfere with copy source headers
							 | 
						|
									regularKey, err := ParseSSECHeaders(req)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Errorf("Regular header parsing should not fail: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if regularKey != nil {
							 | 
						|
										t.Error("Expected nil for regular headers when only copy source headers are present")
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestSSECHeaderValidationErrors(t *testing.T) {
							 | 
						|
									tests := []struct {
							 | 
						|
										name      string
							 | 
						|
										algorithm string
							 | 
						|
										key       string
							 | 
						|
										keyMD5    string
							 | 
						|
										wantErr   error
							 | 
						|
									}{
							 | 
						|
										{
							 | 
						|
											name:      "invalid algorithm",
							 | 
						|
											algorithm: "AES128",
							 | 
						|
											key:       base64.StdEncoding.EncodeToString(make([]byte, 32)),
							 | 
						|
											keyMD5:    base64MD5(make([]byte, 32)),
							 | 
						|
											wantErr:   ErrInvalidEncryptionAlgorithm,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:      "invalid key length",
							 | 
						|
											algorithm: "AES256",
							 | 
						|
											key:       base64.StdEncoding.EncodeToString(make([]byte, 16)),
							 | 
						|
											keyMD5:    base64MD5(make([]byte, 16)),
							 | 
						|
											wantErr:   ErrInvalidEncryptionKey,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:      "mismatched MD5",
							 | 
						|
											algorithm: "AES256",
							 | 
						|
											key:       base64.StdEncoding.EncodeToString(make([]byte, 32)),
							 | 
						|
											keyMD5:    "wrong==md5",
							 | 
						|
											wantErr:   ErrSSECustomerKeyMD5Mismatch,
							 | 
						|
										},
							 | 
						|
										{
							 | 
						|
											name:      "incomplete headers",
							 | 
						|
											algorithm: "AES256",
							 | 
						|
											key:       "",
							 | 
						|
											keyMD5:    "",
							 | 
						|
											wantErr:   ErrInvalidRequest,
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, tt := range tests {
							 | 
						|
										t.Run(tt.name, func(t *testing.T) {
							 | 
						|
											req := &http.Request{Header: make(http.Header)}
							 | 
						|
								
							 | 
						|
											if tt.algorithm != "" {
							 | 
						|
												req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tt.algorithm)
							 | 
						|
											}
							 | 
						|
											if tt.key != "" {
							 | 
						|
												req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tt.key)
							 | 
						|
											}
							 | 
						|
											if tt.keyMD5 != "" {
							 | 
						|
												req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tt.keyMD5)
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											err := ValidateSSECHeaders(req)
							 | 
						|
											if err != tt.wantErr {
							 | 
						|
												t.Errorf("Expected error %v, got %v", tt.wantErr, err)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestSSECEncryptionDecryption(t *testing.T) {
							 | 
						|
									// Create customer key
							 | 
						|
									key := make([]byte, 32)
							 | 
						|
									for i := range key {
							 | 
						|
										key[i] = byte(i)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									md5sumKey := md5.Sum(key)
							 | 
						|
									customerKey := &SSECustomerKey{
							 | 
						|
										Algorithm: "AES256",
							 | 
						|
										Key:       key,
							 | 
						|
										KeyMD5:    base64.StdEncoding.EncodeToString(md5sumKey[:]),
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Test data
							 | 
						|
									testData := []byte("Hello, World! This is a test of SSE-C encryption.")
							 | 
						|
								
							 | 
						|
									// Create encrypted reader
							 | 
						|
									dataReader := bytes.NewReader(testData)
							 | 
						|
									encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to create encrypted reader: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Read encrypted data
							 | 
						|
									encryptedData, err := io.ReadAll(encryptedReader)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to read encrypted data: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Verify data is actually encrypted (different from original)
							 | 
						|
									if bytes.Equal(encryptedData[16:], testData) { // Skip IV
							 | 
						|
										t.Error("Data doesn't appear to be encrypted")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Create decrypted reader
							 | 
						|
									encryptedReader2 := bytes.NewReader(encryptedData)
							 | 
						|
									decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to create decrypted reader: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Read decrypted data
							 | 
						|
									decryptedData, err := io.ReadAll(decryptedReader)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to read decrypted data: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Verify decrypted data matches original
							 | 
						|
									if !bytes.Equal(decryptedData, testData) {
							 | 
						|
										t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestSSECIsSSECRequest(t *testing.T) {
							 | 
						|
									// Test with SSE-C headers
							 | 
						|
									req := &http.Request{Header: make(http.Header)}
							 | 
						|
									req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
							 | 
						|
								
							 | 
						|
									if !IsSSECRequest(req) {
							 | 
						|
										t.Error("Expected IsSSECRequest to return true when SSE-C headers are present")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Test without SSE-C headers
							 | 
						|
									req2 := &http.Request{Header: make(http.Header)}
							 | 
						|
									if IsSSECRequest(req2) {
							 | 
						|
										t.Error("Expected IsSSECRequest to return false when no SSE-C headers are present")
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Test encryption with different data sizes (similar to s3tests)
							 | 
						|
								func TestSSECEncryptionVariousSizes(t *testing.T) {
							 | 
						|
									sizes := []int{1, 13, 1024, 1024 * 1024} // 1B, 13B, 1KB, 1MB
							 | 
						|
								
							 | 
						|
									for _, size := range sizes {
							 | 
						|
										t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
							 | 
						|
											// Create customer key
							 | 
						|
											key := make([]byte, 32)
							 | 
						|
											for i := range key {
							 | 
						|
												key[i] = byte(i + size) // Make key unique per test
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											md5sumDyn := md5.Sum(key)
							 | 
						|
											customerKey := &SSECustomerKey{
							 | 
						|
												Algorithm: "AES256",
							 | 
						|
												Key:       key,
							 | 
						|
												KeyMD5:    base64.StdEncoding.EncodeToString(md5sumDyn[:]),
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Create test data of specified size
							 | 
						|
											testData := make([]byte, size)
							 | 
						|
											for i := range testData {
							 | 
						|
												testData[i] = byte('A' + (i % 26)) // Pattern of A-Z
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Encrypt
							 | 
						|
											dataReader := bytes.NewReader(testData)
							 | 
						|
											encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
							 | 
						|
											if err != nil {
							 | 
						|
												t.Fatalf("Failed to create encrypted reader: %v", err)
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											encryptedData, err := io.ReadAll(encryptedReader)
							 | 
						|
											if err != nil {
							 | 
						|
												t.Fatalf("Failed to read encrypted data: %v", err)
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Verify encrypted data has same size as original (IV is stored in metadata, not in stream)
							 | 
						|
											if len(encryptedData) != size {
							 | 
						|
												t.Errorf("Expected encrypted data length %d (same as original), got %d", size, len(encryptedData))
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Decrypt
							 | 
						|
											encryptedReader2 := bytes.NewReader(encryptedData)
							 | 
						|
											decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
							 | 
						|
											if err != nil {
							 | 
						|
												t.Fatalf("Failed to create decrypted reader: %v", err)
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											decryptedData, err := io.ReadAll(decryptedReader)
							 | 
						|
											if err != nil {
							 | 
						|
												t.Fatalf("Failed to read decrypted data: %v", err)
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// Verify decrypted data matches original
							 | 
						|
											if !bytes.Equal(decryptedData, testData) {
							 | 
						|
												t.Errorf("Decrypted data doesn't match original for size %d", size)
							 | 
						|
											}
							 | 
						|
										})
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func TestSSECEncryptionWithNilKey(t *testing.T) {
							 | 
						|
									testData := []byte("test data")
							 | 
						|
									dataReader := bytes.NewReader(testData)
							 | 
						|
								
							 | 
						|
									// Test encryption with nil key (should pass through)
							 | 
						|
									encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, nil)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to create encrypted reader with nil key: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									result, err := io.ReadAll(encryptedReader)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to read from pass-through reader: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if !bytes.Equal(result, testData) {
							 | 
						|
										t.Error("Data should pass through unchanged when key is nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Test decryption with nil key (should pass through)
							 | 
						|
									dataReader2 := bytes.NewReader(testData)
							 | 
						|
									decryptedReader, err := CreateSSECDecryptedReader(dataReader2, nil, iv)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to create decrypted reader with nil key: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									result2, err := io.ReadAll(decryptedReader)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to read from pass-through reader: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if !bytes.Equal(result2, testData) {
							 | 
						|
										t.Error("Data should pass through unchanged when key is nil")
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// TestSSECEncryptionSmallBuffers tests the fix for the critical bug where small buffers
							 | 
						|
								// could corrupt the data stream when reading in chunks smaller than the IV size
							 | 
						|
								func TestSSECEncryptionSmallBuffers(t *testing.T) {
							 | 
						|
									testData := []byte("This is a test message for small buffer reads")
							 | 
						|
								
							 | 
						|
									// Create customer key
							 | 
						|
									key := make([]byte, 32)
							 | 
						|
									for i := range key {
							 | 
						|
										key[i] = byte(i)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									md5sumKey3 := md5.Sum(key)
							 | 
						|
									customerKey := &SSECustomerKey{
							 | 
						|
										Algorithm: "AES256",
							 | 
						|
										Key:       key,
							 | 
						|
										KeyMD5:    base64.StdEncoding.EncodeToString(md5sumKey3[:]),
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Create encrypted reader
							 | 
						|
									dataReader := bytes.NewReader(testData)
							 | 
						|
									encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to create encrypted reader: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Read with very small buffers (smaller than IV size of 16 bytes)
							 | 
						|
									var encryptedData []byte
							 | 
						|
									smallBuffer := make([]byte, 5) // Much smaller than 16-byte IV
							 | 
						|
								
							 | 
						|
									for {
							 | 
						|
										n, err := encryptedReader.Read(smallBuffer)
							 | 
						|
										if n > 0 {
							 | 
						|
											encryptedData = append(encryptedData, smallBuffer[:n]...)
							 | 
						|
										}
							 | 
						|
										if err == io.EOF {
							 | 
						|
											break
							 | 
						|
										}
							 | 
						|
										if err != nil {
							 | 
						|
											t.Fatalf("Error reading encrypted data: %v", err)
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Verify we have some encrypted data (IV is in metadata, not in stream)
							 | 
						|
									if len(encryptedData) == 0 && len(testData) > 0 {
							 | 
						|
										t.Fatal("Expected encrypted data but got none")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Expected size: same as original data (IV is stored in metadata, not in stream)
							 | 
						|
									if len(encryptedData) != len(testData) {
							 | 
						|
										t.Errorf("Expected encrypted data size %d (same as original), got %d", len(testData), len(encryptedData))
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Decrypt and verify
							 | 
						|
									encryptedReader2 := bytes.NewReader(encryptedData)
							 | 
						|
									decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to create decrypted reader: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									decryptedData, err := io.ReadAll(decryptedReader)
							 | 
						|
									if err != nil {
							 | 
						|
										t.Fatalf("Failed to read decrypted data: %v", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if !bytes.Equal(decryptedData, testData) {
							 | 
						|
										t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
							 | 
						|
									}
							 | 
						|
								}
							 |