diff --git a/go.mod b/go.mod index c8a914ab0..e16b67e6e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/buraksezer/consistent v0.0.0-20191006190839-693edf70fd72 github.com/cespare/xxhash v1.1.0 github.com/chrislusf/raft v1.0.4 + github.com/cloudflare/golibs v0.0.0-20201113145655-eb7a42c5e0be // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/disintegration/imaging v1.6.2 @@ -64,6 +65,7 @@ require ( github.com/seaweedfs/fuse v1.1.3 github.com/seaweedfs/goexif v1.0.2 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e + github.com/spance/suft v0.0.0-20161129124228-358fdb24d82d // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/viper v1.4.0 diff --git a/go.sum b/go.sum index ab1ef4739..a67a1bd39 100644 --- a/go.sum +++ b/go.sum @@ -158,6 +158,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudflare/golibs v0.0.0-20201113145655-eb7a42c5e0be h1:E+nLopD71RbvKLi2IGciQu/gtRmofNOJIBplVOaeKEw= +github.com/cloudflare/golibs v0.0.0-20201113145655-eb7a42c5e0be/go.mod h1:HlgKKR8V5a1wroIDDIz3/A+T+9Janfq+7n1P5sEFdi0= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa h1:OaNxuTZr7kxeODyLWsRMC+OD03aFUH+mW6r2d+MWa5Y= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= @@ -710,6 +712,8 @@ github.com/smartystreets/gunit v1.3.4/go.mod h1:ZjM1ozSIMJlAz/ay4SG8PeKF00ckUp+z github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= +github.com/spance/suft v0.0.0-20161129124228-358fdb24d82d h1:izNZHQvaeiCisdxVrhrbQgigPIJt0RJPrpRNWMWdjaQ= +github.com/spance/suft v0.0.0-20161129124228-358fdb24d82d/go.mod h1:S8IZv4Sq+amXiFLdengs5dVl9IdMjR/1yq35vjnYd54= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= diff --git a/weed/command/volume.go b/weed/command/volume.go index 0f3dba361..1468aa2ee 100644 --- a/weed/command/volume.go +++ b/weed/command/volume.go @@ -3,6 +3,8 @@ package command import ( "fmt" "github.com/chrislusf/seaweedfs/weed/storage/types" + "github.com/chrislusf/seaweedfs/weed/udptransfer" + "net" "net/http" httppprof "net/http/pprof" "os" @@ -29,7 +31,6 @@ import ( stats_collect "github.com/chrislusf/seaweedfs/weed/stats" "github.com/chrislusf/seaweedfs/weed/storage" "github.com/chrislusf/seaweedfs/weed/util" - "pack.ag/tftp" ) var ( @@ -399,15 +400,37 @@ func (v VolumeServerOptions) startTcpService(volumeServer *weed_server.VolumeSer func (v VolumeServerOptions) startUdpService(volumeServer *weed_server.VolumeServer) { listeningAddress := *v.bindIp + ":" + strconv.Itoa(*v.port+20001) - tftpServer, err := tftp.NewServer(listeningAddress) + + listener, err := udptransfer.NewEndpoint(&udptransfer.Params{ + LocalAddr: listeningAddress, + Bandwidth: 100, + FastRetransmit: true, + FlatTraffic: true, + IsServ: true, + }) if err != nil { glog.Fatalf("Volume server listen on %s:%v", listeningAddress, err) } - tftpServer.WriteHandler(volumeServer) - tftpServer.ReadHandler(volumeServer) + defer listener.Close() - glog.V(0).Infoln("Start Seaweed volume server", util.Version(), "UDP at", listeningAddress) - if e:= tftpServer.ListenAndServe(); e != nil { - glog.Fatalf("Volume server UDP on %s:%v", listeningAddress, e) + for { + conn, err := listener.Accept() + if err == nil { + glog.V(0).Infof("Client from %s", conn.RemoteAddr()) + go volumeServer.HandleTcpConnection(conn) + } else if isTemporaryError(err) { + continue + } else { + break + } } } + +func isTemporaryError(err error) bool { + if err != nil { + if ne, y := err.(net.Error); y { + return ne.Temporary() + } + } + return false +} \ No newline at end of file diff --git a/weed/udptransfer/conn.go b/weed/udptransfer/conn.go new file mode 100644 index 000000000..e2eca49da --- /dev/null +++ b/weed/udptransfer/conn.go @@ -0,0 +1,715 @@ +package udptransfer + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "time" +) + +const ( + _MAX_RETRIES = 6 + _MIN_RTT = 8 + _MIN_RTO = 30 + _MIN_ATO = 2 + _MAX_ATO = 10 + _MIN_SWND = 10 + _MAX_SWND = 960 +) + +const ( + _VACK_SCHED = iota + 1 + _VACK_QUICK + _VACK_MUST + _VSWND_ACTIVE + _VRETR_IMMED +) + +const ( + _RETR_REST = -1 + _CLOSE = 0xff +) + +var debug int + +func nodeOf(pk *packet) *qNode { + return &qNode{packet: pk} +} + +func (c *Conn) internalRecvLoop() { + defer func() { + // avoid send to closed channel while some replaying + // data packets were received in shutting down. + _ = recover() + }() + var buf, body []byte + for { + select { + case buf = <-c.evRecv: + if buf != nil { + body = buf[_TH_SIZE:] + } else { // shutdown + return + } + } + pk := new(packet) + // keep the original buffer, so we could recycle it in future + pk.buffer = buf + unmarshall(pk, body) + if pk.flag&_F_SACK != 0 { + c.processSAck(pk) + continue + } + if pk.flag&_F_ACK != 0 { + c.processAck(pk) + } + if pk.flag&_F_DATA != 0 { + c.insertData(pk) + } else if pk.flag&_F_FIN != 0 { + if pk.flag&_F_RESET != 0 { + go c.forceShutdownWithLock() + } else { + go c.closeR(pk) + } + } + } +} + +func (c *Conn) internalSendLoop() { + var timer = time.NewTimer(time.Duration(c.rtt) * time.Millisecond) + for { + select { + case v := <-c.evSWnd: + switch v { + case _VRETR_IMMED: + c.outlock.Lock() + c.retransmit2() + c.outlock.Unlock() + case _VSWND_ACTIVE: + timer.Reset(time.Duration(c.rtt) * time.Millisecond) + case _CLOSE: + return + } + case <-timer.C: // timeout yet + var notifySender bool + c.outlock.Lock() + rest, _ := c.retransmit() + switch rest { + case _RETR_REST, 0: // nothing to send + if c.outQ.size() > 0 { + timer.Reset(time.Duration(c.rtt) * time.Millisecond) + } else { + timer.Stop() + // avoid sender blocking + notifySender = true + } + default: // recent rto point + timer.Reset(time.Duration(minI64(rest, c.rtt)) * time.Millisecond) + } + c.outlock.Unlock() + if notifySender { + select { + case c.evSend <- 1: + default: + } + } + } + } +} + +func (c *Conn) internalAckLoop() { + // var ackTimer = time.NewTicker(time.Duration(c.ato)) + var ackTimer = time.NewTimer(time.Duration(c.ato) * time.Millisecond) + var lastAckState byte + for { + var v byte + select { + case <-ackTimer.C: + // may cause sending duplicated ack if ato>rtt + v = _VACK_QUICK + case v = <-c.evAck: + ackTimer.Reset(time.Duration(c.ato) * time.Millisecond) + state := lastAckState + lastAckState = v + if state != v { + if v == _CLOSE { + return + } + v = _VACK_MUST + } + } + c.inlock.Lock() + if pkAck := c.makeAck(v); pkAck != nil { + c.internalWrite(nodeOf(pkAck)) + } + c.inlock.Unlock() + } +} + +func (c *Conn) retransmit() (rest int64, count int32) { + var now, rto = Now(), c.rto + var limit = c.cwnd + for item := c.outQ.head; item != nil && limit > 0; item = item.next { + if item.scnt != _SENT_OK { // ACKed has scnt==-1 + diff := now - item.sent + if diff > rto { // already rto + c.internalWrite(item) + count++ + } else { + // continue search next min rto duration + if rest > 0 { + rest = minI64(rest, rto-diff+1) + } else { + rest = rto - diff + 1 + } + limit-- + } + } + } + c.outDupCnt += int(count) + if count > 0 { + shrcond := (c.fastRetransmit && count > maxI32(c.cwnd>>5, 4)) || (!c.fastRetransmit && count > c.cwnd>>3) + if shrcond && now-c.lastShrink > c.rto { + log.Printf("shrink cwnd from=%d to=%d s/4=%d", c.cwnd, c.cwnd>>1, c.swnd>>2) + c.lastShrink = now + // shrink cwnd and ensure cwnd >= swnd/4 + if c.cwnd > c.swnd>>1 { + c.cwnd >>= 1 + } + } + } + if c.outQ.size() > 0 { + return + } + return _RETR_REST, 0 +} + +func (c *Conn) retransmit2() (count int32) { + var limit, now = minI32(c.outPending>>4, 8), Now() + var fRtt = c.rtt + if now-c.lastShrink > c.rto { + fRtt += maxI64(c.rtt>>4, 1) + } else { + fRtt += maxI64(c.rtt>>1, 2) + } + for item := c.outQ.head; item != nil && count < limit; item = item.next { + if item.scnt != _SENT_OK { // ACKed has scnt==-1 + if item.miss >= 3 && now-item.sent >= fRtt { + item.miss = 0 + c.internalWrite(item) + count++ + } + } + } + c.fRCnt += int(count) + c.outDupCnt += int(count) + return +} + +func (c *Conn) inputAndSend(pk *packet) error { + item := &qNode{packet: pk} + if c.mySeq&3 == 1 { + c.tSlotT0 = NowNS() + } + c.outlock.Lock() + // inflight packets exceeds cwnd + // inflight includes: 1, unacked; 2, missed + for c.outPending >= c.cwnd+c.missed { + c.outlock.Unlock() + if c.wtmo > 0 { + var tmo int64 + tmo, c.wtmo = c.wtmo, 0 + select { + case v := <-c.evSend: + if v == _CLOSE { + return io.EOF + } + case <-NewTimerChan(tmo): + return ErrIOTimeout + } + } else { + if v := <-c.evSend; v == _CLOSE { + return io.EOF + } + } + c.outlock.Lock() + } + c.outPending++ + c.outPkCnt++ + c.mySeq++ + pk.seq = c.mySeq + c.outQ.appendTail(item) + c.internalWrite(item) + c.outlock.Unlock() + // active resending timer, must blocking + c.evSWnd <- _VSWND_ACTIVE + if c.mySeq&3 == 0 && c.flatTraffic { + // calculate time error bewteen tslot with actual usage. + // consider last sleep time error + t1 := NowNS() + terr := c.tSlot<<2 - c.lastSErr - (t1 - c.tSlotT0) + // rest terr/2 if current time usage less than tslot of 100us. + if terr > 1e5 { // 100us + time.Sleep(time.Duration(terr >> 1)) + c.lastSErr = maxI64(NowNS()-t1-terr, 0) + } else { + c.lastSErr >>= 1 + } + } + return nil +} + +func (c *Conn) internalWrite(item *qNode) { + if item.scnt >= 20 { + // no exception of sending fin + if item.flag&_F_FIN != 0 { + c.fakeShutdown() + c.dest = nil + return + } else { + log.Println("Warn: too many retries", item) + if c.urgent > 0 { // abort + c.forceShutdown() + return + } else { // continue to retry 10 + c.urgent++ + item.scnt = 10 + } + } + } + // update current sent time and prev sent time + item.sent, item.sent_1 = Now(), item.sent + item.scnt++ + buf := item.marshall(c.connID) + if debug >= 3 { + var pkType = packetTypeNames[item.flag] + if item.flag&_F_SACK != 0 { + log.Printf("send %s trp=%d on=%d %x", pkType, item.seq, item.ack, buf[_AH_SIZE+4:]) + } else { + log.Printf("send %s seq=%d ack=%d scnt=%d len=%d", pkType, item.seq, item.ack, item.scnt, len(buf)-_TH_SIZE) + } + } + c.sock.WriteToUDP(buf, c.dest) +} + +func (c *Conn) logAck(ack uint32) { + c.lastAck = ack + c.lastAckTime = Now() +} + +func (c *Conn) makeLastAck() (pk *packet) { + c.inlock.Lock() + defer c.inlock.Unlock() + if Now()-c.lastAckTime < c.rtt { + return nil + } + pk = &packet{ + ack: maxU32(c.lastAck, c.inQ.maxCtnSeq), + flag: _F_ACK, + } + c.logAck(pk.ack) + return +} + +func (c *Conn) makeAck(level byte) (pk *packet) { + now := Now() + if level < _VACK_MUST && now-c.lastAckTime < c.ato { + if level < _VACK_QUICK || now-c.lastAckTime < minI64(c.ato>>2, 1) { + return + } + } + // ready Q <-| + // |-> outQ start (or more right) + // |-> bitmap start + // [predecessor] [predecessor+1] [predecessor+2] ..... + var fakeSAck bool + var predecessor = c.inQ.maxCtnSeq + bmap, tbl := c.inQ.makeHolesBitmap(predecessor) + if len(bmap) <= 0 { // fake sack + bmap = make([]uint64, 1) + bmap[0], tbl = 1, 1 + fakeSAck = true + } + // head 4-byte: TBL:1 | SCNT:1 | DELAY:2 + buf := make([]byte, len(bmap)*8+4) + pk = &packet{ + ack: predecessor + 1, + flag: _F_SACK, + payload: buf, + } + if fakeSAck { + pk.ack-- + } + buf[0] = byte(tbl) + // mark delayed time according to the time reference point + if trp := c.inQ.lastIns; trp != nil { + delayed := now - trp.sent + if delayed < c.rtt { + pk.seq = trp.seq + pk.flag |= _F_TIME + buf[1] = trp.scnt + if delayed <= 0 { + delayed = 1 + } + binary.BigEndian.PutUint16(buf[2:], uint16(delayed)) + } + } + buf1 := buf[4:] + for i, b := range bmap { + binary.BigEndian.PutUint64(buf1[i*8:], b) + } + c.logAck(predecessor) + return +} + +func unmarshallSAck(data []byte) (bmap []uint64, tbl uint32, delayed uint16, scnt uint8) { + if len(data) > 0 { + bmap = make([]uint64, len(data)>>3) + } else { + return + } + tbl = uint32(data[0]) + scnt = data[1] + delayed = binary.BigEndian.Uint16(data[2:]) + data = data[4:] + for i := 0; i < len(bmap); i++ { + bmap[i] = binary.BigEndian.Uint64(data[i*8:]) + } + return +} + +func calSwnd(bandwidth, rtt int64) int32 { + w := int32(bandwidth * rtt / (8000 * _MSS)) + if w <= _MAX_SWND { + if w >= _MIN_SWND { + return w + } else { + return _MIN_SWND + } + } else { + return _MAX_SWND + } +} + +func (c *Conn) measure(seq uint32, delayed int64, scnt uint8) { + target := c.outQ.get(seq) + if target != nil { + var lastSent int64 + switch target.scnt - scnt { + case 0: + // not sent again since this ack was sent out + lastSent = target.sent + case 1: + // sent again once since this ack was sent out + // then use prev sent time + lastSent = target.sent_1 + default: + // can't measure here because the packet was sent too many times + return + } + // real-time rtt + rtt := Now() - lastSent - delayed + // reject these abnormal measures: + // 1. rtt too small -> rtt/8 + // 2. backlogging too long + if rtt < maxI64(c.rtt>>3, 1) || delayed > c.rtt>>1 { + return + } + // srtt: update 1/8 + err := rtt - (c.srtt >> 3) + c.srtt += err + c.rtt = c.srtt >> 3 + if c.rtt < _MIN_RTT { + c.rtt = _MIN_RTT + } + // s-swnd: update 1/4 + swnd := c.swnd<<3 - c.swnd + calSwnd(c.bandwidth, c.rtt) + c.swnd = swnd >> 3 + c.tSlot = c.rtt * 1e6 / int64(c.swnd) + c.ato = c.rtt >> 4 + if c.ato < _MIN_ATO { + c.ato = _MIN_ATO + } else if c.ato > _MAX_ATO { + c.ato = _MAX_ATO + } + if err < 0 { + err = -err + err -= c.mdev >> 2 + if err > 0 { + err >>= 3 + } + } else { + err -= c.mdev >> 2 + } + // mdev: update 1/4 + c.mdev += err + rto := c.rtt + maxI64(c.rtt<<1, c.mdev) + if rto >= c.rto { + c.rto = rto + } else { + c.rto = (c.rto + rto) >> 1 + } + if c.rto < _MIN_RTO { + c.rto = _MIN_RTO + } + if debug >= 1 { + log.Printf("--- rtt=%d srtt=%d rto=%d swnd=%d", c.rtt, c.srtt, c.rto, c.swnd) + } + } +} + +func (c *Conn) processSAck(pk *packet) { + c.outlock.Lock() + bmap, tbl, delayed, scnt := unmarshallSAck(pk.payload) + if bmap == nil { // bad packet + c.outlock.Unlock() + return + } + if pk.flag&_F_TIME != 0 { + c.measure(pk.seq, int64(delayed), scnt) + } + deleted, missed, continuous := c.outQ.deleteByBitmap(bmap, pk.ack, tbl) + if deleted > 0 { + c.ackHit(deleted, missed) + // lock is released + } else { + c.outlock.Unlock() + } + if c.fastRetransmit && !continuous { + // peer Q is uncontinuous, then trigger FR + if deleted == 0 { + c.evSWnd <- _VRETR_IMMED + } else { + select { + case c.evSWnd <- _VRETR_IMMED: + default: + } + } + } + if debug >= 2 { + log.Printf("SACK qhead=%d deleted=%d outPending=%d on=%d %016x", + c.outQ.distanceOfHead(0), deleted, c.outPending, pk.ack, bmap) + } +} + +func (c *Conn) processAck(pk *packet) { + c.outlock.Lock() + if end := c.outQ.get(pk.ack); end != nil { // ack hit + _, deleted := c.outQ.deleteBefore(end) + c.ackHit(deleted, 0) // lock is released + if debug >= 2 { + log.Printf("ACK hit on=%d", pk.ack) + } + // special case: ack the FIN + if pk.seq == _FIN_ACK_SEQ { + select { + case c.evClose <- _S_FIN0: + default: + } + } + } else { // duplicated ack + if debug >= 2 { + log.Printf("ACK miss on=%d", pk.ack) + } + if pk.flag&_F_SYN != 0 { // No.3 Ack lost + if pkAck := c.makeLastAck(); pkAck != nil { + c.internalWrite(nodeOf(pkAck)) + } + } + c.outlock.Unlock() + } +} + +func (c *Conn) ackHit(deleted, missed int32) { + // must in outlock + c.outPending -= deleted + now := Now() + if c.cwnd < c.swnd && now-c.lastShrink > c.rto { + if c.cwnd < c.swnd>>1 { + c.cwnd <<= 1 + } else { + c.cwnd += deleted << 1 + } + } + if c.cwnd > c.swnd { + c.cwnd = c.swnd + } + if now-c.lastRstMis > c.ato { + c.lastRstMis = now + c.missed = missed + } else { + c.missed = c.missed>>1 + missed + } + if qswnd := c.swnd >> 4; c.missed > qswnd { + c.missed = qswnd + } + c.outlock.Unlock() + select { + case c.evSend <- 1: + default: + } +} + +func (c *Conn) insertData(pk *packet) { + c.inlock.Lock() + defer c.inlock.Unlock() + exists := c.inQ.contains(pk.seq) + // duplicated with already queued or history + // means: last ACK were lost + if exists || pk.seq <= c.inQ.maxCtnSeq { + // then send ACK for dups + select { + case c.evAck <- _VACK_MUST: + default: + } + if debug >= 2 { + dumpQ(fmt.Sprint("duplicated ", pk.seq), c.inQ) + } + c.inDupCnt++ + return + } + // record current time in sent and regard as received time + item := &qNode{packet: pk, sent: Now()} + dis := c.inQ.searchInsert(item, c.lastReadSeq) + if debug >= 3 { + log.Printf("\t\t\trecv DATA seq=%d dis=%d maxCtn=%d lastReadSeq=%d", item.seq, dis, c.inQ.maxCtnSeq, c.lastReadSeq) + } + + var ackState byte = _VACK_MUST + var available bool + switch dis { + case 0: // impossible + return + case 1: + if c.inQDirty { + available = c.inQ.updateContinuous(item) + if c.inQ.isWholeContinuous() { // whole Q is ordered + c.inQDirty = false + } else { //those holes still exists. + ackState = _VACK_QUICK + } + } else { + // here is an ideal situation + c.inQ.maxCtnSeq = pk.seq + available = true + ackState = _VACK_SCHED + } + + default: // there is an unordered packet, hole occurred here. + if !c.inQDirty { + c.inQDirty = true + } + } + + // write valid received count + c.inPkCnt++ + c.inQ.lastIns = item + // try notify ack + select { + case c.evAck <- ackState: + default: + } + if available { // try notify reader + select { + case c.evRead <- 1: + default: + } + } +} + +func (c *Conn) readInQ() bool { + c.inlock.Lock() + defer c.inlock.Unlock() + // read already <-|-> expected Q + // [lastReadSeq] | [lastReadSeq+1] [lastReadSeq+2] ...... + if c.inQ.isEqualsHead(c.lastReadSeq+1) && c.lastReadSeq < c.inQ.maxCtnSeq { + c.lastReadSeq = c.inQ.maxCtnSeq + availabled := c.inQ.get(c.inQ.maxCtnSeq) + availabled, _ = c.inQ.deleteBefore(availabled) + for i := availabled; i != nil; i = i.next { + c.inQReady = append(c.inQReady, i.payload...) + // data was copied, then could recycle memory + bpool.Put(i.buffer) + i.payload = nil + i.buffer = nil + } + return true + } + return false +} + +// should not call this function concurrently. +func (c *Conn) Read(buf []byte) (nr int, err error) { + for { + if len(c.inQReady) > 0 { + n := copy(buf, c.inQReady) + c.inQReady = c.inQReady[n:] + return n, nil + } + if !c.readInQ() { + if c.rtmo > 0 { + var tmo int64 + tmo, c.rtmo = c.rtmo, 0 + select { + case _, y := <-c.evRead: + if !y && len(c.inQReady) == 0 { + return 0, io.EOF + } + case <-NewTimerChan(tmo): + return 0, ErrIOTimeout + } + } else { + // only when evRead is closed and inQReady is empty + // then could reply eof + if _, y := <-c.evRead; !y && len(c.inQReady) == 0 { + return 0, io.EOF + } + } + } + } +} + +// should not call this function concurrently. +func (c *Conn) Write(data []byte) (nr int, err error) { + for len(data) > 0 && err == nil { + //buf := make([]byte, _MSS+_AH_SIZE) + buf := bpool.Get(c.mss + _AH_SIZE) + body := buf[_TH_SIZE+_CH_SIZE:] + n := copy(body, data) + nr += n + data = data[n:] + pk := &packet{flag: _F_DATA, payload: body[:n], buffer: buf[:_AH_SIZE+n]} + err = c.inputAndSend(pk) + } + return +} + +func (c *Conn) LocalAddr() net.Addr { + return c.sock.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.dest +} + +func (c *Conn) SetDeadline(t time.Time) error { + c.SetReadDeadline(t) + c.SetWriteDeadline(t) + return nil +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + if d := t.UnixNano()/Millisecond - Now(); d > 0 { + c.rtmo = d + } + return nil +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + if d := t.UnixNano()/Millisecond - Now(); d > 0 { + c.wtmo = d + } + return nil +} diff --git a/weed/udptransfer/conn_test.go b/weed/udptransfer/conn_test.go new file mode 100644 index 000000000..080d778c0 --- /dev/null +++ b/weed/udptransfer/conn_test.go @@ -0,0 +1,86 @@ +package udptransfer + +import ( + "math/rand" + "sort" + "testing" +) + +var conn *Conn + +func init() { + conn = &Conn{ + outQ: newLinkedMap(_QModeOut), + inQ: newLinkedMap(_QModeIn), + } +} + +func assert(cond bool, t testing.TB, format string, args ...interface{}) { + if !cond { + t.Errorf(format, args...) + panic("last error") + } +} + +func Test_ordered_insert(t *testing.T) { + data := []byte{1} + var pk *packet + for i := int32(1); i < 33; i++ { + pk = &packet{ + seq: uint32(i), + payload: data, + } + conn.insertData(pk) + assert(conn.inQ.size() == i, t, "len inQ=%d", conn.inQ.size()) + assert(conn.inQ.maxCtnSeq == pk.seq, t, "lastCtnIn") + assert(!conn.inQDirty, t, "dirty") + } +} + +func Test_unordered_insert(t *testing.T) { + conn.inQ.reset() + + data := []byte{1} + var pk *packet + var seqs = make([]int, 0xfff) + // unordered insert, and assert size + for i := 1; i < len(seqs); i++ { + var seq uint32 + for conn.inQ.contains(seq) || seq == 0 { + seq = uint32(rand.Int31n(0xFFffff)) + } + seqs[i] = int(seq) + pk = &packet{ + seq: seq, + payload: data, + } + conn.insertData(pk) + assert(conn.inQ.size() == int32(i), t, "i=%d inQ.len=%d", i, conn.inQ.size()) + } + // assert lastCtnSeq + sort.Ints(seqs) + var zero = 0 + var last *int + for i := 0; i < len(seqs); i++ { + if i == 0 && seqs[0] != 0 { + last = &zero + break + } + if last != nil && seqs[i]-*last > 1 { + if i == 1 { + last = &zero + } + break + } + last = &seqs[i] + } + if *last != int(conn.inQ.maxCtnSeq) { + for i, j := range seqs { + if i < 10 { + t.Logf("seq %d", j) + } + } + } + assert(*last == int(conn.inQ.maxCtnSeq), t, "lastCtnSeq=%d but expected=%d", conn.inQ.maxCtnSeq, *last) + t.Logf("lastCtnSeq=%d dirty=%v", conn.inQ.maxCtnSeq, conn.inQDirty) +} diff --git a/weed/udptransfer/debug.go b/weed/udptransfer/debug.go new file mode 100644 index 000000000..dc8dbb501 --- /dev/null +++ b/weed/udptransfer/debug.go @@ -0,0 +1,95 @@ +package udptransfer + +import ( + "encoding/hex" + "fmt" + "log" + "os" + "os/signal" + "runtime" + "runtime/pprof" + "sync/atomic" + "syscall" +) + +var enable_pprof bool +var enable_stacktrace bool +var debug_inited int32 + +func set_debug_params(p *Params) { + if atomic.CompareAndSwapInt32(&debug_inited, 0, 1) { + debug = p.Debug + enable_pprof = p.EnablePprof + enable_stacktrace = p.Stacktrace + if enable_pprof { + f, err := os.Create("udptransfer.pprof") + if err != nil { + log.Fatalln(err) + } + pprof.StartCPUProfile(f) + } + } +} + +func (c *Conn) PrintState() { + log.Printf("inQ=%d inQReady=%d outQ=%d", c.inQ.size(), len(c.inQReady), c.outQ.size()) + log.Printf("inMaxCtnSeq=%d lastAck=%d lastReadSeq=%d", c.inQ.maxCtnSeq, c.lastAck, c.lastReadSeq) + if c.inPkCnt > 0 { + log.Printf("Rx pcnt=%d dups=%d %%d=%f%%", c.inPkCnt, c.inDupCnt, 100*float32(c.inDupCnt)/float32(c.inPkCnt)) + } + if c.outPkCnt > 0 { + log.Printf("Tx pcnt=%d dups=%d %%d=%f%%", c.outPkCnt, c.outDupCnt, 100*float32(c.outDupCnt)/float32(c.outPkCnt)) + } + log.Printf("current-rtt=%d FastRetransmit=%d", c.rtt, c.fRCnt) + if enable_stacktrace { + var buf = make([]byte, 6400) + for i := 0; i < 3; i++ { + n := runtime.Stack(buf, true) + if n >= len(buf) { + buf = make([]byte, len(buf)<<1) + } else { + buf = buf[:n] + break + } + } + fmt.Println(string(buf)) + } + if enable_pprof { + pprof.StopCPUProfile() + } +} + +func (c *Conn) internal_state() { + ev := make(chan os.Signal, 10) + signal.Notify(ev, syscall.Signal(12), syscall.SIGINT) + for v := range ev { + c.PrintState() + if v == syscall.SIGINT { + os.Exit(1) + } + } +} + +func printBits(b uint64, j, s, d uint32) { + fmt.Printf("bits=%064b j=%d seq=%d dis=%d\n", b, j, s, d) +} + +func dumpb(label string, buf []byte) { + log.Println(label, "\n", hex.Dump(buf)) +} + +func dumpQ(s string, q *linkedMap) { + var seqs = make([]uint32, 0, 20) + n := q.head + for i, m := int32(0), q.size(); i < m && n != nil; i++ { + seqs = append(seqs, n.seq) + n = n.next + if len(seqs) == 20 { + log.Printf("%s: Q=%d", s, seqs) + seqs = seqs[:0] + } + } + if len(seqs) > 0 { + log.Printf("%s: Q=%d", s, seqs) + } +} diff --git a/weed/udptransfer/endpoint.go b/weed/udptransfer/endpoint.go new file mode 100644 index 000000000..d19d1a4f5 --- /dev/null +++ b/weed/udptransfer/endpoint.go @@ -0,0 +1,339 @@ +package udptransfer + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "math/rand" + "net" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/cloudflare/golibs/bytepool" +) + +const ( + _SO_BUF_SIZE = 8 << 20 +) + +var ( + bpool bytepool.BytePool +) + +type Params struct { + LocalAddr string + Bandwidth int64 + Mtu int + IsServ bool + FastRetransmit bool + FlatTraffic bool + EnablePprof bool + Stacktrace bool + Debug int +} + +type connID struct { + lid uint32 + rid uint32 +} + +type Endpoint struct { + udpconn *net.UDPConn + state int32 + idSeq uint32 + isServ bool + listenChan chan *Conn + lRegistry map[uint32]*Conn + rRegistry map[string][]uint32 + mlock sync.RWMutex + timeout *time.Timer + params Params +} + +func (c *connID) setRid(b []byte) { + c.rid = binary.BigEndian.Uint32(b[_MAGIC_SIZE+6:]) +} + +func init() { + bpool.Init(0, 2000) + rand.Seed(NowNS()) +} + +func NewEndpoint(p *Params) (*Endpoint, error) { + set_debug_params(p) + if p.Bandwidth <= 0 || p.Bandwidth > 100 { + return nil, fmt.Errorf("bw->(0,100]") + } + conn, err := net.ListenPacket("udp", p.LocalAddr) + if err != nil { + return nil, err + } + e := &Endpoint{ + udpconn: conn.(*net.UDPConn), + idSeq: 1, + isServ: p.IsServ, + listenChan: make(chan *Conn, 1), + lRegistry: make(map[uint32]*Conn), + rRegistry: make(map[string][]uint32), + timeout: time.NewTimer(0), + params: *p, + } + if e.isServ { + e.state = _S_EST0 + } else { // client + e.state = _S_EST1 + e.idSeq = uint32(rand.Int31()) + } + e.params.Bandwidth = p.Bandwidth << 20 // mbps to bps + e.udpconn.SetReadBuffer(_SO_BUF_SIZE) + go e.internal_listen() + return e, nil +} + +func (e *Endpoint) internal_listen() { + const rtmo = time.Duration(30*time.Second) + var id connID + for { + //var buf = make([]byte, 1600) + var buf = bpool.Get(1600) + e.udpconn.SetReadDeadline(time.Now().Add(rtmo)) + n, addr, err := e.udpconn.ReadFromUDP(buf) + if err == nil && n >= _AH_SIZE { + buf = buf[:n] + e.getConnID(&id, buf) + + switch id.lid { + case 0: // new connection + if e.isServ { + go e.acceptNewConn(id, addr, buf) + } else { + dumpb("drop", buf) + } + + case _INVALID_SEQ: + dumpb("drop invalid", buf) + + default: // old connection + e.mlock.RLock() + conn := e.lRegistry[id.lid] + e.mlock.RUnlock() + if conn != nil { + e.dispatch(conn, buf) + } else { + e.resetPeer(addr, id) + dumpb("drop null", buf) + } + } + + } else if err != nil { + // idle process + if nerr, y := err.(net.Error); y && nerr.Timeout() { + e.idleProcess() + continue + } + // other errors + if atomic.LoadInt32(&e.state) == _S_FIN { + return + } else { + log.Println("Error: read sock", err) + } + } + } +} + +func (e *Endpoint) idleProcess() { + // recycle/shrink memory + bpool.Drain() + e.mlock.Lock() + defer e.mlock.Unlock() + // reset urgent + for _, c := range e.lRegistry { + c.outlock.Lock() + if c.outQ.size() == 0 && c.urgent != 0 { + c.urgent = 0 + } + c.outlock.Unlock() + } +} + +func (e *Endpoint) Dial(addr string) (*Conn, error) { + rAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + e.mlock.Lock() + e.idSeq++ + id := connID{e.idSeq, 0} + conn := NewConn(e, rAddr, id) + e.lRegistry[id.lid] = conn + e.mlock.Unlock() + if atomic.LoadInt32(&e.state) != _S_FIN { + err = conn.initConnection(nil) + return conn, err + } + return nil, io.EOF +} + +func (e *Endpoint) acceptNewConn(id connID, addr *net.UDPAddr, buf []byte) { + rKey := addr.String() + e.mlock.Lock() + // map: remoteAddr => remoteConnID + // filter duplicated syn packets + if newArr := insertRid(e.rRegistry[rKey], id.rid); newArr != nil { + e.rRegistry[rKey] = newArr + } else { + e.mlock.Unlock() + log.Println("Warn: duplicated connection", addr) + return + } + e.idSeq++ + id.lid = e.idSeq + conn := NewConn(e, addr, id) + e.lRegistry[id.lid] = conn + e.mlock.Unlock() + err := conn.initConnection(buf) + if err == nil { + select { + case e.listenChan <- conn: + case <-time.After(_10ms): + log.Println("Warn: no listener") + } + } else { + e.removeConn(id, addr) + log.Println("Error: init_connection", addr, err) + } +} + +func (e *Endpoint) removeConn(id connID, addr *net.UDPAddr) { + e.mlock.Lock() + delete(e.lRegistry, id.lid) + rKey := addr.String() + if newArr := deleteRid(e.rRegistry[rKey], id.rid); newArr != nil { + if len(newArr) > 0 { + e.rRegistry[rKey] = newArr + } else { + delete(e.rRegistry, rKey) + } + } + e.mlock.Unlock() +} + +// net.Listener +func (e *Endpoint) Close() error { + state := atomic.LoadInt32(&e.state) + if state > 0 && atomic.CompareAndSwapInt32(&e.state, state, _S_FIN) { + err := e.udpconn.Close() + e.lRegistry = nil + e.rRegistry = nil + select { // release listeners + case e.listenChan <- nil: + default: + } + return err + } + return nil +} + +// net.Listener +func (e *Endpoint) Addr() net.Addr { + return e.udpconn.LocalAddr() +} + +// net.Listener +func (e *Endpoint) Accept() (net.Conn, error) { + if atomic.LoadInt32(&e.state) == _S_EST0 { + return <-e.listenChan, nil + } else { + return nil, io.EOF + } +} + +func (e *Endpoint) Listen() *Conn { + if atomic.LoadInt32(&e.state) == _S_EST0 { + return <-e.listenChan + } else { + return nil + } +} + +// tmo in MS +func (e *Endpoint) ListenTimeout(tmo int64) *Conn { + if tmo <= 0 { + return e.Listen() + } + if atomic.LoadInt32(&e.state) == _S_EST0 { + select { + case c := <-e.listenChan: + return c + case <-NewTimerChan(tmo): + } + } + return nil +} + +func (e *Endpoint) getConnID(idPtr *connID, buf []byte) { + // TODO determine magic header + magicAndLen := binary.BigEndian.Uint64(buf) + if int(magicAndLen&0xFFff) == len(buf) { + id := binary.BigEndian.Uint64(buf[_MAGIC_SIZE+2:]) + idPtr.lid = uint32(id >> 32) + idPtr.rid = uint32(id) + } else { + idPtr.lid = _INVALID_SEQ + } +} + +func (e *Endpoint) dispatch(c *Conn, buf []byte) { + e.timeout.Reset(30*time.Millisecond) + select { + case c.evRecv <- buf: + case <-e.timeout.C: + log.Println("Warn: dispatch packet failed") + } +} + +func (e *Endpoint) resetPeer(addr *net.UDPAddr, id connID) { + pk := &packet{flag: _F_FIN | _F_RESET} + buf := nodeOf(pk).marshall(id) + e.udpconn.WriteToUDP(buf, addr) +} + +type u32Slice []uint32 + +func (p u32Slice) Len() int { return len(p) } +func (p u32Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p u32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// if the rid is not existed in array then insert it return new array +func insertRid(array []uint32, rid uint32) []uint32 { + if len(array) > 0 { + pos := sort.Search(len(array), func(n int) bool { + return array[n] >= rid + }) + if pos < len(array) && array[pos] == rid { + return nil + } + } + array = append(array, rid) + sort.Sort(u32Slice(array)) + return array +} + +// if rid was existed in array then delete it return new array +func deleteRid(array []uint32, rid uint32) []uint32 { + if len(array) > 0 { + pos := sort.Search(len(array), func(n int) bool { + return array[n] >= rid + }) + if pos < len(array) && array[pos] == rid { + newArray := make([]uint32, len(array)-1) + n := copy(newArray, array[:pos]) + copy(newArray[n:], array[pos+1:]) + return newArray + } + } + return nil +} diff --git a/weed/udptransfer/endpoint_test.go b/weed/udptransfer/endpoint_test.go new file mode 100644 index 000000000..01be073bc --- /dev/null +++ b/weed/udptransfer/endpoint_test.go @@ -0,0 +1,42 @@ +package udptransfer + +import ( + "fmt" + "math/rand" + "sort" + "testing" +) + +func Test_insert_delete_rid(t *testing.T) { + var a []uint32 + var b = make([]uint32, 0, 1e3) + var uniq = make(map[uint32]int) + // insert into a with random + for i := 0; i < cap(b); i++ { + n := uint32(rand.Int31()) + if _, y := uniq[n]; !y { + b = append(b, n) + uniq[n] = 1 + } + dups := 1 + if i&0xf == 0xf { + dups = 3 + } + for j := 0; j < dups; j++ { + if aa := insertRid(a, n); aa != nil { + a = aa + } + } + } + sort.Sort(u32Slice(b)) + bStr := fmt.Sprintf("%d", b) + aStr := fmt.Sprintf("%d", a) + assert(aStr == bStr, t, "a!=b") + + for i := 0; i < len(b); i++ { + if aa := deleteRid(a, b[i]); aa != nil { + a = aa + } + } + assert(len(a) == 0, t, "a!=0") +} diff --git a/weed/udptransfer/linked.go b/weed/udptransfer/linked.go new file mode 100644 index 000000000..061cfacd2 --- /dev/null +++ b/weed/udptransfer/linked.go @@ -0,0 +1,333 @@ +package udptransfer + +type qNode struct { + *packet + prev *qNode + next *qNode + sent int64 // last sent time + sent_1 int64 // prev sent time + miss int // sack miss count +} + +type linkedMap struct { + head *qNode + tail *qNode + qmap map[uint32]*qNode + lastIns *qNode + maxCtnSeq uint32 + mode int +} + +const ( + _QModeIn = 1 + _QModeOut = 2 +) + +func newLinkedMap(qmode int) *linkedMap { + return &linkedMap{ + qmap: make(map[uint32]*qNode), + mode: qmode, + } +} + +func (l *linkedMap) get(seq uint32) (i *qNode) { + i = l.qmap[seq] + return +} + +func (l *linkedMap) contains(seq uint32) bool { + _, y := l.qmap[seq] + return y +} + +func (l *linkedMap) size() int32 { + return int32(len(l.qmap)) +} + +func (l *linkedMap) reset() { + l.head = nil + l.tail = nil + l.lastIns = nil + l.maxCtnSeq = 0 + l.qmap = make(map[uint32]*qNode) +} + +func (l *linkedMap) isEqualsHead(seq uint32) bool { + return l.head != nil && seq == l.head.seq +} + +func (l *linkedMap) distanceOfHead(seq uint32) int32 { + if l.head != nil { + return int32(seq - l.head.seq) + } else { + return -1 + } +} + +func (l *linkedMap) appendTail(one *qNode) { + if l.tail != nil { + l.tail.next = one + one.prev = l.tail + l.tail = one + } else { + l.head = one + l.tail = one + } + l.qmap[one.seq] = one +} + +// xxx - n - yyy +// xxx - yyy +func (l *linkedMap) deleteAt(n *qNode) { + x, y := n.prev, n.next + if x != nil { + x.next = y + } else { + l.head = y + } + if y != nil { + y.prev = x + } else { + l.tail = x + } + n.prev, n.next = nil, nil + delete(l.qmap, n.seq) +} + +// delete with n <- ...n | +func (l *linkedMap) deleteBefore(n *qNode) (left *qNode, deleted int32) { + for i := n; i != nil; i = i.prev { + delete(l.qmap, i.seq) + if i.scnt != _SENT_OK { + deleted++ + // only outQ could delete at here + if l.mode == _QModeOut { + bpool.Put(i.buffer) + i.buffer = nil + } + } + } + left = l.head + l.head = n.next + n.next = nil + if l.head != nil { + l.head.prev = nil + } else { // n.next is the tail and is nil + l.tail = nil + } + return +} + +// xxx - ref +// xxx - one - ref +func (l *linkedMap) insertBefore(ref, one *qNode) { + x := ref.prev + if x != nil { + x.next = one + one.prev = x + } else { + l.head = one + } + ref.prev = one + one.next = ref + l.qmap[one.seq] = one +} + +// ref - zzz +// ref - one - zzz +func (l *linkedMap) insertAfter(ref, one *qNode) { + z := ref.next + if z == nil { // append + ref.next = one + l.tail = one + } else { // insert mid + z.prev = one + ref.next = one + } + one.prev = ref + one.next = z + l.qmap[one.seq] = one +} + +// baseHead: the left outside boundary +// if inserted, return the distance between newNode with baseHead +func (l *linkedMap) searchInsert(one *qNode, baseHead uint32) (dis int64) { + for i := l.tail; i != nil; i = i.prev { + dis = int64(one.seq) - int64(i.seq) + if dis > 0 { + l.insertAfter(i, one) + return + } else if dis == 0 { + // duplicated + return + } + } + if one.seq <= baseHead { + return 0 + } + if l.head != nil { + l.insertBefore(l.head, one) + } else { + l.head = one + l.tail = one + l.qmap[one.seq] = one + } + dis = int64(one.seq) - int64(baseHead) + return +} + +func (l *linkedMap) updateContinuous(i *qNode) bool { + var lastCtnSeq = l.maxCtnSeq + for ; i != nil; i = i.next { + if i.seq-lastCtnSeq == 1 { + lastCtnSeq = i.seq + } else { + break + } + } + if lastCtnSeq != l.maxCtnSeq { + l.maxCtnSeq = lastCtnSeq + return true + } + return false +} + +func (l *linkedMap) isWholeContinuous() bool { + return l.tail != nil && l.maxCtnSeq == l.tail.seq +} + +/* +func (l *linkedMap) searchMaxContinued(baseHead uint32) (*qNode, bool) { + var last *qNode + for i := l.head; i != nil; i = i.next { + if last != nil { + if i.seq-last.seq > 1 { + return last, true + } + } else { + if i.seq != baseHead { + return nil, false + } + } + last = i + } + if last != nil { + return last, true + } else { + return nil, false + } +} +*/ + +func (q *qNode) forward(n int) *qNode { + for ; n > 0 && q != nil; n-- { + q = q.next + } + return q +} + +// prev of bitmap start point +func (l *linkedMap) makeHolesBitmap(prev uint32) ([]uint64, uint32) { + var start *qNode + var bits uint64 + var j uint32 + for i := l.head; i != nil; i = i.next { + if i.seq >= prev+1 { + start = i + break + } + } + var bitmap []uint64 + // search start which is the recent successor of [prev] +Iterator: + for i := start; i != nil; i = i.next { + dis := i.seq - prev + prev = i.seq + j += dis // j is next bit index + for j >= 65 { + if len(bitmap) >= 20 { // bitmap too long + break Iterator + } + bitmap = append(bitmap, bits) + bits = 0 + j -= 64 + } + bits |= 1 << (j - 1) + } + if j > 0 { + // j -> (0, 64] + bitmap = append(bitmap, bits) + } + return bitmap, j +} + +// from= the bitmap start point +func (l *linkedMap) deleteByBitmap(bmap []uint64, from uint32, tailBitsLen uint32) (deleted, missed int32, lastContinued bool) { + var start = l.qmap[from] + if start != nil { + // delete predecessors + if pred := start.prev; pred != nil { + _, n := l.deleteBefore(pred) + deleted += n + } + } else { + // [from] is out of bounds + return + } + var j, bitsLen uint32 + var bits uint64 + bits, bmap = bmap[0], bmap[1:] + if len(bmap) > 0 { + bitsLen = 64 + } else { + // bmap.len==1, tail is here + bitsLen = tailBitsLen + } + // maxContinued will save the max continued node (from [start]) which could be deleted safely. + // keep the queue smallest + var maxContinued *qNode + lastContinued = true + + for i := start; i != nil; j++ { + if j >= bitsLen { + if len(bmap) > 0 { + j = 0 + bits, bmap = bmap[0], bmap[1:] + if len(bmap) > 0 { + bitsLen = 64 + } else { + bitsLen = tailBitsLen + } + } else { + // no more pages + goto finished + } + } + if bits&1 == 1 { + if lastContinued { + maxContinued = i + } + if i.scnt != _SENT_OK { + // no mark means first deleting + deleted++ + } + // don't delete, just mark it + i.scnt = _SENT_OK + } else { + // known it may be lost + if i.miss == 0 { + missed++ + } + i.miss++ + lastContinued = false + } + bits >>= 1 + i = i.next + } + +finished: + if maxContinued != nil { + l.deleteBefore(maxContinued) + } + return +} diff --git a/weed/udptransfer/linked_test.go b/weed/udptransfer/linked_test.go new file mode 100644 index 000000000..563d6264a --- /dev/null +++ b/weed/udptransfer/linked_test.go @@ -0,0 +1,236 @@ +package udptransfer + +import ( + "fmt" + "math/rand" + "testing" +) + +var lmap *linkedMap + +func init() { + lmap = newLinkedMap(_QModeIn) +} + +func node(seq int) *qNode { + return &qNode{packet: &packet{seq: uint32(seq)}} +} + +func Test_get(t *testing.T) { + var i, n *qNode + assert(!lmap.contains(0), t, "0 nil") + n = node(0) + lmap.head = n + lmap.tail = n + lmap.qmap[0] = n + i = lmap.get(0) + assert(i == n, t, "0=n") +} + +func Test_insert(t *testing.T) { + lmap.reset() + n := node(1) + // appendTail + lmap.appendTail(n) + assert(lmap.head == n, t, "head n") + assert(lmap.tail == n, t, "head n") + n = node(2) + lmap.appendTail(n) + assert(lmap.head != n, t, "head n") + assert(lmap.tail == n, t, "head n") + assert(lmap.size() == 2, t, "size") +} + +func Test_insertAfter(t *testing.T) { + n1 := lmap.get(1) + n2 := n1.next + n3 := node(3) + lmap.insertAfter(n1, n3) + assert(n1.next == n3, t, "n3") + assert(n1 == n3.prev, t, "left n3") + assert(n2 == n3.next, t, "n3 right") +} + +func Test_insertBefore(t *testing.T) { + n3 := lmap.get(3) + n2 := n3.next + n4 := node(4) + lmap.insertAfter(n3, n4) + assert(n3.next == n4, t, "n4") + assert(n3 == n4.prev, t, "left n4") + assert(n2 == n4.next, t, "n4 right") +} + +func Test_deleteBefore(t *testing.T) { + lmap.reset() + for i := 1; i < 10; i++ { + n := node(i) + lmap.appendTail(n) + } + + var assertRangeEquals = func(n *qNode, start, wantCount int) { + var last *qNode + var count int + for ; n != nil; n = n.next { + assert(int(n.seq) == start, t, "nseq=%d start=%d", n.seq, start) + last = n + start++ + count++ + } + assert(last.next == nil, t, "tail nil") + assert(count == wantCount, t, "count") + } + assertRangeEquals(lmap.head, 1, 9) + var n *qNode + n = lmap.get(3) + n, _ = lmap.deleteBefore(n) + assertRangeEquals(n, 1, 3) + assert(lmap.head.seq == 4, t, "head") + + n = lmap.get(8) + n, _ = lmap.deleteBefore(n) + assertRangeEquals(n, 4, 5) + assert(lmap.head.seq == 9, t, "head") + + n = lmap.get(9) + n, _ = lmap.deleteBefore(n) + assertRangeEquals(n, 9, 1) + + assert(lmap.size() == 0, t, "size 0") + assert(lmap.head == nil, t, "head nil") + assert(lmap.tail == nil, t, "tail nil") +} + +func testBitmap(t *testing.T, bmap []uint64, prev uint32) { + var j uint + var k int + bits := bmap[k] + t.Logf("test-%d %016x", k, bits) + var checkNextPage = func() { + if j >= 64 { + j = 0 + k++ + bits = bmap[k] + t.Logf("test-%d %016x", k, bits) + } + } + for i := lmap.head; i != nil && k < len(bmap); i = i.next { + checkNextPage() + dis := i.seq - prev + prev = i.seq + if dis == 1 { + bit := (bits >> j) & 1 + assert(bit == 1, t, "1 bit=%d j=%d", bit, j) + j++ + } else { + for ; dis > 0; dis-- { + checkNextPage() + bit := (bits >> j) & 1 + want := uint64(0) + if dis == 1 { + want = 1 + } + assert(bit == want, t, "?=%d bit=%d j=%d", want, bit, j) + j++ + } + } + } + // remains bits should be 0 + for i := j & 63; i > 0; i-- { + bit := (bits >> j) & 1 + assert(bit == 0, t, "00 bit=%d j=%d", bit, j) + j++ + } +} + +func Test_bitmap(t *testing.T) { + var prev uint32 + var head uint32 = prev + 1 + + lmap.reset() + // test 66-%3 and record holes + var holes = make([]uint32, 0, 50) + for i := head; i < 366; i++ { + if i%3 == 0 { + holes = append(holes, i) + continue + } + n := node(int(i)) + lmap.appendTail(n) + } + bmap, tbl := lmap.makeHolesBitmap(prev) + testBitmap(t, bmap, prev) + + lmap.reset() + // full 66, do deleteByBitmap then compare + for i := head; i < 366; i++ { + n := node(int(i)) + lmap.appendTail(n) + } + + lmap.deleteByBitmap(bmap, head, tbl) + var holesResult = make([]uint32, 0, 50) + for i := lmap.head; i != nil; i = i.next { + if i.scnt != _SENT_OK { + holesResult = append(holesResult, i.seq) + } + } + a := fmt.Sprintf("%x", holes) + b := fmt.Sprintf("%x", holesResult) + assert(a == b, t, "deleteByBitmap \na=%s \nb=%s", a, b) + + lmap.reset() + // test stride across page 1 + for i := head; i < 69; i++ { + if i >= 63 && i <= 65 { + continue + } + n := node(int(i)) + lmap.appendTail(n) + } + bmap, _ = lmap.makeHolesBitmap(prev) + testBitmap(t, bmap, prev) + + lmap.reset() + prev = 65 + head = prev + 1 + // test stride across page 0 + for i := head; i < 68; i++ { + n := node(int(i)) + lmap.appendTail(n) + } + bmap, _ = lmap.makeHolesBitmap(prev) + testBitmap(t, bmap, prev) +} + +var ackbitmap []uint64 + +func init_benchmark_map() { + if lmap.size() != 640 { + lmap.reset() + for i := 1; i <= 640; i++ { + lmap.appendTail(node(i)) + } + ackbitmap = make([]uint64, 10) + for i := 0; i < len(ackbitmap); i++ { + n := rand.Int63() + ackbitmap[i] = uint64(n) << 1 + } + } +} + +func Benchmark_make_bitmap(b *testing.B) { + init_benchmark_map() + + for i := 0; i < b.N; i++ { + lmap.makeHolesBitmap(0) + } +} + +func Benchmark_apply_bitmap(b *testing.B) { + init_benchmark_map() + + for i := 0; i < b.N; i++ { + lmap.deleteByBitmap(ackbitmap, 1, 64) + } +} diff --git a/weed/udptransfer/packet.go b/weed/udptransfer/packet.go new file mode 100644 index 000000000..2413f1ccf --- /dev/null +++ b/weed/udptransfer/packet.go @@ -0,0 +1,142 @@ +package udptransfer + +import ( + "encoding/binary" + "fmt" +) + +const ( + _F_NIL = 0 + _F_SYN = 1 + _F_ACK = 1 << 1 + _F_SACK = 1 << 2 + _F_TIME = 1 << 3 + _F_DATA = 1 << 4 + // reserved = 1 << 5 + _F_RESET = 1 << 6 + _F_FIN = 1 << 7 +) + +var packetTypeNames = map[byte]string{ + 0: "NOOP", + 1: "SYN", + 2: "ACK", + 3: "SYN+ACK", + 4: "SACK", + 8: "TIME", + 12: "SACK+TIME", + 16: "DATA", + 64: "RESET", + 128: "FIN", + 192: "FIN+RESET", +} + +const ( + _S_FIN = iota + _S_FIN0 + _S_FIN1 + _S_SYN0 + _S_SYN1 + _S_EST0 + _S_EST1 +) + +// Magic-6 | TH-10 | CH-10 | payload +const ( + _MAGIC_SIZE = 6 + _TH_SIZE = 10 + _MAGIC_SIZE + _CH_SIZE = 10 + _AH_SIZE = _TH_SIZE + _CH_SIZE +) + +const ( + // Max UDP payload: 1500 MTU - 20 IP hdr - 8 UDP hdr = 1472 bytes + // Then: MSS = 1472-26 = 1446 + // And For ADSL: 1446-8 = 1438 + _MSS = 1438 +) + +const ( + _SENT_OK = 0xff +) + +type packet struct { + seq uint32 + ack uint32 + flag uint8 + scnt uint8 + payload []byte + buffer []byte +} + +func (p *packet) marshall(id connID) []byte { + buf := p.buffer + if buf == nil { + buf = make([]byte, _AH_SIZE+len(p.payload)) + copy(buf[_TH_SIZE+10:], p.payload) + } + binary.BigEndian.PutUint16(buf[_MAGIC_SIZE:], uint16(len(buf))) + binary.BigEndian.PutUint32(buf[_MAGIC_SIZE+2:], id.rid) + binary.BigEndian.PutUint32(buf[_MAGIC_SIZE+6:], id.lid) + binary.BigEndian.PutUint32(buf[_TH_SIZE:], p.seq) + binary.BigEndian.PutUint32(buf[_TH_SIZE+4:], p.ack) + buf[_TH_SIZE+8] = p.flag + buf[_TH_SIZE+9] = p.scnt + return buf +} + +func unmarshall(pk *packet, buf []byte) { + if len(buf) >= _CH_SIZE { + pk.seq = binary.BigEndian.Uint32(buf) + pk.ack = binary.BigEndian.Uint32(buf[4:]) + pk.flag = buf[8] + pk.scnt = buf[9] + pk.payload = buf[10:] + } +} + +func (n *qNode) String() string { + now := Now() + return fmt.Sprintf("type=%s seq=%d scnt=%d sndtime~=%d,%d miss=%d", + packetTypeNames[n.flag], n.seq, n.scnt, n.sent-now, n.sent_1-now, n.miss) +} + +func maxI64(a, b int64) int64 { + if a >= b { + return a + } else { + return b + } +} + +func maxU32(a, b uint32) uint32 { + if a >= b { + return a + } else { + return b + } +} + +func minI64(a, b int64) int64 { + if a <= b { + return a + } else { + return b + } +} + +func maxI32(a, b int32) int32 { + if a >= b { + return a + } else { + return b + } +} + +func minI32(a, b int32) int32 { + if a <= b { + return a + } else { + return b + } +} diff --git a/weed/udptransfer/state.go b/weed/udptransfer/state.go new file mode 100644 index 000000000..e0b4f1791 --- /dev/null +++ b/weed/udptransfer/state.go @@ -0,0 +1,516 @@ +package udptransfer + +import ( + "errors" + "log" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + _10ms = time.Millisecond * 10 + _100ms = time.Millisecond * 100 +) + +const ( + _FIN_ACK_SEQ uint32 = 0xffFF0000 + _INVALID_SEQ uint32 = 0xffFFffFF +) + +var ( + ErrIOTimeout error = &TimeoutError{} + ErrUnknown = errors.New("Unknown error") + ErrInexplicableData = errors.New("Inexplicable data") + ErrTooManyAttempts = errors.New("Too many attempts to connect") +) + +type TimeoutError struct{} + +func (e *TimeoutError) Error() string { return "i/o timeout" } +func (e *TimeoutError) Timeout() bool { return true } +func (e *TimeoutError) Temporary() bool { return true } + +type Conn struct { + sock *net.UDPConn + dest *net.UDPAddr + edp *Endpoint + connID connID // 8 bytes + // events + evRecv chan []byte + evRead chan byte + evSend chan byte + evSWnd chan byte + evAck chan byte + evClose chan byte + // protocol state + inlock sync.Mutex + outlock sync.Mutex + state int32 + mySeq uint32 + swnd int32 + cwnd int32 + missed int32 + outPending int32 + lastAck uint32 + lastAckTime int64 + lastAckTime2 int64 + lastShrink int64 + lastRstMis int64 + ato int64 + rto int64 + rtt int64 + srtt int64 + mdev int64 + rtmo int64 + wtmo int64 + tSlot int64 + tSlotT0 int64 + lastSErr int64 + // queue + outQ *linkedMap + inQ *linkedMap + inQReady []byte + inQDirty bool + lastReadSeq uint32 // last user read seq + // params + bandwidth int64 + fastRetransmit bool + flatTraffic bool + mss int + // statistics + urgent int + inPkCnt int + inDupCnt int + outPkCnt int + outDupCnt int + fRCnt int +} + +func NewConn(e *Endpoint, dest *net.UDPAddr, id connID) *Conn { + c := &Conn{ + sock: e.udpconn, + dest: dest, + edp: e, + connID: id, + evRecv: make(chan []byte, 128), + evRead: make(chan byte, 1), + evSWnd: make(chan byte, 2), + evSend: make(chan byte, 4), + evAck: make(chan byte, 1), + evClose: make(chan byte, 2), + outQ: newLinkedMap(_QModeOut), + inQ: newLinkedMap(_QModeIn), + } + p := e.params + c.bandwidth = p.Bandwidth + c.fastRetransmit = p.FastRetransmit + c.flatTraffic = p.FlatTraffic + c.mss = _MSS + if dest.IP.To4() == nil { + // typical ipv6 header length=40 + c.mss -= 20 + } + return c +} + +func (c *Conn) initConnection(buf []byte) (err error) { + if buf == nil { + err = c.initDialing() + } else { //server + err = c.acceptConnection(buf[_TH_SIZE:]) + } + if err != nil { + return + } + if c.state == _S_EST1 { + c.lastReadSeq = c.lastAck + c.inQ.maxCtnSeq = c.lastAck + c.rtt = maxI64(c.rtt, _MIN_RTT) + c.mdev = c.rtt << 1 + c.srtt = c.rtt << 3 + c.rto = maxI64(c.rtt*2, _MIN_RTO) + c.ato = maxI64(c.rtt>>4, _MIN_ATO) + c.ato = minI64(c.ato, _MAX_ATO) + // initial cwnd + c.swnd = calSwnd(c.bandwidth, c.rtt) >> 1 + c.cwnd = 8 + go c.internalRecvLoop() + go c.internalSendLoop() + go c.internalAckLoop() + if debug >= 0 { + go c.internal_state() + } + return nil + } else { + return ErrUnknown + } +} + +func (c *Conn) initDialing() error { + // first syn + pk := &packet{ + seq: c.mySeq, + flag: _F_SYN, + } + item := nodeOf(pk) + var buf []byte + c.state = _S_SYN0 + t0 := Now() + for i := 0; i < _MAX_RETRIES && c.state == _S_SYN0; i++ { + // send syn + c.internalWrite(item) + select { + case buf = <-c.evRecv: + c.rtt = Now() - t0 + c.state = _S_SYN1 + c.connID.setRid(buf) + buf = buf[_TH_SIZE:] + case <-time.After(time.Second): + continue + } + } + if c.state == _S_SYN0 { + return ErrTooManyAttempts + } + + unmarshall(pk, buf) + // expected syn+ack + if pk.flag == _F_SYN|_F_ACK && pk.ack == c.mySeq { + if scnt := pk.scnt - 1; scnt > 0 { + c.rtt -= int64(scnt) * 1e3 + } + log.Println("rtt", c.rtt) + c.state = _S_EST0 + // build ack3 + pk.scnt = 0 + pk.ack = pk.seq + pk.flag = _F_ACK + item := nodeOf(pk) + // send ack3 + c.internalWrite(item) + // update lastAck + c.logAck(pk.ack) + c.state = _S_EST1 + return nil + } else { + return ErrInexplicableData + } +} + +func (c *Conn) acceptConnection(buf []byte) error { + var pk = new(packet) + var item *qNode + unmarshall(pk, buf) + // expected syn + if pk.flag == _F_SYN { + c.state = _S_SYN1 + // build syn+ack + pk.ack = pk.seq + pk.seq = c.mySeq + pk.flag |= _F_ACK + // update lastAck + c.logAck(pk.ack) + item = nodeOf(pk) + item.scnt = pk.scnt - 1 + } else { + dumpb("Syn1 ?", buf) + return ErrInexplicableData + } + for i := 0; i < 5 && c.state == _S_SYN1; i++ { + t0 := Now() + // reply syn+ack + c.internalWrite(item) + // recv ack3 + select { + case buf = <-c.evRecv: + c.state = _S_EST0 + c.rtt = Now() - t0 + buf = buf[_TH_SIZE:] + log.Println("rtt", c.rtt) + case <-time.After(time.Second): + continue + } + } + if c.state == _S_SYN1 { + return ErrTooManyAttempts + } + + pk = new(packet) + unmarshall(pk, buf) + // expected ack3 + if pk.flag == _F_ACK && pk.ack == c.mySeq { + c.state = _S_EST1 + } else { + // if ack3 lost, resend syn+ack 3-times + // and drop these coming data + if pk.flag&_F_DATA != 0 && pk.seq > c.lastAck { + c.internalWrite(item) + c.state = _S_EST1 + } else { + dumpb("Ack3 ?", buf) + return ErrInexplicableData + } + } + return nil +} + +// 20,20,20,20, 100,100,100,100, 1s,1s,1s,1s +func selfSpinWait(fn func() bool) error { + const _MAX_SPIN = 12 + for i := 0; i < _MAX_SPIN; i++ { + if fn() { + return nil + } else if i <= 3 { + time.Sleep(_10ms * 2) + } else if i <= 7 { + time.Sleep(_100ms) + } else { + time.Sleep(time.Second) + } + } + return ErrIOTimeout +} + +func (c *Conn) IsClosed() bool { + return atomic.LoadInt32(&c.state) <= _S_FIN1 +} + +/* +active close: +1 <- send fin-W: closeW() + before sending, ensure all outQ items has beed sent out and all of them has been acked. +2 -> wait to recv ack{fin-W} + then trigger closeR, including send fin-R and wait to recv ack{fin-R} + +passive close: +-> fin: + if outQ is not empty then self-spin wait. + if outQ empty, send ack{fin-W} then goto closeW(). +*/ +func (c *Conn) Close() (err error) { + if !atomic.CompareAndSwapInt32(&c.state, _S_EST1, _S_FIN0) { + return selfSpinWait(func() bool { + return atomic.LoadInt32(&c.state) == _S_FIN + }) + } + var err0 error + err0 = c.closeW() + // waiting for fin-2 of peer + err = selfSpinWait(func() bool { + select { + case v := <-c.evClose: + if v == _S_FIN { + return true + } else { + time.AfterFunc(_100ms, func() { c.evClose <- v }) + } + default: + } + return false + }) + defer c.afterShutdown() + if err != nil { + // backup path for wait ack(finW) timeout + c.closeR(nil) + } + if err0 != nil { + return err0 + } else { + return + } +} + +func (c *Conn) beforeCloseW() (err error) { + // check outQ was empty and all has been acked. + // self-spin waiting + for i := 0; i < 2; i++ { + err = selfSpinWait(func() bool { + return atomic.LoadInt32(&c.outPending) <= 0 + }) + if err == nil { + break + } + } + // send fin, reliably + c.outlock.Lock() + c.mySeq++ + c.outPending++ + pk := &packet{seq: c.mySeq, flag: _F_FIN} + item := nodeOf(pk) + c.outQ.appendTail(item) + c.internalWrite(item) + c.outlock.Unlock() + c.evSWnd <- _VSWND_ACTIVE + return +} + +func (c *Conn) closeW() (err error) { + // close resource of sending + defer c.afterCloseW() + // send fin + err = c.beforeCloseW() + var closed bool + var max = 20 + if c.rtt > 200 { + max = int(c.rtt) / 10 + } + // waiting for outQ means: + // 1. all outQ has been acked, for passive + // 2. fin has been acked, for active + for i := 0; i < max && (atomic.LoadInt32(&c.outPending) > 0 || !closed); i++ { + select { + case v := <-c.evClose: + if v == _S_FIN0 { + // namely, last fin has been acked. + closed = true + } else { + time.AfterFunc(_100ms, func() { c.evClose <- v }) + } + case <-time.After(_100ms): + } + } + if closed || err != nil { + return + } else { + return ErrIOTimeout + } +} + +func (c *Conn) afterCloseW() { + // can't close(c.evRecv), avoid endpoint dispatch exception + // stop pending inputAndSend + select { + case c.evSend <- _CLOSE: + default: + } + // stop internalSendLoop + c.evSWnd <- _CLOSE +} + +// called by active and passive close() +func (c *Conn) afterShutdown() { + // stop internalRecvLoop + c.evRecv <- nil + // remove registry + c.edp.removeConn(c.connID, c.dest) + log.Println("shutdown", c.state) +} + +// trigger by reset +func (c *Conn) forceShutdownWithLock() { + c.outlock.Lock() + defer c.outlock.Unlock() + c.forceShutdown() +} + +// called by: +// 1/ send exception +// 2/ recv reset +// drop outQ and force shutdown +func (c *Conn) forceShutdown() { + if atomic.CompareAndSwapInt32(&c.state, _S_EST1, _S_FIN) { + defer c.afterShutdown() + // stop sender + for i := 0; i < cap(c.evSend); i++ { + select { + case <-c.evSend: + default: + } + } + select { + case c.evSend <- _CLOSE: + default: + } + c.outQ.reset() + // stop reader + close(c.evRead) + c.inQ.reset() + // stop internalLoops + c.evSWnd <- _CLOSE + c.evAck <- _CLOSE + //log.Println("force shutdown") + } +} + +// for sending fin failed +func (c *Conn) fakeShutdown() { + select { + case c.evClose <- _S_FIN0: + default: + } +} + +func (c *Conn) closeR(pk *packet) { + var passive = true + for { + state := atomic.LoadInt32(&c.state) + switch state { + case _S_FIN: + return + case _S_FIN1: // multiple FIN, maybe lost + c.passiveCloseReply(pk, false) + return + case _S_FIN0: // active close preformed + passive = false + } + if !atomic.CompareAndSwapInt32(&c.state, state, _S_FIN1) { + continue + } + c.passiveCloseReply(pk, true) + break + } + // here, R is closed. + // ^^^^^^^^^^^^^^^^^^^^^ + if passive { + // passive closing call closeW contains sending fin and recv ack + // may the ack of fin-2 was lost, then the closeW will timeout + c.closeW() + } + // here, R,W both were closed. + // ^^^^^^^^^^^^^^^^^^^^^ + atomic.StoreInt32(&c.state, _S_FIN) + // stop internalAckLoop + c.evAck <- _CLOSE + + if passive { + // close evRecv within here + c.afterShutdown() + } else { + // notify active close thread + select { + case c.evClose <- _S_FIN: + default: + } + } +} + +func (c *Conn) passiveCloseReply(pk *packet, first bool) { + if pk != nil && pk.flag&_F_FIN != 0 { + if first { + c.checkInQ(pk) + close(c.evRead) + } + // ack the FIN + pk = &packet{seq: _FIN_ACK_SEQ, ack: pk.seq, flag: _F_ACK} + item := nodeOf(pk) + c.internalWrite(item) + } +} + +// check inQ ends orderly, and copy queue data to user space +func (c *Conn) checkInQ(pk *packet) { + if nil != selfSpinWait(func() bool { + return c.inQ.maxCtnSeq+1 == pk.seq + }) { // timeout for waiting inQ to finish + return + } + c.inlock.Lock() + defer c.inlock.Unlock() + if c.inQ.size() > 0 { + for i := c.inQ.head; i != nil; i = i.next { + c.inQReady = append(c.inQReady, i.payload...) + } + } +} diff --git a/weed/udptransfer/stopwatch.go b/weed/udptransfer/stopwatch.go new file mode 100644 index 000000000..2ee4feb57 --- /dev/null +++ b/weed/udptransfer/stopwatch.go @@ -0,0 +1,36 @@ +package udptransfer + +import ( + "fmt" + "os" + "time" +) + +type watch struct { + label string + t1 time.Time +} + +func StartWatch(s string) *watch { + return &watch{ + label: s, + t1: time.Now(), + } +} + +func (w *watch) StopLoops(loop int, size int) { + tu := time.Now().Sub(w.t1).Nanoseconds() + timePerLoop := float64(tu) / float64(loop) + throughput := float64(loop*size) * 1e6 / float64(tu) + tu_ms := float64(tu) / 1e6 + fmt.Fprintf(os.Stderr, "%s tu=%.2f ms tpl=%.0f ns throughput=%.2f K/s\n", w.label, tu_ms, timePerLoop, throughput) +} + +var _kt = float64(1e9 / 1024) + +func (w *watch) Stop(size int) { + tu := time.Now().Sub(w.t1).Nanoseconds() + throughput := float64(size) * _kt / float64(tu) + tu_ms := float64(tu) / 1e6 + fmt.Fprintf(os.Stderr, "%s tu=%.2f ms throughput=%.2f K/s\n", w.label, tu_ms, throughput) +} diff --git a/weed/udptransfer/timer.go b/weed/udptransfer/timer.go new file mode 100644 index 000000000..adcfe50c3 --- /dev/null +++ b/weed/udptransfer/timer.go @@ -0,0 +1,18 @@ +package udptransfer + +import "time" + +const Millisecond = 1e6 + +func Now() int64 { + return time.Now().UnixNano()/Millisecond +} + +func NowNS() int64 { + return time.Now().UnixNano() +} + +func NewTimerChan(d int64) <-chan time.Time { + ticker := time.NewTimer(time.Duration(d) * time.Millisecond) + return ticker.C +} diff --git a/weed/udptransfer/timing_test.go b/weed/udptransfer/timing_test.go new file mode 100644 index 000000000..4f73bc33a --- /dev/null +++ b/weed/udptransfer/timing_test.go @@ -0,0 +1,32 @@ +package udptransfer + +import ( + "math" + "testing" + "time" +) + +func Test_sleep(t *testing.T) { + const loops = 10 + var intervals = [...]int64{1, 1e3, 1e4, 1e5, 1e6, 1e7} + var ret [len(intervals)][loops]int64 + for j := 0; j < len(intervals); j++ { + v := time.Duration(intervals[j]) + for i := 0; i < loops; i++ { + t0 := NowNS() + time.Sleep(v) + ret[j][i] = NowNS() - t0 + } + } + for j := 0; j < len(intervals); j++ { + var exp, sum, stdev float64 + exp = float64(intervals[j]) + for _, v := range ret[j] { + sum += float64(v) + stdev += math.Pow(float64(v)-exp, 2) + } + stdev /= float64(loops) + stdev = math.Sqrt(stdev) + t.Logf("interval=%s sleeping=%s stdev/k=%.2f", time.Duration(intervals[j]), time.Duration(sum/loops), stdev/1e3) + } +} diff --git a/weed/wdclient/volume_udp_client.go b/weed/wdclient/volume_udp_client.go index b8bdcec90..93fd2b227 100644 --- a/weed/wdclient/volume_udp_client.go +++ b/weed/wdclient/volume_udp_client.go @@ -1,21 +1,68 @@ package wdclient import ( + "bufio" + "bytes" + "fmt" "github.com/chrislusf/seaweedfs/weed/glog" "github.com/chrislusf/seaweedfs/weed/pb" + "github.com/chrislusf/seaweedfs/weed/udptransfer" + "github.com/chrislusf/seaweedfs/weed/util" + "github.com/chrislusf/seaweedfs/weed/wdclient/net2" "io" - "pack.ag/tftp" + "net" + "time" ) -// VolumeTcpClient put/get/delete file chunks directly on volume servers without replication +// VolumeUdpClient put/get/delete file chunks directly on volume servers without replication type VolumeUdpClient struct { + cp net2.ConnectionPool +} + +type VolumeUdpConn struct { + net.Conn + bufWriter *bufio.Writer + bufReader *bufio.Reader } func NewVolumeUdpClient() *VolumeUdpClient { + MaxIdleTime := 10 * time.Second return &VolumeUdpClient{ + cp: net2.NewMultiConnectionPool(net2.ConnectionOptions{ + MaxActiveConnections: 16, + MaxIdleConnections: 1, + MaxIdleTime: &MaxIdleTime, + DialMaxConcurrency: 0, + Dial: func(network string, address string) (net.Conn, error) { + + listener, err := udptransfer.NewEndpoint(&udptransfer.Params{ + LocalAddr: "", + Bandwidth: 100, + FastRetransmit: true, + FlatTraffic: true, + IsServ: false, + }) + if err != nil { + return nil, err + } + + conn, err := listener.Dial(address) + if err != nil { + return nil, err + } + return &VolumeUdpConn{ + conn, + bufio.NewWriter(conn), + bufio.NewReader(conn), + }, err + + }, + NowFunc: nil, + ReadTimeout: 0, + WriteTimeout: 0, + }), } } - func (c *VolumeUdpClient) PutFileChunk(volumeServerAddress string, fileId string, fileSize uint32, fileReader io.Reader) (err error) { udpAddress, parseErr := pb.ParseServerAddress(volumeServerAddress, 20001) @@ -23,24 +70,45 @@ func (c *VolumeUdpClient) PutFileChunk(volumeServerAddress string, fileId string return parseErr } - udpClient, _ := tftp.NewClient( - tftp.ClientMode(tftp.ModeOctet), - tftp.ClientBlocksize(9000), - tftp.ClientWindowsize(16), - tftp.ClientTimeout(1), - tftp.ClientTransferSize(true), - tftp.ClientRetransmit(3), - ) + c.cp.Register("udp", udpAddress) + udpConn, getErr := c.cp.Get("udp", udpAddress) + if getErr != nil { + return fmt.Errorf("get connection to %s: %v", udpAddress, getErr) + } + conn := udpConn.RawConn().(*VolumeUdpConn) + defer func() { + if err != nil { + udpConn.DiscardConnection() + } else { + udpConn.ReleaseConnection() + } + }() - fileUrl := "tftp://" + udpAddress + "/" + fileId + buf := []byte("+" + fileId + "\n") + _, err = conn.bufWriter.Write([]byte(buf)) + if err != nil { + return + } + util.Uint32toBytes(buf[0:4], fileSize) + _, err = conn.bufWriter.Write(buf[0:4]) + if err != nil { + return + } + _, err = io.Copy(conn.bufWriter, fileReader) + if err != nil { + return + } + conn.bufWriter.Write([]byte("!\n")) + conn.bufWriter.Flush() - // println("put", fileUrl, "...") - err = udpClient.Put(fileUrl, fileReader, int64(fileSize)) + ret, _, err := conn.bufReader.ReadLine() if err != nil { - glog.Errorf("udp put %s: %v", fileUrl, err) + glog.V(0).Infof("upload by udp: %v", err) return } - // println("sent", fileUrl) + if !bytes.HasPrefix(ret, []byte("+OK")) { + glog.V(0).Infof("upload by udp: %v", string(ret)) + } - return + return nil }