diff --git a/weed/pb/grpc_client_server.go b/weed/pb/grpc_client_server.go index cdac0ba99..edb60e4fa 100644 --- a/weed/pb/grpc_client_server.go +++ b/weed/pb/grpc_client_server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/chrislusf/seaweedfs/weed/glog" + "math/rand" "net/http" "strconv" "strings" @@ -24,10 +25,15 @@ const ( var ( // cache grpc connections - grpcClients = make(map[string]*grpc.ClientConn) + grpcClients = make(map[string]*versionedGrpcClient) grpcClientsLock sync.Mutex ) +type versionedGrpcClient struct { + *grpc.ClientConn + version int +} + func init() { http.DefaultTransport.(*http.Transport).MaxIdleConnsPerHost = 1024 http.DefaultTransport.(*http.Transport).MaxIdleConns = 1024 @@ -79,7 +85,7 @@ func GrpcDial(ctx context.Context, address string, opts ...grpc.DialOption) (*gr return grpc.DialContext(ctx, address, options...) } -func getOrCreateConnection(address string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { +func getOrCreateConnection(address string, opts ...grpc.DialOption) (*versionedGrpcClient, error) { grpcClientsLock.Lock() defer grpcClientsLock.Unlock() @@ -94,18 +100,34 @@ func getOrCreateConnection(address string, opts ...grpc.DialOption) (*grpc.Clien return nil, fmt.Errorf("fail to dial %s: %v", address, err) } - grpcClients[address] = grpcConnection + vgc := &versionedGrpcClient{ + grpcConnection, + rand.Int(), + } + grpcClients[address] = vgc - return grpcConnection, nil + return vgc, nil } func WithCachedGrpcClient(fn func(*grpc.ClientConn) error, address string, opts ...grpc.DialOption) error { - grpcConnection, err := getOrCreateConnection(address, opts...) + vgc, err := getOrCreateConnection(address, opts...) if err != nil { return fmt.Errorf("getOrCreateConnection %s: %v", address, err) } - return fn(grpcConnection) + executionErr := fn(vgc.ClientConn) + if executionErr != nil && strings.Contains(executionErr.Error(), "transport") { + grpcClientsLock.Lock() + if t, ok := grpcClients[address]; ok { + if t.version == vgc.version { + vgc.Close() + delete(grpcClients, address) + } + } + grpcClientsLock.Unlock() + } + + return executionErr } func ParseServerToGrpcAddress(server string) (serverGrpcAddress string, err error) {