|
|
@ -4,16 +4,17 @@ import ( |
|
|
|
"crypto/tls" |
|
|
|
"crypto/x509" |
|
|
|
"fmt" |
|
|
|
"google.golang.org/grpc/credentials/insecure" |
|
|
|
"google.golang.org/grpc/credentials/tls/certprovider/pemfile" |
|
|
|
"google.golang.org/grpc/security/advancedtls" |
|
|
|
"os" |
|
|
|
"slices" |
|
|
|
"strings" |
|
|
|
"time" |
|
|
|
|
|
|
|
"github.com/seaweedfs/seaweedfs/weed/glog" |
|
|
|
"github.com/seaweedfs/seaweedfs/weed/util" |
|
|
|
"google.golang.org/grpc" |
|
|
|
"google.golang.org/grpc/credentials/insecure" |
|
|
|
"google.golang.org/grpc/credentials/tls/certprovider/pemfile" |
|
|
|
"google.golang.org/grpc/security/advancedtls" |
|
|
|
) |
|
|
|
|
|
|
|
const CredRefreshingInterval = time.Duration(5) * time.Hour |
|
|
@ -62,7 +63,22 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption |
|
|
|
RootProvider: serverRootProvider, |
|
|
|
}, |
|
|
|
RequireClientCert: true, |
|
|
|
VerificationType: advancedtls.CertVerification, |
|
|
|
VerificationType: advancedtls.CertVerification, |
|
|
|
} |
|
|
|
options.MinTLSVersion, err = TlsVersionByName(config.GetString("tls.min_version")) |
|
|
|
if err != nil { |
|
|
|
glog.Warningf("tls min version parse failed, %v", err) |
|
|
|
return nil, nil |
|
|
|
} |
|
|
|
options.MaxTLSVersion, err = TlsVersionByName(config.GetString("tls.max_version")) |
|
|
|
if err != nil { |
|
|
|
glog.Warningf("tls max version parse failed, %v", err) |
|
|
|
return nil, nil |
|
|
|
} |
|
|
|
options.CipherSuites, err = TlsCipherSuiteByNames(config.GetString("tls.cipher_suites")) |
|
|
|
if err != nil { |
|
|
|
glog.Warningf("tls cipher suite parse failed, %v", err) |
|
|
|
return nil, nil |
|
|
|
} |
|
|
|
allowedCommonNames := config.GetString(component + ".allowed_commonNames") |
|
|
|
allowedWildcardDomain := config.GetString("grpc.allowed_wildcard_domain") |
|
|
@ -123,8 +139,8 @@ func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption { |
|
|
|
IdentityProvider: clientProvider, |
|
|
|
}, |
|
|
|
AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) { |
|
|
|
return &advancedtls.PostHandshakeVerificationResults{}, nil |
|
|
|
}, |
|
|
|
return &advancedtls.PostHandshakeVerificationResults{}, nil |
|
|
|
}, |
|
|
|
RootOptions: advancedtls.RootCertificateOptions{ |
|
|
|
RootProvider: clientRootProvider, |
|
|
|
}, |
|
|
@ -166,3 +182,57 @@ func (a Authenticator) Authenticate(params *advancedtls.HandshakeVerificationInf |
|
|
|
glog.Error(err) |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
func FixTlsConfig(viper *util.ViperProxy, config *tls.Config) error { |
|
|
|
var err error |
|
|
|
config.MinVersion, err = TlsVersionByName(viper.GetString("tls.min_version")) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
config.MaxVersion, err = TlsVersionByName(viper.GetString("tls.max_version")) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
config.CipherSuites, err = TlsCipherSuiteByNames(viper.GetString("tls.cipher_suites")) |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
func TlsVersionByName(name string) (uint16, error) { |
|
|
|
switch name { |
|
|
|
case "": |
|
|
|
return 0, nil |
|
|
|
case "SSLv3": |
|
|
|
return tls.VersionSSL30, nil |
|
|
|
case "TLS 1.0": |
|
|
|
return tls.VersionTLS10, nil |
|
|
|
case "TLS 1.1": |
|
|
|
return tls.VersionTLS11, nil |
|
|
|
case "TLS 1.2": |
|
|
|
return tls.VersionTLS12, nil |
|
|
|
case "TLS 1.3": |
|
|
|
return tls.VersionTLS13, nil |
|
|
|
default: |
|
|
|
return 0, fmt.Errorf("invalid tls version %s", name) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TlsCipherSuiteByNames(cipherSuiteNames string) ([]uint16, error) { |
|
|
|
cipherSuiteNames = strings.TrimSpace(cipherSuiteNames) |
|
|
|
if cipherSuiteNames == "" { |
|
|
|
return nil, nil |
|
|
|
} |
|
|
|
names := strings.Split(cipherSuiteNames, ",") |
|
|
|
cipherSuites := tls.CipherSuites() |
|
|
|
cipherIds := make([]uint16, 0, len(names)) |
|
|
|
for _, name := range names { |
|
|
|
name = strings.TrimSpace(name) |
|
|
|
index := slices.IndexFunc(cipherSuites, func(suite *tls.CipherSuite) bool { |
|
|
|
return name == suite.Name |
|
|
|
}) |
|
|
|
if index == -1 { |
|
|
|
return nil, fmt.Errorf("invalid tls cipher suite name %s", name) |
|
|
|
} |
|
|
|
cipherIds = append(cipherIds, cipherSuites[index].ID) |
|
|
|
} |
|
|
|
return cipherIds, nil |
|
|
|
} |