You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							389 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							389 lines
						
					
					
						
							12 KiB
						
					
					
				
								package aws
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"context"
							 | 
						|
									"encoding/base64"
							 | 
						|
									"fmt"
							 | 
						|
									"net/http"
							 | 
						|
									"strings"
							 | 
						|
									"time"
							 | 
						|
								
							 | 
						|
									"github.com/aws/aws-sdk-go/aws"
							 | 
						|
									"github.com/aws/aws-sdk-go/aws/awserr"
							 | 
						|
									"github.com/aws/aws-sdk-go/aws/credentials"
							 | 
						|
									"github.com/aws/aws-sdk-go/aws/session"
							 | 
						|
									"github.com/aws/aws-sdk-go/service/kms"
							 | 
						|
								
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/glog"
							 | 
						|
									seaweedkms "github.com/seaweedfs/seaweedfs/weed/kms"
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/util"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								func init() {
							 | 
						|
									// Register the AWS KMS provider
							 | 
						|
									seaweedkms.RegisterProvider("aws", NewAWSKMSProvider)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// AWSKMSProvider implements the KMSProvider interface using AWS KMS
							 | 
						|
								type AWSKMSProvider struct {
							 | 
						|
									client   *kms.KMS
							 | 
						|
									region   string
							 | 
						|
									endpoint string // For testing with LocalStack or custom endpoints
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// AWSKMSConfig contains configuration for the AWS KMS provider
							 | 
						|
								type AWSKMSConfig struct {
							 | 
						|
									Region         string `json:"region"`          // AWS region (e.g., "us-east-1")
							 | 
						|
									AccessKey      string `json:"access_key"`      // AWS access key (optional if using IAM roles)
							 | 
						|
									SecretKey      string `json:"secret_key"`      // AWS secret key (optional if using IAM roles)
							 | 
						|
									SessionToken   string `json:"session_token"`   // AWS session token (optional for STS)
							 | 
						|
									Endpoint       string `json:"endpoint"`        // Custom endpoint (optional, for LocalStack/testing)
							 | 
						|
									Profile        string `json:"profile"`         // AWS profile name (optional)
							 | 
						|
									RoleARN        string `json:"role_arn"`        // IAM role ARN to assume (optional)
							 | 
						|
									ExternalID     string `json:"external_id"`     // External ID for role assumption (optional)
							 | 
						|
									ConnectTimeout int    `json:"connect_timeout"` // Connection timeout in seconds (default: 10)
							 | 
						|
									RequestTimeout int    `json:"request_timeout"` // Request timeout in seconds (default: 30)
							 | 
						|
									MaxRetries     int    `json:"max_retries"`     // Maximum number of retries (default: 3)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// NewAWSKMSProvider creates a new AWS KMS provider
							 | 
						|
								func NewAWSKMSProvider(config util.Configuration) (seaweedkms.KMSProvider, error) {
							 | 
						|
									if config == nil {
							 | 
						|
										return nil, fmt.Errorf("AWS KMS configuration is required")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Extract configuration
							 | 
						|
									region := config.GetString("region")
							 | 
						|
									if region == "" {
							 | 
						|
										region = "us-east-1" // Default region
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									accessKey := config.GetString("access_key")
							 | 
						|
									secretKey := config.GetString("secret_key")
							 | 
						|
									sessionToken := config.GetString("session_token")
							 | 
						|
									endpoint := config.GetString("endpoint")
							 | 
						|
									profile := config.GetString("profile")
							 | 
						|
								
							 | 
						|
									// Timeouts and retries
							 | 
						|
									connectTimeout := config.GetInt("connect_timeout")
							 | 
						|
									if connectTimeout == 0 {
							 | 
						|
										connectTimeout = 10 // Default 10 seconds
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									requestTimeout := config.GetInt("request_timeout")
							 | 
						|
									if requestTimeout == 0 {
							 | 
						|
										requestTimeout = 30 // Default 30 seconds
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									maxRetries := config.GetInt("max_retries")
							 | 
						|
									if maxRetries == 0 {
							 | 
						|
										maxRetries = 3 // Default 3 retries
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Create AWS session
							 | 
						|
									awsConfig := &aws.Config{
							 | 
						|
										Region:     aws.String(region),
							 | 
						|
										MaxRetries: aws.Int(maxRetries),
							 | 
						|
										HTTPClient: &http.Client{
							 | 
						|
											Timeout: time.Duration(requestTimeout) * time.Second,
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Set custom endpoint if provided (for testing with LocalStack)
							 | 
						|
									if endpoint != "" {
							 | 
						|
										awsConfig.Endpoint = aws.String(endpoint)
							 | 
						|
										awsConfig.DisableSSL = aws.Bool(strings.HasPrefix(endpoint, "http://"))
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Configure credentials
							 | 
						|
									if accessKey != "" && secretKey != "" {
							 | 
						|
										awsConfig.Credentials = credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)
							 | 
						|
									} else if profile != "" {
							 | 
						|
										awsConfig.Credentials = credentials.NewSharedCredentials("", profile)
							 | 
						|
									}
							 | 
						|
									// If neither are provided, use default credential chain (IAM roles, etc.)
							 | 
						|
								
							 | 
						|
									sess, err := session.NewSession(awsConfig)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to create AWS session: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									provider := &AWSKMSProvider{
							 | 
						|
										client:   kms.New(sess),
							 | 
						|
										region:   region,
							 | 
						|
										endpoint: endpoint,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									glog.V(1).Infof("AWS KMS provider initialized for region %s", region)
							 | 
						|
									return provider, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GenerateDataKey generates a new data encryption key using AWS KMS
							 | 
						|
								func (p *AWSKMSProvider) GenerateDataKey(ctx context.Context, req *seaweedkms.GenerateDataKeyRequest) (*seaweedkms.GenerateDataKeyResponse, error) {
							 | 
						|
									if req == nil {
							 | 
						|
										return nil, fmt.Errorf("GenerateDataKeyRequest cannot be nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if req.KeyID == "" {
							 | 
						|
										return nil, fmt.Errorf("KeyID is required")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Validate key spec
							 | 
						|
									var keySpec string
							 | 
						|
									switch req.KeySpec {
							 | 
						|
									case seaweedkms.KeySpecAES256:
							 | 
						|
										keySpec = "AES_256"
							 | 
						|
									default:
							 | 
						|
										return nil, fmt.Errorf("unsupported key spec: %s", req.KeySpec)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Build KMS request
							 | 
						|
									kmsReq := &kms.GenerateDataKeyInput{
							 | 
						|
										KeyId:   aws.String(req.KeyID),
							 | 
						|
										KeySpec: aws.String(keySpec),
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Add encryption context if provided
							 | 
						|
									if len(req.EncryptionContext) > 0 {
							 | 
						|
										kmsReq.EncryptionContext = aws.StringMap(req.EncryptionContext)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Call AWS KMS
							 | 
						|
									glog.V(4).Infof("AWS KMS: Generating data key for key ID %s", req.KeyID)
							 | 
						|
									result, err := p.client.GenerateDataKeyWithContext(ctx, kmsReq)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, p.convertAWSError(err, req.KeyID)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Extract the actual key ID from the response (resolves aliases)
							 | 
						|
									actualKeyID := ""
							 | 
						|
									if result.KeyId != nil {
							 | 
						|
										actualKeyID = *result.KeyId
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Create standardized envelope format for consistent API behavior
							 | 
						|
									envelopeBlob, err := seaweedkms.CreateEnvelope("aws", actualKeyID, base64.StdEncoding.EncodeToString(result.CiphertextBlob), nil)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to create ciphertext envelope: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									response := &seaweedkms.GenerateDataKeyResponse{
							 | 
						|
										KeyID:          actualKeyID,
							 | 
						|
										Plaintext:      result.Plaintext,
							 | 
						|
										CiphertextBlob: envelopeBlob, // Store in standardized envelope format
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									glog.V(4).Infof("AWS KMS: Generated data key for key ID %s (actual: %s)", req.KeyID, actualKeyID)
							 | 
						|
									return response, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Decrypt decrypts an encrypted data key using AWS KMS
							 | 
						|
								func (p *AWSKMSProvider) Decrypt(ctx context.Context, req *seaweedkms.DecryptRequest) (*seaweedkms.DecryptResponse, error) {
							 | 
						|
									if req == nil {
							 | 
						|
										return nil, fmt.Errorf("DecryptRequest cannot be nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if len(req.CiphertextBlob) == 0 {
							 | 
						|
										return nil, fmt.Errorf("CiphertextBlob cannot be empty")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Parse the ciphertext envelope to extract key information
							 | 
						|
									envelope, err := seaweedkms.ParseEnvelope(req.CiphertextBlob)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to parse ciphertext envelope: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if envelope.Provider != "aws" {
							 | 
						|
										return nil, fmt.Errorf("invalid provider in envelope: expected 'aws', got '%s'", envelope.Provider)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									ciphertext, err := base64.StdEncoding.DecodeString(envelope.Ciphertext)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to decode ciphertext from envelope: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Build KMS request
							 | 
						|
									kmsReq := &kms.DecryptInput{
							 | 
						|
										CiphertextBlob: ciphertext,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Add encryption context if provided
							 | 
						|
									if len(req.EncryptionContext) > 0 {
							 | 
						|
										kmsReq.EncryptionContext = aws.StringMap(req.EncryptionContext)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Call AWS KMS
							 | 
						|
									glog.V(4).Infof("AWS KMS: Decrypting data key (blob size: %d bytes)", len(req.CiphertextBlob))
							 | 
						|
									result, err := p.client.DecryptWithContext(ctx, kmsReq)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, p.convertAWSError(err, "")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Extract the key ID that was used for encryption
							 | 
						|
									keyID := ""
							 | 
						|
									if result.KeyId != nil {
							 | 
						|
										keyID = *result.KeyId
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									response := &seaweedkms.DecryptResponse{
							 | 
						|
										KeyID:     keyID,
							 | 
						|
										Plaintext: result.Plaintext,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									glog.V(4).Infof("AWS KMS: Decrypted data key using key ID %s", keyID)
							 | 
						|
									return response, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// DescribeKey validates that a key exists and returns its metadata
							 | 
						|
								func (p *AWSKMSProvider) DescribeKey(ctx context.Context, req *seaweedkms.DescribeKeyRequest) (*seaweedkms.DescribeKeyResponse, error) {
							 | 
						|
									if req == nil {
							 | 
						|
										return nil, fmt.Errorf("DescribeKeyRequest cannot be nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if req.KeyID == "" {
							 | 
						|
										return nil, fmt.Errorf("KeyID is required")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Build KMS request
							 | 
						|
									kmsReq := &kms.DescribeKeyInput{
							 | 
						|
										KeyId: aws.String(req.KeyID),
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Call AWS KMS
							 | 
						|
									glog.V(4).Infof("AWS KMS: Describing key %s", req.KeyID)
							 | 
						|
									result, err := p.client.DescribeKeyWithContext(ctx, kmsReq)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, p.convertAWSError(err, req.KeyID)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if result.KeyMetadata == nil {
							 | 
						|
										return nil, fmt.Errorf("no key metadata returned from AWS KMS")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									metadata := result.KeyMetadata
							 | 
						|
									response := &seaweedkms.DescribeKeyResponse{
							 | 
						|
										KeyID:       aws.StringValue(metadata.KeyId),
							 | 
						|
										ARN:         aws.StringValue(metadata.Arn),
							 | 
						|
										Description: aws.StringValue(metadata.Description),
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Convert AWS key usage to our enum
							 | 
						|
									if metadata.KeyUsage != nil {
							 | 
						|
										switch *metadata.KeyUsage {
							 | 
						|
										case "ENCRYPT_DECRYPT":
							 | 
						|
											response.KeyUsage = seaweedkms.KeyUsageEncryptDecrypt
							 | 
						|
										case "GENERATE_DATA_KEY":
							 | 
						|
											response.KeyUsage = seaweedkms.KeyUsageGenerateDataKey
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Convert AWS key state to our enum
							 | 
						|
									if metadata.KeyState != nil {
							 | 
						|
										switch *metadata.KeyState {
							 | 
						|
										case "Enabled":
							 | 
						|
											response.KeyState = seaweedkms.KeyStateEnabled
							 | 
						|
										case "Disabled":
							 | 
						|
											response.KeyState = seaweedkms.KeyStateDisabled
							 | 
						|
										case "PendingDeletion":
							 | 
						|
											response.KeyState = seaweedkms.KeyStatePendingDeletion
							 | 
						|
										case "Unavailable":
							 | 
						|
											response.KeyState = seaweedkms.KeyStateUnavailable
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Convert AWS origin to our enum
							 | 
						|
									if metadata.Origin != nil {
							 | 
						|
										switch *metadata.Origin {
							 | 
						|
										case "AWS_KMS":
							 | 
						|
											response.Origin = seaweedkms.KeyOriginAWS
							 | 
						|
										case "EXTERNAL":
							 | 
						|
											response.Origin = seaweedkms.KeyOriginExternal
							 | 
						|
										case "AWS_CLOUDHSM":
							 | 
						|
											response.Origin = seaweedkms.KeyOriginCloudHSM
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									glog.V(4).Infof("AWS KMS: Described key %s (actual: %s, state: %s)", req.KeyID, response.KeyID, response.KeyState)
							 | 
						|
									return response, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GetKeyID resolves a key alias or ARN to the actual key ID
							 | 
						|
								func (p *AWSKMSProvider) GetKeyID(ctx context.Context, keyIdentifier string) (string, error) {
							 | 
						|
									if keyIdentifier == "" {
							 | 
						|
										return "", fmt.Errorf("key identifier cannot be empty")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Use DescribeKey to resolve the key identifier
							 | 
						|
									descReq := &seaweedkms.DescribeKeyRequest{KeyID: keyIdentifier}
							 | 
						|
									descResp, err := p.DescribeKey(ctx, descReq)
							 | 
						|
									if err != nil {
							 | 
						|
										return "", fmt.Errorf("failed to resolve key identifier %s: %w", keyIdentifier, err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return descResp.KeyID, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Close cleans up any resources used by the provider
							 | 
						|
								func (p *AWSKMSProvider) Close() error {
							 | 
						|
									// AWS SDK clients don't require explicit cleanup
							 | 
						|
									glog.V(2).Infof("AWS KMS provider closed")
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// convertAWSError converts AWS KMS errors to our standard KMS errors
							 | 
						|
								func (p *AWSKMSProvider) convertAWSError(err error, keyID string) error {
							 | 
						|
									if awsErr, ok := err.(awserr.Error); ok {
							 | 
						|
										switch awsErr.Code() {
							 | 
						|
										case "NotFoundException":
							 | 
						|
											return &seaweedkms.KMSError{
							 | 
						|
												Code:    seaweedkms.ErrCodeNotFoundException,
							 | 
						|
												Message: awsErr.Message(),
							 | 
						|
												KeyID:   keyID,
							 | 
						|
											}
							 | 
						|
										case "DisabledException", "KeyUnavailableException":
							 | 
						|
											return &seaweedkms.KMSError{
							 | 
						|
												Code:    seaweedkms.ErrCodeKeyUnavailable,
							 | 
						|
												Message: awsErr.Message(),
							 | 
						|
												KeyID:   keyID,
							 | 
						|
											}
							 | 
						|
										case "AccessDeniedException":
							 | 
						|
											return &seaweedkms.KMSError{
							 | 
						|
												Code:    seaweedkms.ErrCodeAccessDenied,
							 | 
						|
												Message: awsErr.Message(),
							 | 
						|
												KeyID:   keyID,
							 | 
						|
											}
							 | 
						|
										case "InvalidKeyUsageException":
							 | 
						|
											return &seaweedkms.KMSError{
							 | 
						|
												Code:    seaweedkms.ErrCodeInvalidKeyUsage,
							 | 
						|
												Message: awsErr.Message(),
							 | 
						|
												KeyID:   keyID,
							 | 
						|
											}
							 | 
						|
										case "InvalidCiphertextException":
							 | 
						|
											return &seaweedkms.KMSError{
							 | 
						|
												Code:    seaweedkms.ErrCodeInvalidCiphertext,
							 | 
						|
												Message: awsErr.Message(),
							 | 
						|
												KeyID:   keyID,
							 | 
						|
											}
							 | 
						|
										case "KMSInternalException", "KMSInvalidStateException":
							 | 
						|
											return &seaweedkms.KMSError{
							 | 
						|
												Code:    seaweedkms.ErrCodeKMSInternalFailure,
							 | 
						|
												Message: awsErr.Message(),
							 | 
						|
												KeyID:   keyID,
							 | 
						|
											}
							 | 
						|
										default:
							 | 
						|
											// For unknown AWS errors, wrap them as internal failures
							 | 
						|
											return &seaweedkms.KMSError{
							 | 
						|
												Code:    seaweedkms.ErrCodeKMSInternalFailure,
							 | 
						|
												Message: fmt.Sprintf("AWS KMS error %s: %s", awsErr.Code(), awsErr.Message()),
							 | 
						|
												KeyID:   keyID,
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// For non-AWS errors (network issues, etc.), wrap as internal failure
							 | 
						|
									return &seaweedkms.KMSError{
							 | 
						|
										Code:    seaweedkms.ErrCodeKMSInternalFailure,
							 | 
						|
										Message: fmt.Sprintf("AWS KMS provider error: %v", err),
							 | 
						|
										KeyID:   keyID,
							 | 
						|
									}
							 | 
						|
								}
							 |