diff --git a/weed/pb/grpc_client_server.go b/weed/pb/grpc_client_server.go index 9c7cf124b..ce706e282 100644 --- a/weed/pb/grpc_client_server.go +++ b/weed/pb/grpc_client_server.go @@ -76,43 +76,33 @@ func GrpcDial(ctx context.Context, address string, opts ...grpc.DialOption) (*gr return grpc.DialContext(ctx, address, options...) } -func WithCachedGrpcClient(fn func(*grpc.ClientConn) error, address string, opts ...grpc.DialOption) error { +func getOrCreateConnection(address string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { grpcClientsLock.Lock() + defer grpcClientsLock.Unlock() existingConnection, found := grpcClients[address] if found { - grpcClientsLock.Unlock() - err := fn(existingConnection) - if err != nil { - grpcClientsLock.Lock() - // delete(grpcClients, address) - grpcClientsLock.Unlock() - // println("closing existing connection to", existingConnection.Target()) - // existingConnection.Close() - } - return err + return existingConnection, nil } grpcConnection, err := GrpcDial(context.Background(), address, opts...) if err != nil { - grpcClientsLock.Unlock() - return fmt.Errorf("fail to dial %s: %v", address, err) + return nil, fmt.Errorf("fail to dial %s: %v", address, err) } grpcClients[address] = grpcConnection - grpcClientsLock.Unlock() - err = fn(grpcConnection) + return grpcConnection, nil +} + +func WithCachedGrpcClient(fn func(*grpc.ClientConn) error, address string, opts ...grpc.DialOption) error { + + grpcConnection, err := getOrCreateConnection(address, opts...) if err != nil { - grpcClientsLock.Lock() - // delete(grpcClients, address) - grpcClientsLock.Unlock() - // println("closing created new connection to", grpcConnection.Target()) - // grpcConnection.Close() + return fmt.Errorf("getOrCreateConnection %s: %v", address, err) } - - return err + return fn(grpcConnection) } func ParseServerToGrpcAddress(server string) (serverGrpcAddress string, err error) {