You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							145 lines
						
					
					
						
							3.8 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							145 lines
						
					
					
						
							3.8 KiB
						
					
					
				| package kms | |
| 
 | |
| import ( | |
| 	"context" | |
| 	"errors" | |
| 	"fmt" | |
| 	"sync" | |
| 
 | |
| 	"github.com/seaweedfs/seaweedfs/weed/util" | |
| ) | |
| 
 | |
| // ProviderRegistry manages KMS provider implementations | |
| type ProviderRegistry struct { | |
| 	mu        sync.RWMutex | |
| 	providers map[string]ProviderFactory | |
| 	instances map[string]KMSProvider | |
| } | |
| 
 | |
| // ProviderFactory creates a new KMS provider instance | |
| type ProviderFactory func(config util.Configuration) (KMSProvider, error) | |
| 
 | |
| var defaultRegistry = NewProviderRegistry() | |
| 
 | |
| // NewProviderRegistry creates a new provider registry | |
| func NewProviderRegistry() *ProviderRegistry { | |
| 	return &ProviderRegistry{ | |
| 		providers: make(map[string]ProviderFactory), | |
| 		instances: make(map[string]KMSProvider), | |
| 	} | |
| } | |
| 
 | |
| // RegisterProvider registers a new KMS provider factory | |
| func RegisterProvider(name string, factory ProviderFactory) { | |
| 	defaultRegistry.RegisterProvider(name, factory) | |
| } | |
| 
 | |
| // RegisterProvider registers a new KMS provider factory in this registry | |
| func (r *ProviderRegistry) RegisterProvider(name string, factory ProviderFactory) { | |
| 	r.mu.Lock() | |
| 	defer r.mu.Unlock() | |
| 	r.providers[name] = factory | |
| } | |
| 
 | |
| // GetProvider returns a KMS provider instance, creating it if necessary | |
| func GetProvider(name string, config util.Configuration) (KMSProvider, error) { | |
| 	return defaultRegistry.GetProvider(name, config) | |
| } | |
| 
 | |
| // GetProvider returns a KMS provider instance, creating it if necessary | |
| func (r *ProviderRegistry) GetProvider(name string, config util.Configuration) (KMSProvider, error) { | |
| 	r.mu.Lock() | |
| 	defer r.mu.Unlock() | |
| 
 | |
| 	// Return existing instance if available | |
| 	if instance, exists := r.instances[name]; exists { | |
| 		return instance, nil | |
| 	} | |
| 
 | |
| 	// Find the factory | |
| 	factory, exists := r.providers[name] | |
| 	if !exists { | |
| 		return nil, fmt.Errorf("KMS provider '%s' not registered", name) | |
| 	} | |
| 
 | |
| 	// Create new instance | |
| 	instance, err := factory(config) | |
| 	if err != nil { | |
| 		return nil, fmt.Errorf("failed to create KMS provider '%s': %v", name, err) | |
| 	} | |
| 
 | |
| 	// Cache the instance | |
| 	r.instances[name] = instance | |
| 	return instance, nil | |
| } | |
| 
 | |
| // ListProviders returns the names of all registered providers | |
| func ListProviders() []string { | |
| 	return defaultRegistry.ListProviders() | |
| } | |
| 
 | |
| // ListProviders returns the names of all registered providers | |
| func (r *ProviderRegistry) ListProviders() []string { | |
| 	r.mu.RLock() | |
| 	defer r.mu.RUnlock() | |
| 
 | |
| 	names := make([]string, 0, len(r.providers)) | |
| 	for name := range r.providers { | |
| 		names = append(names, name) | |
| 	} | |
| 	return names | |
| } | |
| 
 | |
| // CloseAll closes all provider instances | |
| func CloseAll() error { | |
| 	return defaultRegistry.CloseAll() | |
| } | |
| 
 | |
| // CloseAll closes all provider instances in this registry | |
| func (r *ProviderRegistry) CloseAll() error { | |
| 	r.mu.Lock() | |
| 	defer r.mu.Unlock() | |
| 
 | |
| 	var allErrors []error | |
| 	for name, instance := range r.instances { | |
| 		if err := instance.Close(); err != nil { | |
| 			allErrors = append(allErrors, fmt.Errorf("failed to close KMS provider '%s': %w", name, err)) | |
| 		} | |
| 	} | |
| 
 | |
| 	// Clear the instances map | |
| 	r.instances = make(map[string]KMSProvider) | |
| 
 | |
| 	return errors.Join(allErrors...) | |
| } | |
| 
 | |
| // WithKMSProvider is a helper function to execute code with a KMS provider | |
| func WithKMSProvider(name string, config util.Configuration, fn func(KMSProvider) error) error { | |
| 	provider, err := GetProvider(name, config) | |
| 	if err != nil { | |
| 		return err | |
| 	} | |
| 	return fn(provider) | |
| } | |
| 
 | |
| // TestKMSConnection tests the connection to a KMS provider | |
| func TestKMSConnection(ctx context.Context, provider KMSProvider, testKeyID string) error { | |
| 	if provider == nil { | |
| 		return fmt.Errorf("KMS provider is nil") | |
| 	} | |
| 
 | |
| 	// Try to describe a test key to verify connectivity | |
| 	_, err := provider.DescribeKey(ctx, &DescribeKeyRequest{ | |
| 		KeyID: testKeyID, | |
| 	}) | |
| 
 | |
| 	if err != nil { | |
| 		// If the key doesn't exist, that's still a successful connection test | |
| 		if kmsErr, ok := err.(*KMSError); ok && kmsErr.Code == ErrCodeNotFoundException { | |
| 			return nil | |
| 		} | |
| 		return fmt.Errorf("KMS connection test failed: %v", err) | |
| 	} | |
| 
 | |
| 	return nil | |
| }
 |