You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

274 lines
6.7 KiB

package kms
import (
"context"
"errors"
"fmt"
"sync"
"github.com/seaweedfs/seaweedfs/weed/util"
)
// ProviderRegistry manages KMS provider implementations
type ProviderRegistry struct {
mu sync.RWMutex
providers map[string]ProviderFactory
instances map[string]KMSProvider
}
// ProviderFactory creates a new KMS provider instance
type ProviderFactory func(config util.Configuration) (KMSProvider, error)
var defaultRegistry = NewProviderRegistry()
// NewProviderRegistry creates a new provider registry
func NewProviderRegistry() *ProviderRegistry {
return &ProviderRegistry{
providers: make(map[string]ProviderFactory),
instances: make(map[string]KMSProvider),
}
}
// RegisterProvider registers a new KMS provider factory
func RegisterProvider(name string, factory ProviderFactory) {
defaultRegistry.RegisterProvider(name, factory)
}
// RegisterProvider registers a new KMS provider factory in this registry
func (r *ProviderRegistry) RegisterProvider(name string, factory ProviderFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.providers[name] = factory
}
// GetProvider returns a KMS provider instance, creating it if necessary
func GetProvider(name string, config util.Configuration) (KMSProvider, error) {
return defaultRegistry.GetProvider(name, config)
}
// GetProvider returns a KMS provider instance, creating it if necessary
func (r *ProviderRegistry) GetProvider(name string, config util.Configuration) (KMSProvider, error) {
r.mu.Lock()
defer r.mu.Unlock()
// Return existing instance if available
if instance, exists := r.instances[name]; exists {
return instance, nil
}
// Find the factory
factory, exists := r.providers[name]
if !exists {
return nil, fmt.Errorf("KMS provider '%s' not registered", name)
}
// Create new instance
instance, err := factory(config)
if err != nil {
return nil, fmt.Errorf("failed to create KMS provider '%s': %v", name, err)
}
// Cache the instance
r.instances[name] = instance
return instance, nil
}
// ListProviders returns the names of all registered providers
func ListProviders() []string {
return defaultRegistry.ListProviders()
}
// ListProviders returns the names of all registered providers
func (r *ProviderRegistry) ListProviders() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.providers))
for name := range r.providers {
names = append(names, name)
}
return names
}
// CloseAll closes all provider instances
func CloseAll() error {
return defaultRegistry.CloseAll()
}
// CloseAll closes all provider instances in this registry
func (r *ProviderRegistry) CloseAll() error {
r.mu.Lock()
defer r.mu.Unlock()
var allErrors []error
for name, instance := range r.instances {
if err := instance.Close(); err != nil {
allErrors = append(allErrors, fmt.Errorf("failed to close KMS provider '%s': %w", name, err))
}
}
// Clear the instances map
r.instances = make(map[string]KMSProvider)
return errors.Join(allErrors...)
}
// KMSConfig represents the configuration for KMS
type KMSConfig struct {
Provider string `json:"provider"` // KMS provider name
Config map[string]interface{} `json:"config"` // Provider-specific configuration
}
// configAdapter adapts KMSConfig.Config to util.Configuration interface
type configAdapter struct {
config map[string]interface{}
}
func (c *configAdapter) GetString(key string) string {
if val, ok := c.config[key]; ok {
if str, ok := val.(string); ok {
return str
}
}
return ""
}
func (c *configAdapter) GetBool(key string) bool {
if val, ok := c.config[key]; ok {
if b, ok := val.(bool); ok {
return b
}
}
return false
}
func (c *configAdapter) GetInt(key string) int {
if val, ok := c.config[key]; ok {
if i, ok := val.(int); ok {
return i
}
if f, ok := val.(float64); ok {
return int(f)
}
}
return 0
}
func (c *configAdapter) GetStringSlice(key string) []string {
if val, ok := c.config[key]; ok {
if slice, ok := val.([]string); ok {
return slice
}
if interfaceSlice, ok := val.([]interface{}); ok {
result := make([]string, len(interfaceSlice))
for i, v := range interfaceSlice {
if str, ok := v.(string); ok {
result[i] = str
}
}
return result
}
}
return nil
}
func (c *configAdapter) SetDefault(key string, value interface{}) {
if c.config == nil {
c.config = make(map[string]interface{})
}
if _, exists := c.config[key]; !exists {
c.config[key] = value
}
}
// GlobalKMSProvider holds the global KMS provider instance
var (
globalKMSProvider KMSProvider
globalKMSMutex sync.RWMutex
)
// InitializeGlobalKMS initializes the global KMS provider
func InitializeGlobalKMS(config *KMSConfig) error {
if config == nil || config.Provider == "" {
return fmt.Errorf("KMS configuration is required")
}
// Adapt the config to util.Configuration interface
var providerConfig util.Configuration
if config.Config != nil {
providerConfig = &configAdapter{config: config.Config}
}
provider, err := GetProvider(config.Provider, providerConfig)
if err != nil {
return err
}
globalKMSMutex.Lock()
defer globalKMSMutex.Unlock()
// Close existing provider if any
if globalKMSProvider != nil {
globalKMSProvider.Close()
}
globalKMSProvider = provider
return nil
}
// GetGlobalKMS returns the global KMS provider
func GetGlobalKMS() KMSProvider {
globalKMSMutex.RLock()
defer globalKMSMutex.RUnlock()
return globalKMSProvider
}
// IsKMSEnabled returns true if KMS is enabled globally
func IsKMSEnabled() bool {
return GetGlobalKMS() != nil
}
// WithKMSProvider is a helper function to execute code with a KMS provider
func WithKMSProvider(name string, config util.Configuration, fn func(KMSProvider) error) error {
provider, err := GetProvider(name, config)
if err != nil {
return err
}
return fn(provider)
}
// TestKMSConnection tests the connection to a KMS provider
func TestKMSConnection(ctx context.Context, provider KMSProvider, testKeyID string) error {
if provider == nil {
return fmt.Errorf("KMS provider is nil")
}
// Try to describe a test key to verify connectivity
_, err := provider.DescribeKey(ctx, &DescribeKeyRequest{
KeyID: testKeyID,
})
if err != nil {
// If the key doesn't exist, that's still a successful connection test
if kmsErr, ok := err.(*KMSError); ok && kmsErr.Code == ErrCodeNotFoundException {
return nil
}
return fmt.Errorf("KMS connection test failed: %v", err)
}
return nil
}
// SetGlobalKMSForTesting sets the global KMS provider for testing purposes
// This should only be used in tests
func SetGlobalKMSForTesting(provider KMSProvider) {
globalKMSMutex.Lock()
defer globalKMSMutex.Unlock()
// Close existing provider if any
if globalKMSProvider != nil {
globalKMSProvider.Close()
}
globalKMSProvider = provider
}