You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							389 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							389 lines
						
					
					
						
							12 KiB
						
					
					
				| package aws | |
| 
 | |
| import ( | |
| 	"context" | |
| 	"encoding/base64" | |
| 	"fmt" | |
| 	"net/http" | |
| 	"strings" | |
| 	"time" | |
| 
 | |
| 	"github.com/aws/aws-sdk-go/aws" | |
| 	"github.com/aws/aws-sdk-go/aws/awserr" | |
| 	"github.com/aws/aws-sdk-go/aws/credentials" | |
| 	"github.com/aws/aws-sdk-go/aws/session" | |
| 	"github.com/aws/aws-sdk-go/service/kms" | |
| 
 | |
| 	"github.com/seaweedfs/seaweedfs/weed/glog" | |
| 	seaweedkms "github.com/seaweedfs/seaweedfs/weed/kms" | |
| 	"github.com/seaweedfs/seaweedfs/weed/util" | |
| ) | |
| 
 | |
| func init() { | |
| 	// Register the AWS KMS provider | |
| 	seaweedkms.RegisterProvider("aws", NewAWSKMSProvider) | |
| } | |
| 
 | |
| // AWSKMSProvider implements the KMSProvider interface using AWS KMS | |
| type AWSKMSProvider struct { | |
| 	client   *kms.KMS | |
| 	region   string | |
| 	endpoint string // For testing with LocalStack or custom endpoints | |
| } | |
| 
 | |
| // AWSKMSConfig contains configuration for the AWS KMS provider | |
| type AWSKMSConfig struct { | |
| 	Region         string `json:"region"`          // AWS region (e.g., "us-east-1") | |
| 	AccessKey      string `json:"access_key"`      // AWS access key (optional if using IAM roles) | |
| 	SecretKey      string `json:"secret_key"`      // AWS secret key (optional if using IAM roles) | |
| 	SessionToken   string `json:"session_token"`   // AWS session token (optional for STS) | |
| 	Endpoint       string `json:"endpoint"`        // Custom endpoint (optional, for LocalStack/testing) | |
| 	Profile        string `json:"profile"`         // AWS profile name (optional) | |
| 	RoleARN        string `json:"role_arn"`        // IAM role ARN to assume (optional) | |
| 	ExternalID     string `json:"external_id"`     // External ID for role assumption (optional) | |
| 	ConnectTimeout int    `json:"connect_timeout"` // Connection timeout in seconds (default: 10) | |
| 	RequestTimeout int    `json:"request_timeout"` // Request timeout in seconds (default: 30) | |
| 	MaxRetries     int    `json:"max_retries"`     // Maximum number of retries (default: 3) | |
| } | |
| 
 | |
| // NewAWSKMSProvider creates a new AWS KMS provider | |
| func NewAWSKMSProvider(config util.Configuration) (seaweedkms.KMSProvider, error) { | |
| 	if config == nil { | |
| 		return nil, fmt.Errorf("AWS KMS configuration is required") | |
| 	} | |
| 
 | |
| 	// Extract configuration | |
| 	region := config.GetString("region") | |
| 	if region == "" { | |
| 		region = "us-east-1" // Default region | |
| 	} | |
| 
 | |
| 	accessKey := config.GetString("access_key") | |
| 	secretKey := config.GetString("secret_key") | |
| 	sessionToken := config.GetString("session_token") | |
| 	endpoint := config.GetString("endpoint") | |
| 	profile := config.GetString("profile") | |
| 
 | |
| 	// Timeouts and retries | |
| 	connectTimeout := config.GetInt("connect_timeout") | |
| 	if connectTimeout == 0 { | |
| 		connectTimeout = 10 // Default 10 seconds | |
| 	} | |
| 
 | |
| 	requestTimeout := config.GetInt("request_timeout") | |
| 	if requestTimeout == 0 { | |
| 		requestTimeout = 30 // Default 30 seconds | |
| 	} | |
| 
 | |
| 	maxRetries := config.GetInt("max_retries") | |
| 	if maxRetries == 0 { | |
| 		maxRetries = 3 // Default 3 retries | |
| 	} | |
| 
 | |
| 	// Create AWS session | |
| 	awsConfig := &aws.Config{ | |
| 		Region:     aws.String(region), | |
| 		MaxRetries: aws.Int(maxRetries), | |
| 		HTTPClient: &http.Client{ | |
| 			Timeout: time.Duration(requestTimeout) * time.Second, | |
| 		}, | |
| 	} | |
| 
 | |
| 	// Set custom endpoint if provided (for testing with LocalStack) | |
| 	if endpoint != "" { | |
| 		awsConfig.Endpoint = aws.String(endpoint) | |
| 		awsConfig.DisableSSL = aws.Bool(strings.HasPrefix(endpoint, "http://")) | |
| 	} | |
| 
 | |
| 	// Configure credentials | |
| 	if accessKey != "" && secretKey != "" { | |
| 		awsConfig.Credentials = credentials.NewStaticCredentials(accessKey, secretKey, sessionToken) | |
| 	} else if profile != "" { | |
| 		awsConfig.Credentials = credentials.NewSharedCredentials("", profile) | |
| 	} | |
| 	// If neither are provided, use default credential chain (IAM roles, etc.) | |
|  | |
| 	sess, err := session.NewSession(awsConfig) | |
| 	if err != nil { | |
| 		return nil, fmt.Errorf("failed to create AWS session: %w", err) | |
| 	} | |
| 
 | |
| 	provider := &AWSKMSProvider{ | |
| 		client:   kms.New(sess), | |
| 		region:   region, | |
| 		endpoint: endpoint, | |
| 	} | |
| 
 | |
| 	glog.V(1).Infof("AWS KMS provider initialized for region %s", region) | |
| 	return provider, nil | |
| } | |
| 
 | |
| // GenerateDataKey generates a new data encryption key using AWS KMS | |
| func (p *AWSKMSProvider) GenerateDataKey(ctx context.Context, req *seaweedkms.GenerateDataKeyRequest) (*seaweedkms.GenerateDataKeyResponse, error) { | |
| 	if req == nil { | |
| 		return nil, fmt.Errorf("GenerateDataKeyRequest cannot be nil") | |
| 	} | |
| 
 | |
| 	if req.KeyID == "" { | |
| 		return nil, fmt.Errorf("KeyID is required") | |
| 	} | |
| 
 | |
| 	// Validate key spec | |
| 	var keySpec string | |
| 	switch req.KeySpec { | |
| 	case seaweedkms.KeySpecAES256: | |
| 		keySpec = "AES_256" | |
| 	default: | |
| 		return nil, fmt.Errorf("unsupported key spec: %s", req.KeySpec) | |
| 	} | |
| 
 | |
| 	// Build KMS request | |
| 	kmsReq := &kms.GenerateDataKeyInput{ | |
| 		KeyId:   aws.String(req.KeyID), | |
| 		KeySpec: aws.String(keySpec), | |
| 	} | |
| 
 | |
| 	// Add encryption context if provided | |
| 	if len(req.EncryptionContext) > 0 { | |
| 		kmsReq.EncryptionContext = aws.StringMap(req.EncryptionContext) | |
| 	} | |
| 
 | |
| 	// Call AWS KMS | |
| 	glog.V(4).Infof("AWS KMS: Generating data key for key ID %s", req.KeyID) | |
| 	result, err := p.client.GenerateDataKeyWithContext(ctx, kmsReq) | |
| 	if err != nil { | |
| 		return nil, p.convertAWSError(err, req.KeyID) | |
| 	} | |
| 
 | |
| 	// Extract the actual key ID from the response (resolves aliases) | |
| 	actualKeyID := "" | |
| 	if result.KeyId != nil { | |
| 		actualKeyID = *result.KeyId | |
| 	} | |
| 
 | |
| 	// Create standardized envelope format for consistent API behavior | |
| 	envelopeBlob, err := seaweedkms.CreateEnvelope("aws", actualKeyID, base64.StdEncoding.EncodeToString(result.CiphertextBlob), nil) | |
| 	if err != nil { | |
| 		return nil, fmt.Errorf("failed to create ciphertext envelope: %w", err) | |
| 	} | |
| 
 | |
| 	response := &seaweedkms.GenerateDataKeyResponse{ | |
| 		KeyID:          actualKeyID, | |
| 		Plaintext:      result.Plaintext, | |
| 		CiphertextBlob: envelopeBlob, // Store in standardized envelope format | |
| 	} | |
| 
 | |
| 	glog.V(4).Infof("AWS KMS: Generated data key for key ID %s (actual: %s)", req.KeyID, actualKeyID) | |
| 	return response, nil | |
| } | |
| 
 | |
| // Decrypt decrypts an encrypted data key using AWS KMS | |
| func (p *AWSKMSProvider) Decrypt(ctx context.Context, req *seaweedkms.DecryptRequest) (*seaweedkms.DecryptResponse, error) { | |
| 	if req == nil { | |
| 		return nil, fmt.Errorf("DecryptRequest cannot be nil") | |
| 	} | |
| 
 | |
| 	if len(req.CiphertextBlob) == 0 { | |
| 		return nil, fmt.Errorf("CiphertextBlob cannot be empty") | |
| 	} | |
| 
 | |
| 	// Parse the ciphertext envelope to extract key information | |
| 	envelope, err := seaweedkms.ParseEnvelope(req.CiphertextBlob) | |
| 	if err != nil { | |
| 		return nil, fmt.Errorf("failed to parse ciphertext envelope: %w", err) | |
| 	} | |
| 
 | |
| 	if envelope.Provider != "aws" { | |
| 		return nil, fmt.Errorf("invalid provider in envelope: expected 'aws', got '%s'", envelope.Provider) | |
| 	} | |
| 
 | |
| 	ciphertext, err := base64.StdEncoding.DecodeString(envelope.Ciphertext) | |
| 	if err != nil { | |
| 		return nil, fmt.Errorf("failed to decode ciphertext from envelope: %w", err) | |
| 	} | |
| 
 | |
| 	// Build KMS request | |
| 	kmsReq := &kms.DecryptInput{ | |
| 		CiphertextBlob: ciphertext, | |
| 	} | |
| 
 | |
| 	// Add encryption context if provided | |
| 	if len(req.EncryptionContext) > 0 { | |
| 		kmsReq.EncryptionContext = aws.StringMap(req.EncryptionContext) | |
| 	} | |
| 
 | |
| 	// Call AWS KMS | |
| 	glog.V(4).Infof("AWS KMS: Decrypting data key (blob size: %d bytes)", len(req.CiphertextBlob)) | |
| 	result, err := p.client.DecryptWithContext(ctx, kmsReq) | |
| 	if err != nil { | |
| 		return nil, p.convertAWSError(err, "") | |
| 	} | |
| 
 | |
| 	// Extract the key ID that was used for encryption | |
| 	keyID := "" | |
| 	if result.KeyId != nil { | |
| 		keyID = *result.KeyId | |
| 	} | |
| 
 | |
| 	response := &seaweedkms.DecryptResponse{ | |
| 		KeyID:     keyID, | |
| 		Plaintext: result.Plaintext, | |
| 	} | |
| 
 | |
| 	glog.V(4).Infof("AWS KMS: Decrypted data key using key ID %s", keyID) | |
| 	return response, nil | |
| } | |
| 
 | |
| // DescribeKey validates that a key exists and returns its metadata | |
| func (p *AWSKMSProvider) DescribeKey(ctx context.Context, req *seaweedkms.DescribeKeyRequest) (*seaweedkms.DescribeKeyResponse, error) { | |
| 	if req == nil { | |
| 		return nil, fmt.Errorf("DescribeKeyRequest cannot be nil") | |
| 	} | |
| 
 | |
| 	if req.KeyID == "" { | |
| 		return nil, fmt.Errorf("KeyID is required") | |
| 	} | |
| 
 | |
| 	// Build KMS request | |
| 	kmsReq := &kms.DescribeKeyInput{ | |
| 		KeyId: aws.String(req.KeyID), | |
| 	} | |
| 
 | |
| 	// Call AWS KMS | |
| 	glog.V(4).Infof("AWS KMS: Describing key %s", req.KeyID) | |
| 	result, err := p.client.DescribeKeyWithContext(ctx, kmsReq) | |
| 	if err != nil { | |
| 		return nil, p.convertAWSError(err, req.KeyID) | |
| 	} | |
| 
 | |
| 	if result.KeyMetadata == nil { | |
| 		return nil, fmt.Errorf("no key metadata returned from AWS KMS") | |
| 	} | |
| 
 | |
| 	metadata := result.KeyMetadata | |
| 	response := &seaweedkms.DescribeKeyResponse{ | |
| 		KeyID:       aws.StringValue(metadata.KeyId), | |
| 		ARN:         aws.StringValue(metadata.Arn), | |
| 		Description: aws.StringValue(metadata.Description), | |
| 	} | |
| 
 | |
| 	// Convert AWS key usage to our enum | |
| 	if metadata.KeyUsage != nil { | |
| 		switch *metadata.KeyUsage { | |
| 		case "ENCRYPT_DECRYPT": | |
| 			response.KeyUsage = seaweedkms.KeyUsageEncryptDecrypt | |
| 		case "GENERATE_DATA_KEY": | |
| 			response.KeyUsage = seaweedkms.KeyUsageGenerateDataKey | |
| 		} | |
| 	} | |
| 
 | |
| 	// Convert AWS key state to our enum | |
| 	if metadata.KeyState != nil { | |
| 		switch *metadata.KeyState { | |
| 		case "Enabled": | |
| 			response.KeyState = seaweedkms.KeyStateEnabled | |
| 		case "Disabled": | |
| 			response.KeyState = seaweedkms.KeyStateDisabled | |
| 		case "PendingDeletion": | |
| 			response.KeyState = seaweedkms.KeyStatePendingDeletion | |
| 		case "Unavailable": | |
| 			response.KeyState = seaweedkms.KeyStateUnavailable | |
| 		} | |
| 	} | |
| 
 | |
| 	// Convert AWS origin to our enum | |
| 	if metadata.Origin != nil { | |
| 		switch *metadata.Origin { | |
| 		case "AWS_KMS": | |
| 			response.Origin = seaweedkms.KeyOriginAWS | |
| 		case "EXTERNAL": | |
| 			response.Origin = seaweedkms.KeyOriginExternal | |
| 		case "AWS_CLOUDHSM": | |
| 			response.Origin = seaweedkms.KeyOriginCloudHSM | |
| 		} | |
| 	} | |
| 
 | |
| 	glog.V(4).Infof("AWS KMS: Described key %s (actual: %s, state: %s)", req.KeyID, response.KeyID, response.KeyState) | |
| 	return response, nil | |
| } | |
| 
 | |
| // GetKeyID resolves a key alias or ARN to the actual key ID | |
| func (p *AWSKMSProvider) GetKeyID(ctx context.Context, keyIdentifier string) (string, error) { | |
| 	if keyIdentifier == "" { | |
| 		return "", fmt.Errorf("key identifier cannot be empty") | |
| 	} | |
| 
 | |
| 	// Use DescribeKey to resolve the key identifier | |
| 	descReq := &seaweedkms.DescribeKeyRequest{KeyID: keyIdentifier} | |
| 	descResp, err := p.DescribeKey(ctx, descReq) | |
| 	if err != nil { | |
| 		return "", fmt.Errorf("failed to resolve key identifier %s: %w", keyIdentifier, err) | |
| 	} | |
| 
 | |
| 	return descResp.KeyID, nil | |
| } | |
| 
 | |
| // Close cleans up any resources used by the provider | |
| func (p *AWSKMSProvider) Close() error { | |
| 	// AWS SDK clients don't require explicit cleanup | |
| 	glog.V(2).Infof("AWS KMS provider closed") | |
| 	return nil | |
| } | |
| 
 | |
| // convertAWSError converts AWS KMS errors to our standard KMS errors | |
| func (p *AWSKMSProvider) convertAWSError(err error, keyID string) error { | |
| 	if awsErr, ok := err.(awserr.Error); ok { | |
| 		switch awsErr.Code() { | |
| 		case "NotFoundException": | |
| 			return &seaweedkms.KMSError{ | |
| 				Code:    seaweedkms.ErrCodeNotFoundException, | |
| 				Message: awsErr.Message(), | |
| 				KeyID:   keyID, | |
| 			} | |
| 		case "DisabledException", "KeyUnavailableException": | |
| 			return &seaweedkms.KMSError{ | |
| 				Code:    seaweedkms.ErrCodeKeyUnavailable, | |
| 				Message: awsErr.Message(), | |
| 				KeyID:   keyID, | |
| 			} | |
| 		case "AccessDeniedException": | |
| 			return &seaweedkms.KMSError{ | |
| 				Code:    seaweedkms.ErrCodeAccessDenied, | |
| 				Message: awsErr.Message(), | |
| 				KeyID:   keyID, | |
| 			} | |
| 		case "InvalidKeyUsageException": | |
| 			return &seaweedkms.KMSError{ | |
| 				Code:    seaweedkms.ErrCodeInvalidKeyUsage, | |
| 				Message: awsErr.Message(), | |
| 				KeyID:   keyID, | |
| 			} | |
| 		case "InvalidCiphertextException": | |
| 			return &seaweedkms.KMSError{ | |
| 				Code:    seaweedkms.ErrCodeInvalidCiphertext, | |
| 				Message: awsErr.Message(), | |
| 				KeyID:   keyID, | |
| 			} | |
| 		case "KMSInternalException", "KMSInvalidStateException": | |
| 			return &seaweedkms.KMSError{ | |
| 				Code:    seaweedkms.ErrCodeKMSInternalFailure, | |
| 				Message: awsErr.Message(), | |
| 				KeyID:   keyID, | |
| 			} | |
| 		default: | |
| 			// For unknown AWS errors, wrap them as internal failures | |
| 			return &seaweedkms.KMSError{ | |
| 				Code:    seaweedkms.ErrCodeKMSInternalFailure, | |
| 				Message: fmt.Sprintf("AWS KMS error %s: %s", awsErr.Code(), awsErr.Message()), | |
| 				KeyID:   keyID, | |
| 			} | |
| 		} | |
| 	} | |
| 
 | |
| 	// For non-AWS errors (network issues, etc.), wrap as internal failure | |
| 	return &seaweedkms.KMSError{ | |
| 		Code:    seaweedkms.ErrCodeKMSInternalFailure, | |
| 		Message: fmt.Sprintf("AWS KMS provider error: %v", err), | |
| 		KeyID:   keyID, | |
| 	} | |
| }
 |