You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

219 lines
5.9 KiB

package security
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"testing"
"time"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/security/advancedtls"
)
func generateSelfSignedCert(t *testing.T) (tls.Certificate, *x509.CertPool) {
t.Helper()
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generate CA key: %v", err)
}
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
if err != nil {
t.Fatalf("create CA cert: %v", err)
}
caCert, err := x509.ParseCertificate(caDER)
if err != nil {
t.Fatalf("parse CA cert: %v", err)
}
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generate leaf key: %v", err)
}
leafTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
}
leafDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, caCert, &leafKey.PublicKey, caKey)
if err != nil {
t.Fatalf("create leaf cert: %v", err)
}
leafCert := tls.Certificate{
Certificate: [][]byte{leafDER},
PrivateKey: leafKey,
}
pool := x509.NewCertPool()
pool.AddCert(caCert)
return leafCert, pool
}
func startTLSListenerCapturingSNI(t *testing.T, cert tls.Certificate) (string, <-chan string) {
t.Helper()
sniChan := make(chan string, 1)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
select {
case sniChan <- hello.ServerName:
close(sniChan)
default:
}
return nil, nil
},
}
ln, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
if err != nil {
t.Fatalf("tls.Listen: %v", err)
}
t.Cleanup(func() { ln.Close() })
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1)
conn.Read(buf)
}()
return ln.Addr().String(), sniChan
}
func newSNIStrippingCreds(t *testing.T, cert tls.Certificate, pool *x509.CertPool) credentials.TransportCredentials {
t.Helper()
clientCreds, err := advancedtls.NewClientCreds(&advancedtls.Options{
IdentityOptions: advancedtls.IdentityCertificateOptions{
Certificates: []tls.Certificate{cert},
},
RootOptions: advancedtls.RootCertificateOptions{
RootCertificates: pool,
},
VerificationType: advancedtls.CertVerification,
AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
return &advancedtls.PostHandshakeVerificationResults{}, nil
},
})
if err != nil {
t.Fatalf("NewClientCreds: %v", err)
}
return &SNIStrippingTransportCredentials{creds: clientCreds}
}
func TestSNI_HostnameStripsPort(t *testing.T) {
cert, pool := generateSelfSignedCert(t)
wrapped := newSNIStrippingCreds(t, cert, pool)
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
// gRPC passes "host:port" as authority; SNI wrapper strips the port
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err = wrapped.ClientHandshake(ctx, "localhost:"+portFromAddr(addr), conn)
if err != nil {
t.Fatalf("ClientHandshake: %v", err)
}
select {
case sni := <-sniChan:
if sni != "localhost" {
t.Errorf("SNI = %q, want %q", sni, "localhost")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for SNI")
}
}
func TestSNI_IPAddressEmptySNI(t *testing.T) {
cert, pool := generateSelfSignedCert(t)
wrapped := newSNIStrippingCreds(t, cert, pool)
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
// RFC 6066: IP addresses MUST NOT be sent as SNI; Go's TLS sends empty ServerName for IPs
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err = wrapped.ClientHandshake(ctx, "127.0.0.1:"+portFromAddr(addr), conn)
if err != nil {
t.Fatalf("ClientHandshake: %v", err)
}
select {
case sni := <-sniChan:
if sni != "" {
t.Errorf("SNI = %q, want empty string", sni)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for SNI")
}
}
func TestSNI_IPv6AddressEmptySNI(t *testing.T) {
cert, pool := generateSelfSignedCert(t)
wrapped := newSNIStrippingCreds(t, cert, pool)
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err = wrapped.ClientHandshake(ctx, "[::1]:"+portFromAddr(addr), conn)
if err != nil {
t.Fatalf("ClientHandshake: %v", err)
}
select {
case sni := <-sniChan:
if sni != "" {
t.Errorf("SNI = %q, want empty string", sni)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for SNI")
}
}
func portFromAddr(addr string) string {
_, port, _ := net.SplitHostPort(addr)
return port
}