Browse Source

fix: port in SNI address when using domainName instead of IP for master (#8500)

pull/8490/merge
Racci 17 hours ago
committed by GitHub
parent
commit
9e26d6f5dd
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 360
      go.sum
  2. 17
      weed/pb/server_address.go
  3. 58
      weed/pb/server_address_test.go
  4. 37
      weed/security/tls.go
  5. 219
      weed/security/tls_sni_test.go

360
go.sum
File diff suppressed because it is too large
View File

17
weed/pb/server_address.go

@ -86,6 +86,23 @@ func (sa ServerAddress) ToGrpcAddress() string {
return ServerToGrpcAddress(string(sa))
}
// ToHost returns the host part only, without any port information.
func (sa ServerAddress) ToHost() string {
httpAddr := sa.ToHttpAddress()
host, _, err := net.SplitHostPort(httpAddr)
if err == nil {
return host
}
// Fallback: if parsing fails, it's likely a host without a port.
// Handle bracketed IPv6 (e.g., "[::1]" without port) by trimming brackets.
if strings.HasPrefix(httpAddr, "[") && strings.HasSuffix(httpAddr, "]") {
return httpAddr[1 : len(httpAddr)-1]
}
return httpAddr
}
// LookUp may return an error for some records along with successful lookups - make sure you do not
// discard `addresses` even if `err == nil`
func (r ServerSrvAddress) LookUp() (addresses []ServerAddress, err error) {

58
weed/pb/server_address_test.go

@ -35,6 +35,64 @@ func TestServerAddresses_ToAddressMapOrSrv_shouldHandleIPPortList(t *testing.T)
}
}
func TestServerAddress_ToHost(t *testing.T) {
testCases := []struct {
name string
address ServerAddress
expected string
}{
{
name: "hostname with port",
address: ServerAddress("master.example.com:9333"),
expected: "master.example.com",
},
{
name: "IPv4 with port",
address: ServerAddress("192.168.1.1:9333"),
expected: "192.168.1.1",
},
{
name: "IPv6 with port",
address: ServerAddress("[2001:db8::1]:9333"),
expected: "2001:db8::1",
},
{
name: "hostname without port",
address: ServerAddress("master.example.com"),
expected: "master.example.com",
},
{
name: "hostname with port.grpcPort",
address: ServerAddress("master.example.com:443.10443"),
expected: "master.example.com",
},
{
name: "IPv4 with port.grpcPort",
address: ServerAddress("192.168.1.1:8080.18080"),
expected: "192.168.1.1",
},
{
name: "IPv6 with port.grpcPort",
address: ServerAddress("[2001:db8::1]:8080.18080"),
expected: "2001:db8::1",
},
{
name: "bracketed IPv6 without port",
address: ServerAddress("[2001:db8::1]"),
expected: "2001:db8::1",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := tc.address.ToHost()
if got != tc.expected {
t.Errorf("ServerAddress(%q).ToHost() = %q, want %q", tc.address, got, tc.expected)
}
})
}
}
func TestIPv6ServerAddressFormatting(t *testing.T) {
testCases := []struct {
name string

37
weed/security/tls.go

@ -1,9 +1,11 @@
package security
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"slices"
"strings"
@ -12,6 +14,7 @@ import (
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/util"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
"google.golang.org/grpc/security/advancedtls"
@ -24,6 +27,37 @@ type Authenticator struct {
AllowedCommonNames map[string]bool
}
// SNIStrippingTransportCredentials wraps another TransportCredentials
// and strips the port from the authority in ClientHandshake to prevent
// advancedtls from using the full "host:port" as ServerName in SNI.
type SNIStrippingTransportCredentials struct {
creds credentials.TransportCredentials
}
func (s *SNIStrippingTransportCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
host, _, err := net.SplitHostPort(authority)
if err == nil {
authority = host
}
return s.creds.ClientHandshake(ctx, authority, rawConn)
}
func (s *SNIStrippingTransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return s.creds.ServerHandshake(rawConn)
}
func (s *SNIStrippingTransportCredentials) Info() credentials.ProtocolInfo {
return s.creds.Info()
}
func (s *SNIStrippingTransportCredentials) Clone() credentials.TransportCredentials {
return &SNIStrippingTransportCredentials{creds: s.creds.Clone()}
}
func (s *SNIStrippingTransportCredentials) OverrideServerName(serverNameOverride string) error {
return s.creds.OverrideServerName(serverNameOverride)
}
func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption, grpc.ServerOption) {
if config == nil {
return nil, nil
@ -151,7 +185,8 @@ func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption {
glog.Warningf("advancedtls.NewClientCreds(%v) failed: %v", options, err)
return grpc.WithTransportCredentials(insecure.NewCredentials())
}
return grpc.WithTransportCredentials(ta)
wrapped := &SNIStrippingTransportCredentials{creds: ta}
return grpc.WithTransportCredentials(wrapped)
}
func LoadClientTLSHTTP(clientCertFile string) *tls.Config {

219
weed/security/tls_sni_test.go

@ -0,0 +1,219 @@
package security
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"testing"
"time"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/security/advancedtls"
)
func generateSelfSignedCert(t *testing.T) (tls.Certificate, *x509.CertPool) {
t.Helper()
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generate CA key: %v", err)
}
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
if err != nil {
t.Fatalf("create CA cert: %v", err)
}
caCert, err := x509.ParseCertificate(caDER)
if err != nil {
t.Fatalf("parse CA cert: %v", err)
}
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generate leaf key: %v", err)
}
leafTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
}
leafDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, caCert, &leafKey.PublicKey, caKey)
if err != nil {
t.Fatalf("create leaf cert: %v", err)
}
leafCert := tls.Certificate{
Certificate: [][]byte{leafDER},
PrivateKey: leafKey,
}
pool := x509.NewCertPool()
pool.AddCert(caCert)
return leafCert, pool
}
func startTLSListenerCapturingSNI(t *testing.T, cert tls.Certificate) (string, <-chan string) {
t.Helper()
sniChan := make(chan string, 1)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
select {
case sniChan <- hello.ServerName:
close(sniChan)
default:
}
return nil, nil
},
}
ln, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
if err != nil {
t.Fatalf("tls.Listen: %v", err)
}
t.Cleanup(func() { ln.Close() })
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1)
conn.Read(buf)
}()
return ln.Addr().String(), sniChan
}
func newSNIStrippingCreds(t *testing.T, cert tls.Certificate, pool *x509.CertPool) credentials.TransportCredentials {
t.Helper()
clientCreds, err := advancedtls.NewClientCreds(&advancedtls.Options{
IdentityOptions: advancedtls.IdentityCertificateOptions{
Certificates: []tls.Certificate{cert},
},
RootOptions: advancedtls.RootCertificateOptions{
RootCertificates: pool,
},
VerificationType: advancedtls.CertVerification,
AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
return &advancedtls.PostHandshakeVerificationResults{}, nil
},
})
if err != nil {
t.Fatalf("NewClientCreds: %v", err)
}
return &SNIStrippingTransportCredentials{creds: clientCreds}
}
func TestSNI_HostnameStripsPort(t *testing.T) {
cert, pool := generateSelfSignedCert(t)
wrapped := newSNIStrippingCreds(t, cert, pool)
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
// gRPC passes "host:port" as authority; SNI wrapper strips the port
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err = wrapped.ClientHandshake(ctx, "localhost:"+portFromAddr(addr), conn)
if err != nil {
t.Fatalf("ClientHandshake: %v", err)
}
select {
case sni := <-sniChan:
if sni != "localhost" {
t.Errorf("SNI = %q, want %q", sni, "localhost")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for SNI")
}
}
func TestSNI_IPAddressEmptySNI(t *testing.T) {
cert, pool := generateSelfSignedCert(t)
wrapped := newSNIStrippingCreds(t, cert, pool)
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
// RFC 6066: IP addresses MUST NOT be sent as SNI; Go's TLS sends empty ServerName for IPs
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err = wrapped.ClientHandshake(ctx, "127.0.0.1:"+portFromAddr(addr), conn)
if err != nil {
t.Fatalf("ClientHandshake: %v", err)
}
select {
case sni := <-sniChan:
if sni != "" {
t.Errorf("SNI = %q, want empty string", sni)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for SNI")
}
}
func TestSNI_IPv6AddressEmptySNI(t *testing.T) {
cert, pool := generateSelfSignedCert(t)
wrapped := newSNIStrippingCreds(t, cert, pool)
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err = wrapped.ClientHandshake(ctx, "[::1]:"+portFromAddr(addr), conn)
if err != nil {
t.Fatalf("ClientHandshake: %v", err)
}
select {
case sni := <-sniChan:
if sni != "" {
t.Errorf("SNI = %q, want empty string", sni)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for SNI")
}
}
func portFromAddr(addr string) string {
_, port, _ := net.SplitHostPort(addr)
return port
}
Loading…
Cancel
Save