You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
274 lines
6.7 KiB
274 lines
6.7 KiB
package kms
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/seaweedfs/seaweedfs/weed/util"
|
|
)
|
|
|
|
// ProviderRegistry manages KMS provider implementations
|
|
type ProviderRegistry struct {
|
|
mu sync.RWMutex
|
|
providers map[string]ProviderFactory
|
|
instances map[string]KMSProvider
|
|
}
|
|
|
|
// ProviderFactory creates a new KMS provider instance
|
|
type ProviderFactory func(config util.Configuration) (KMSProvider, error)
|
|
|
|
var defaultRegistry = NewProviderRegistry()
|
|
|
|
// NewProviderRegistry creates a new provider registry
|
|
func NewProviderRegistry() *ProviderRegistry {
|
|
return &ProviderRegistry{
|
|
providers: make(map[string]ProviderFactory),
|
|
instances: make(map[string]KMSProvider),
|
|
}
|
|
}
|
|
|
|
// RegisterProvider registers a new KMS provider factory
|
|
func RegisterProvider(name string, factory ProviderFactory) {
|
|
defaultRegistry.RegisterProvider(name, factory)
|
|
}
|
|
|
|
// RegisterProvider registers a new KMS provider factory in this registry
|
|
func (r *ProviderRegistry) RegisterProvider(name string, factory ProviderFactory) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.providers[name] = factory
|
|
}
|
|
|
|
// GetProvider returns a KMS provider instance, creating it if necessary
|
|
func GetProvider(name string, config util.Configuration) (KMSProvider, error) {
|
|
return defaultRegistry.GetProvider(name, config)
|
|
}
|
|
|
|
// GetProvider returns a KMS provider instance, creating it if necessary
|
|
func (r *ProviderRegistry) GetProvider(name string, config util.Configuration) (KMSProvider, error) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
// Return existing instance if available
|
|
if instance, exists := r.instances[name]; exists {
|
|
return instance, nil
|
|
}
|
|
|
|
// Find the factory
|
|
factory, exists := r.providers[name]
|
|
if !exists {
|
|
return nil, fmt.Errorf("KMS provider '%s' not registered", name)
|
|
}
|
|
|
|
// Create new instance
|
|
instance, err := factory(config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create KMS provider '%s': %v", name, err)
|
|
}
|
|
|
|
// Cache the instance
|
|
r.instances[name] = instance
|
|
return instance, nil
|
|
}
|
|
|
|
// ListProviders returns the names of all registered providers
|
|
func ListProviders() []string {
|
|
return defaultRegistry.ListProviders()
|
|
}
|
|
|
|
// ListProviders returns the names of all registered providers
|
|
func (r *ProviderRegistry) ListProviders() []string {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
names := make([]string, 0, len(r.providers))
|
|
for name := range r.providers {
|
|
names = append(names, name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// CloseAll closes all provider instances
|
|
func CloseAll() error {
|
|
return defaultRegistry.CloseAll()
|
|
}
|
|
|
|
// CloseAll closes all provider instances in this registry
|
|
func (r *ProviderRegistry) CloseAll() error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
var allErrors []error
|
|
for name, instance := range r.instances {
|
|
if err := instance.Close(); err != nil {
|
|
allErrors = append(allErrors, fmt.Errorf("failed to close KMS provider '%s': %w", name, err))
|
|
}
|
|
}
|
|
|
|
// Clear the instances map
|
|
r.instances = make(map[string]KMSProvider)
|
|
|
|
return errors.Join(allErrors...)
|
|
}
|
|
|
|
// KMSConfig represents the configuration for KMS
|
|
type KMSConfig struct {
|
|
Provider string `json:"provider"` // KMS provider name
|
|
Config map[string]interface{} `json:"config"` // Provider-specific configuration
|
|
}
|
|
|
|
// configAdapter adapts KMSConfig.Config to util.Configuration interface
|
|
type configAdapter struct {
|
|
config map[string]interface{}
|
|
}
|
|
|
|
func (c *configAdapter) GetString(key string) string {
|
|
if val, ok := c.config[key]; ok {
|
|
if str, ok := val.(string); ok {
|
|
return str
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (c *configAdapter) GetBool(key string) bool {
|
|
if val, ok := c.config[key]; ok {
|
|
if b, ok := val.(bool); ok {
|
|
return b
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *configAdapter) GetInt(key string) int {
|
|
if val, ok := c.config[key]; ok {
|
|
if i, ok := val.(int); ok {
|
|
return i
|
|
}
|
|
if f, ok := val.(float64); ok {
|
|
return int(f)
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (c *configAdapter) GetStringSlice(key string) []string {
|
|
if val, ok := c.config[key]; ok {
|
|
if slice, ok := val.([]string); ok {
|
|
return slice
|
|
}
|
|
if interfaceSlice, ok := val.([]interface{}); ok {
|
|
result := make([]string, len(interfaceSlice))
|
|
for i, v := range interfaceSlice {
|
|
if str, ok := v.(string); ok {
|
|
result[i] = str
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *configAdapter) SetDefault(key string, value interface{}) {
|
|
if c.config == nil {
|
|
c.config = make(map[string]interface{})
|
|
}
|
|
if _, exists := c.config[key]; !exists {
|
|
c.config[key] = value
|
|
}
|
|
}
|
|
|
|
// GlobalKMSProvider holds the global KMS provider instance
|
|
var (
|
|
globalKMSProvider KMSProvider
|
|
globalKMSMutex sync.RWMutex
|
|
)
|
|
|
|
// InitializeGlobalKMS initializes the global KMS provider
|
|
func InitializeGlobalKMS(config *KMSConfig) error {
|
|
if config == nil || config.Provider == "" {
|
|
return fmt.Errorf("KMS configuration is required")
|
|
}
|
|
|
|
// Adapt the config to util.Configuration interface
|
|
var providerConfig util.Configuration
|
|
if config.Config != nil {
|
|
providerConfig = &configAdapter{config: config.Config}
|
|
}
|
|
|
|
provider, err := GetProvider(config.Provider, providerConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
globalKMSMutex.Lock()
|
|
defer globalKMSMutex.Unlock()
|
|
|
|
// Close existing provider if any
|
|
if globalKMSProvider != nil {
|
|
globalKMSProvider.Close()
|
|
}
|
|
|
|
globalKMSProvider = provider
|
|
return nil
|
|
}
|
|
|
|
// GetGlobalKMS returns the global KMS provider
|
|
func GetGlobalKMS() KMSProvider {
|
|
globalKMSMutex.RLock()
|
|
defer globalKMSMutex.RUnlock()
|
|
return globalKMSProvider
|
|
}
|
|
|
|
// IsKMSEnabled returns true if KMS is enabled globally
|
|
func IsKMSEnabled() bool {
|
|
return GetGlobalKMS() != nil
|
|
}
|
|
|
|
// WithKMSProvider is a helper function to execute code with a KMS provider
|
|
func WithKMSProvider(name string, config util.Configuration, fn func(KMSProvider) error) error {
|
|
provider, err := GetProvider(name, config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return fn(provider)
|
|
}
|
|
|
|
// TestKMSConnection tests the connection to a KMS provider
|
|
func TestKMSConnection(ctx context.Context, provider KMSProvider, testKeyID string) error {
|
|
if provider == nil {
|
|
return fmt.Errorf("KMS provider is nil")
|
|
}
|
|
|
|
// Try to describe a test key to verify connectivity
|
|
_, err := provider.DescribeKey(ctx, &DescribeKeyRequest{
|
|
KeyID: testKeyID,
|
|
})
|
|
|
|
if err != nil {
|
|
// If the key doesn't exist, that's still a successful connection test
|
|
if kmsErr, ok := err.(*KMSError); ok && kmsErr.Code == ErrCodeNotFoundException {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("KMS connection test failed: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetGlobalKMSForTesting sets the global KMS provider for testing purposes
|
|
// This should only be used in tests
|
|
func SetGlobalKMSForTesting(provider KMSProvider) {
|
|
globalKMSMutex.Lock()
|
|
defer globalKMSMutex.Unlock()
|
|
|
|
// Close existing provider if any
|
|
if globalKMSProvider != nil {
|
|
globalKMSProvider.Close()
|
|
}
|
|
|
|
globalKMSProvider = provider
|
|
}
|