Chris Lu
4 years ago
16 changed files with 2711 additions and 24 deletions
-
2go.mod
-
4go.sum
-
37weed/command/volume.go
-
715weed/udptransfer/conn.go
-
86weed/udptransfer/conn_test.go
-
95weed/udptransfer/debug.go
-
339weed/udptransfer/endpoint.go
-
42weed/udptransfer/endpoint_test.go
-
333weed/udptransfer/linked.go
-
236weed/udptransfer/linked_test.go
-
142weed/udptransfer/packet.go
-
516weed/udptransfer/state.go
-
36weed/udptransfer/stopwatch.go
-
18weed/udptransfer/timer.go
-
32weed/udptransfer/timing_test.go
-
100weed/wdclient/volume_udp_client.go
@ -0,0 +1,715 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"fmt" |
|||
"io" |
|||
"log" |
|||
"net" |
|||
"time" |
|||
) |
|||
|
|||
const ( |
|||
_MAX_RETRIES = 6 |
|||
_MIN_RTT = 8 |
|||
_MIN_RTO = 30 |
|||
_MIN_ATO = 2 |
|||
_MAX_ATO = 10 |
|||
_MIN_SWND = 10 |
|||
_MAX_SWND = 960 |
|||
) |
|||
|
|||
const ( |
|||
_VACK_SCHED = iota + 1 |
|||
_VACK_QUICK |
|||
_VACK_MUST |
|||
_VSWND_ACTIVE |
|||
_VRETR_IMMED |
|||
) |
|||
|
|||
const ( |
|||
_RETR_REST = -1 |
|||
_CLOSE = 0xff |
|||
) |
|||
|
|||
var debug int |
|||
|
|||
func nodeOf(pk *packet) *qNode { |
|||
return &qNode{packet: pk} |
|||
} |
|||
|
|||
func (c *Conn) internalRecvLoop() { |
|||
defer func() { |
|||
// avoid send to closed channel while some replaying
|
|||
// data packets were received in shutting down.
|
|||
_ = recover() |
|||
}() |
|||
var buf, body []byte |
|||
for { |
|||
select { |
|||
case buf = <-c.evRecv: |
|||
if buf != nil { |
|||
body = buf[_TH_SIZE:] |
|||
} else { // shutdown
|
|||
return |
|||
} |
|||
} |
|||
pk := new(packet) |
|||
// keep the original buffer, so we could recycle it in future
|
|||
pk.buffer = buf |
|||
unmarshall(pk, body) |
|||
if pk.flag&_F_SACK != 0 { |
|||
c.processSAck(pk) |
|||
continue |
|||
} |
|||
if pk.flag&_F_ACK != 0 { |
|||
c.processAck(pk) |
|||
} |
|||
if pk.flag&_F_DATA != 0 { |
|||
c.insertData(pk) |
|||
} else if pk.flag&_F_FIN != 0 { |
|||
if pk.flag&_F_RESET != 0 { |
|||
go c.forceShutdownWithLock() |
|||
} else { |
|||
go c.closeR(pk) |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) internalSendLoop() { |
|||
var timer = time.NewTimer(time.Duration(c.rtt) * time.Millisecond) |
|||
for { |
|||
select { |
|||
case v := <-c.evSWnd: |
|||
switch v { |
|||
case _VRETR_IMMED: |
|||
c.outlock.Lock() |
|||
c.retransmit2() |
|||
c.outlock.Unlock() |
|||
case _VSWND_ACTIVE: |
|||
timer.Reset(time.Duration(c.rtt) * time.Millisecond) |
|||
case _CLOSE: |
|||
return |
|||
} |
|||
case <-timer.C: // timeout yet
|
|||
var notifySender bool |
|||
c.outlock.Lock() |
|||
rest, _ := c.retransmit() |
|||
switch rest { |
|||
case _RETR_REST, 0: // nothing to send
|
|||
if c.outQ.size() > 0 { |
|||
timer.Reset(time.Duration(c.rtt) * time.Millisecond) |
|||
} else { |
|||
timer.Stop() |
|||
// avoid sender blocking
|
|||
notifySender = true |
|||
} |
|||
default: // recent rto point
|
|||
timer.Reset(time.Duration(minI64(rest, c.rtt)) * time.Millisecond) |
|||
} |
|||
c.outlock.Unlock() |
|||
if notifySender { |
|||
select { |
|||
case c.evSend <- 1: |
|||
default: |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) internalAckLoop() { |
|||
// var ackTimer = time.NewTicker(time.Duration(c.ato))
|
|||
var ackTimer = time.NewTimer(time.Duration(c.ato) * time.Millisecond) |
|||
var lastAckState byte |
|||
for { |
|||
var v byte |
|||
select { |
|||
case <-ackTimer.C: |
|||
// may cause sending duplicated ack if ato>rtt
|
|||
v = _VACK_QUICK |
|||
case v = <-c.evAck: |
|||
ackTimer.Reset(time.Duration(c.ato) * time.Millisecond) |
|||
state := lastAckState |
|||
lastAckState = v |
|||
if state != v { |
|||
if v == _CLOSE { |
|||
return |
|||
} |
|||
v = _VACK_MUST |
|||
} |
|||
} |
|||
c.inlock.Lock() |
|||
if pkAck := c.makeAck(v); pkAck != nil { |
|||
c.internalWrite(nodeOf(pkAck)) |
|||
} |
|||
c.inlock.Unlock() |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) retransmit() (rest int64, count int32) { |
|||
var now, rto = Now(), c.rto |
|||
var limit = c.cwnd |
|||
for item := c.outQ.head; item != nil && limit > 0; item = item.next { |
|||
if item.scnt != _SENT_OK { // ACKed has scnt==-1
|
|||
diff := now - item.sent |
|||
if diff > rto { // already rto
|
|||
c.internalWrite(item) |
|||
count++ |
|||
} else { |
|||
// continue search next min rto duration
|
|||
if rest > 0 { |
|||
rest = minI64(rest, rto-diff+1) |
|||
} else { |
|||
rest = rto - diff + 1 |
|||
} |
|||
limit-- |
|||
} |
|||
} |
|||
} |
|||
c.outDupCnt += int(count) |
|||
if count > 0 { |
|||
shrcond := (c.fastRetransmit && count > maxI32(c.cwnd>>5, 4)) || (!c.fastRetransmit && count > c.cwnd>>3) |
|||
if shrcond && now-c.lastShrink > c.rto { |
|||
log.Printf("shrink cwnd from=%d to=%d s/4=%d", c.cwnd, c.cwnd>>1, c.swnd>>2) |
|||
c.lastShrink = now |
|||
// shrink cwnd and ensure cwnd >= swnd/4
|
|||
if c.cwnd > c.swnd>>1 { |
|||
c.cwnd >>= 1 |
|||
} |
|||
} |
|||
} |
|||
if c.outQ.size() > 0 { |
|||
return |
|||
} |
|||
return _RETR_REST, 0 |
|||
} |
|||
|
|||
func (c *Conn) retransmit2() (count int32) { |
|||
var limit, now = minI32(c.outPending>>4, 8), Now() |
|||
var fRtt = c.rtt |
|||
if now-c.lastShrink > c.rto { |
|||
fRtt += maxI64(c.rtt>>4, 1) |
|||
} else { |
|||
fRtt += maxI64(c.rtt>>1, 2) |
|||
} |
|||
for item := c.outQ.head; item != nil && count < limit; item = item.next { |
|||
if item.scnt != _SENT_OK { // ACKed has scnt==-1
|
|||
if item.miss >= 3 && now-item.sent >= fRtt { |
|||
item.miss = 0 |
|||
c.internalWrite(item) |
|||
count++ |
|||
} |
|||
} |
|||
} |
|||
c.fRCnt += int(count) |
|||
c.outDupCnt += int(count) |
|||
return |
|||
} |
|||
|
|||
func (c *Conn) inputAndSend(pk *packet) error { |
|||
item := &qNode{packet: pk} |
|||
if c.mySeq&3 == 1 { |
|||
c.tSlotT0 = NowNS() |
|||
} |
|||
c.outlock.Lock() |
|||
// inflight packets exceeds cwnd
|
|||
// inflight includes: 1, unacked; 2, missed
|
|||
for c.outPending >= c.cwnd+c.missed { |
|||
c.outlock.Unlock() |
|||
if c.wtmo > 0 { |
|||
var tmo int64 |
|||
tmo, c.wtmo = c.wtmo, 0 |
|||
select { |
|||
case v := <-c.evSend: |
|||
if v == _CLOSE { |
|||
return io.EOF |
|||
} |
|||
case <-NewTimerChan(tmo): |
|||
return ErrIOTimeout |
|||
} |
|||
} else { |
|||
if v := <-c.evSend; v == _CLOSE { |
|||
return io.EOF |
|||
} |
|||
} |
|||
c.outlock.Lock() |
|||
} |
|||
c.outPending++ |
|||
c.outPkCnt++ |
|||
c.mySeq++ |
|||
pk.seq = c.mySeq |
|||
c.outQ.appendTail(item) |
|||
c.internalWrite(item) |
|||
c.outlock.Unlock() |
|||
// active resending timer, must blocking
|
|||
c.evSWnd <- _VSWND_ACTIVE |
|||
if c.mySeq&3 == 0 && c.flatTraffic { |
|||
// calculate time error bewteen tslot with actual usage.
|
|||
// consider last sleep time error
|
|||
t1 := NowNS() |
|||
terr := c.tSlot<<2 - c.lastSErr - (t1 - c.tSlotT0) |
|||
// rest terr/2 if current time usage less than tslot of 100us.
|
|||
if terr > 1e5 { // 100us
|
|||
time.Sleep(time.Duration(terr >> 1)) |
|||
c.lastSErr = maxI64(NowNS()-t1-terr, 0) |
|||
} else { |
|||
c.lastSErr >>= 1 |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (c *Conn) internalWrite(item *qNode) { |
|||
if item.scnt >= 20 { |
|||
// no exception of sending fin
|
|||
if item.flag&_F_FIN != 0 { |
|||
c.fakeShutdown() |
|||
c.dest = nil |
|||
return |
|||
} else { |
|||
log.Println("Warn: too many retries", item) |
|||
if c.urgent > 0 { // abort
|
|||
c.forceShutdown() |
|||
return |
|||
} else { // continue to retry 10
|
|||
c.urgent++ |
|||
item.scnt = 10 |
|||
} |
|||
} |
|||
} |
|||
// update current sent time and prev sent time
|
|||
item.sent, item.sent_1 = Now(), item.sent |
|||
item.scnt++ |
|||
buf := item.marshall(c.connID) |
|||
if debug >= 3 { |
|||
var pkType = packetTypeNames[item.flag] |
|||
if item.flag&_F_SACK != 0 { |
|||
log.Printf("send %s trp=%d on=%d %x", pkType, item.seq, item.ack, buf[_AH_SIZE+4:]) |
|||
} else { |
|||
log.Printf("send %s seq=%d ack=%d scnt=%d len=%d", pkType, item.seq, item.ack, item.scnt, len(buf)-_TH_SIZE) |
|||
} |
|||
} |
|||
c.sock.WriteToUDP(buf, c.dest) |
|||
} |
|||
|
|||
func (c *Conn) logAck(ack uint32) { |
|||
c.lastAck = ack |
|||
c.lastAckTime = Now() |
|||
} |
|||
|
|||
func (c *Conn) makeLastAck() (pk *packet) { |
|||
c.inlock.Lock() |
|||
defer c.inlock.Unlock() |
|||
if Now()-c.lastAckTime < c.rtt { |
|||
return nil |
|||
} |
|||
pk = &packet{ |
|||
ack: maxU32(c.lastAck, c.inQ.maxCtnSeq), |
|||
flag: _F_ACK, |
|||
} |
|||
c.logAck(pk.ack) |
|||
return |
|||
} |
|||
|
|||
func (c *Conn) makeAck(level byte) (pk *packet) { |
|||
now := Now() |
|||
if level < _VACK_MUST && now-c.lastAckTime < c.ato { |
|||
if level < _VACK_QUICK || now-c.lastAckTime < minI64(c.ato>>2, 1) { |
|||
return |
|||
} |
|||
} |
|||
// ready Q <-|
|
|||
// |-> outQ start (or more right)
|
|||
// |-> bitmap start
|
|||
// [predecessor] [predecessor+1] [predecessor+2] .....
|
|||
var fakeSAck bool |
|||
var predecessor = c.inQ.maxCtnSeq |
|||
bmap, tbl := c.inQ.makeHolesBitmap(predecessor) |
|||
if len(bmap) <= 0 { // fake sack
|
|||
bmap = make([]uint64, 1) |
|||
bmap[0], tbl = 1, 1 |
|||
fakeSAck = true |
|||
} |
|||
// head 4-byte: TBL:1 | SCNT:1 | DELAY:2
|
|||
buf := make([]byte, len(bmap)*8+4) |
|||
pk = &packet{ |
|||
ack: predecessor + 1, |
|||
flag: _F_SACK, |
|||
payload: buf, |
|||
} |
|||
if fakeSAck { |
|||
pk.ack-- |
|||
} |
|||
buf[0] = byte(tbl) |
|||
// mark delayed time according to the time reference point
|
|||
if trp := c.inQ.lastIns; trp != nil { |
|||
delayed := now - trp.sent |
|||
if delayed < c.rtt { |
|||
pk.seq = trp.seq |
|||
pk.flag |= _F_TIME |
|||
buf[1] = trp.scnt |
|||
if delayed <= 0 { |
|||
delayed = 1 |
|||
} |
|||
binary.BigEndian.PutUint16(buf[2:], uint16(delayed)) |
|||
} |
|||
} |
|||
buf1 := buf[4:] |
|||
for i, b := range bmap { |
|||
binary.BigEndian.PutUint64(buf1[i*8:], b) |
|||
} |
|||
c.logAck(predecessor) |
|||
return |
|||
} |
|||
|
|||
func unmarshallSAck(data []byte) (bmap []uint64, tbl uint32, delayed uint16, scnt uint8) { |
|||
if len(data) > 0 { |
|||
bmap = make([]uint64, len(data)>>3) |
|||
} else { |
|||
return |
|||
} |
|||
tbl = uint32(data[0]) |
|||
scnt = data[1] |
|||
delayed = binary.BigEndian.Uint16(data[2:]) |
|||
data = data[4:] |
|||
for i := 0; i < len(bmap); i++ { |
|||
bmap[i] = binary.BigEndian.Uint64(data[i*8:]) |
|||
} |
|||
return |
|||
} |
|||
|
|||
func calSwnd(bandwidth, rtt int64) int32 { |
|||
w := int32(bandwidth * rtt / (8000 * _MSS)) |
|||
if w <= _MAX_SWND { |
|||
if w >= _MIN_SWND { |
|||
return w |
|||
} else { |
|||
return _MIN_SWND |
|||
} |
|||
} else { |
|||
return _MAX_SWND |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) measure(seq uint32, delayed int64, scnt uint8) { |
|||
target := c.outQ.get(seq) |
|||
if target != nil { |
|||
var lastSent int64 |
|||
switch target.scnt - scnt { |
|||
case 0: |
|||
// not sent again since this ack was sent out
|
|||
lastSent = target.sent |
|||
case 1: |
|||
// sent again once since this ack was sent out
|
|||
// then use prev sent time
|
|||
lastSent = target.sent_1 |
|||
default: |
|||
// can't measure here because the packet was sent too many times
|
|||
return |
|||
} |
|||
// real-time rtt
|
|||
rtt := Now() - lastSent - delayed |
|||
// reject these abnormal measures:
|
|||
// 1. rtt too small -> rtt/8
|
|||
// 2. backlogging too long
|
|||
if rtt < maxI64(c.rtt>>3, 1) || delayed > c.rtt>>1 { |
|||
return |
|||
} |
|||
// srtt: update 1/8
|
|||
err := rtt - (c.srtt >> 3) |
|||
c.srtt += err |
|||
c.rtt = c.srtt >> 3 |
|||
if c.rtt < _MIN_RTT { |
|||
c.rtt = _MIN_RTT |
|||
} |
|||
// s-swnd: update 1/4
|
|||
swnd := c.swnd<<3 - c.swnd + calSwnd(c.bandwidth, c.rtt) |
|||
c.swnd = swnd >> 3 |
|||
c.tSlot = c.rtt * 1e6 / int64(c.swnd) |
|||
c.ato = c.rtt >> 4 |
|||
if c.ato < _MIN_ATO { |
|||
c.ato = _MIN_ATO |
|||
} else if c.ato > _MAX_ATO { |
|||
c.ato = _MAX_ATO |
|||
} |
|||
if err < 0 { |
|||
err = -err |
|||
err -= c.mdev >> 2 |
|||
if err > 0 { |
|||
err >>= 3 |
|||
} |
|||
} else { |
|||
err -= c.mdev >> 2 |
|||
} |
|||
// mdev: update 1/4
|
|||
c.mdev += err |
|||
rto := c.rtt + maxI64(c.rtt<<1, c.mdev) |
|||
if rto >= c.rto { |
|||
c.rto = rto |
|||
} else { |
|||
c.rto = (c.rto + rto) >> 1 |
|||
} |
|||
if c.rto < _MIN_RTO { |
|||
c.rto = _MIN_RTO |
|||
} |
|||
if debug >= 1 { |
|||
log.Printf("--- rtt=%d srtt=%d rto=%d swnd=%d", c.rtt, c.srtt, c.rto, c.swnd) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) processSAck(pk *packet) { |
|||
c.outlock.Lock() |
|||
bmap, tbl, delayed, scnt := unmarshallSAck(pk.payload) |
|||
if bmap == nil { // bad packet
|
|||
c.outlock.Unlock() |
|||
return |
|||
} |
|||
if pk.flag&_F_TIME != 0 { |
|||
c.measure(pk.seq, int64(delayed), scnt) |
|||
} |
|||
deleted, missed, continuous := c.outQ.deleteByBitmap(bmap, pk.ack, tbl) |
|||
if deleted > 0 { |
|||
c.ackHit(deleted, missed) |
|||
// lock is released
|
|||
} else { |
|||
c.outlock.Unlock() |
|||
} |
|||
if c.fastRetransmit && !continuous { |
|||
// peer Q is uncontinuous, then trigger FR
|
|||
if deleted == 0 { |
|||
c.evSWnd <- _VRETR_IMMED |
|||
} else { |
|||
select { |
|||
case c.evSWnd <- _VRETR_IMMED: |
|||
default: |
|||
} |
|||
} |
|||
} |
|||
if debug >= 2 { |
|||
log.Printf("SACK qhead=%d deleted=%d outPending=%d on=%d %016x", |
|||
c.outQ.distanceOfHead(0), deleted, c.outPending, pk.ack, bmap) |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) processAck(pk *packet) { |
|||
c.outlock.Lock() |
|||
if end := c.outQ.get(pk.ack); end != nil { // ack hit
|
|||
_, deleted := c.outQ.deleteBefore(end) |
|||
c.ackHit(deleted, 0) // lock is released
|
|||
if debug >= 2 { |
|||
log.Printf("ACK hit on=%d", pk.ack) |
|||
} |
|||
// special case: ack the FIN
|
|||
if pk.seq == _FIN_ACK_SEQ { |
|||
select { |
|||
case c.evClose <- _S_FIN0: |
|||
default: |
|||
} |
|||
} |
|||
} else { // duplicated ack
|
|||
if debug >= 2 { |
|||
log.Printf("ACK miss on=%d", pk.ack) |
|||
} |
|||
if pk.flag&_F_SYN != 0 { // No.3 Ack lost
|
|||
if pkAck := c.makeLastAck(); pkAck != nil { |
|||
c.internalWrite(nodeOf(pkAck)) |
|||
} |
|||
} |
|||
c.outlock.Unlock() |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) ackHit(deleted, missed int32) { |
|||
// must in outlock
|
|||
c.outPending -= deleted |
|||
now := Now() |
|||
if c.cwnd < c.swnd && now-c.lastShrink > c.rto { |
|||
if c.cwnd < c.swnd>>1 { |
|||
c.cwnd <<= 1 |
|||
} else { |
|||
c.cwnd += deleted << 1 |
|||
} |
|||
} |
|||
if c.cwnd > c.swnd { |
|||
c.cwnd = c.swnd |
|||
} |
|||
if now-c.lastRstMis > c.ato { |
|||
c.lastRstMis = now |
|||
c.missed = missed |
|||
} else { |
|||
c.missed = c.missed>>1 + missed |
|||
} |
|||
if qswnd := c.swnd >> 4; c.missed > qswnd { |
|||
c.missed = qswnd |
|||
} |
|||
c.outlock.Unlock() |
|||
select { |
|||
case c.evSend <- 1: |
|||
default: |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) insertData(pk *packet) { |
|||
c.inlock.Lock() |
|||
defer c.inlock.Unlock() |
|||
exists := c.inQ.contains(pk.seq) |
|||
// duplicated with already queued or history
|
|||
// means: last ACK were lost
|
|||
if exists || pk.seq <= c.inQ.maxCtnSeq { |
|||
// then send ACK for dups
|
|||
select { |
|||
case c.evAck <- _VACK_MUST: |
|||
default: |
|||
} |
|||
if debug >= 2 { |
|||
dumpQ(fmt.Sprint("duplicated ", pk.seq), c.inQ) |
|||
} |
|||
c.inDupCnt++ |
|||
return |
|||
} |
|||
// record current time in sent and regard as received time
|
|||
item := &qNode{packet: pk, sent: Now()} |
|||
dis := c.inQ.searchInsert(item, c.lastReadSeq) |
|||
if debug >= 3 { |
|||
log.Printf("\t\t\trecv DATA seq=%d dis=%d maxCtn=%d lastReadSeq=%d", item.seq, dis, c.inQ.maxCtnSeq, c.lastReadSeq) |
|||
} |
|||
|
|||
var ackState byte = _VACK_MUST |
|||
var available bool |
|||
switch dis { |
|||
case 0: // impossible
|
|||
return |
|||
case 1: |
|||
if c.inQDirty { |
|||
available = c.inQ.updateContinuous(item) |
|||
if c.inQ.isWholeContinuous() { // whole Q is ordered
|
|||
c.inQDirty = false |
|||
} else { //those holes still exists.
|
|||
ackState = _VACK_QUICK |
|||
} |
|||
} else { |
|||
// here is an ideal situation
|
|||
c.inQ.maxCtnSeq = pk.seq |
|||
available = true |
|||
ackState = _VACK_SCHED |
|||
} |
|||
|
|||
default: // there is an unordered packet, hole occurred here.
|
|||
if !c.inQDirty { |
|||
c.inQDirty = true |
|||
} |
|||
} |
|||
|
|||
// write valid received count
|
|||
c.inPkCnt++ |
|||
c.inQ.lastIns = item |
|||
// try notify ack
|
|||
select { |
|||
case c.evAck <- ackState: |
|||
default: |
|||
} |
|||
if available { // try notify reader
|
|||
select { |
|||
case c.evRead <- 1: |
|||
default: |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) readInQ() bool { |
|||
c.inlock.Lock() |
|||
defer c.inlock.Unlock() |
|||
// read already <-|-> expected Q
|
|||
// [lastReadSeq] | [lastReadSeq+1] [lastReadSeq+2] ......
|
|||
if c.inQ.isEqualsHead(c.lastReadSeq+1) && c.lastReadSeq < c.inQ.maxCtnSeq { |
|||
c.lastReadSeq = c.inQ.maxCtnSeq |
|||
availabled := c.inQ.get(c.inQ.maxCtnSeq) |
|||
availabled, _ = c.inQ.deleteBefore(availabled) |
|||
for i := availabled; i != nil; i = i.next { |
|||
c.inQReady = append(c.inQReady, i.payload...) |
|||
// data was copied, then could recycle memory
|
|||
bpool.Put(i.buffer) |
|||
i.payload = nil |
|||
i.buffer = nil |
|||
} |
|||
return true |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// should not call this function concurrently.
|
|||
func (c *Conn) Read(buf []byte) (nr int, err error) { |
|||
for { |
|||
if len(c.inQReady) > 0 { |
|||
n := copy(buf, c.inQReady) |
|||
c.inQReady = c.inQReady[n:] |
|||
return n, nil |
|||
} |
|||
if !c.readInQ() { |
|||
if c.rtmo > 0 { |
|||
var tmo int64 |
|||
tmo, c.rtmo = c.rtmo, 0 |
|||
select { |
|||
case _, y := <-c.evRead: |
|||
if !y && len(c.inQReady) == 0 { |
|||
return 0, io.EOF |
|||
} |
|||
case <-NewTimerChan(tmo): |
|||
return 0, ErrIOTimeout |
|||
} |
|||
} else { |
|||
// only when evRead is closed and inQReady is empty
|
|||
// then could reply eof
|
|||
if _, y := <-c.evRead; !y && len(c.inQReady) == 0 { |
|||
return 0, io.EOF |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// should not call this function concurrently.
|
|||
func (c *Conn) Write(data []byte) (nr int, err error) { |
|||
for len(data) > 0 && err == nil { |
|||
//buf := make([]byte, _MSS+_AH_SIZE)
|
|||
buf := bpool.Get(c.mss + _AH_SIZE) |
|||
body := buf[_TH_SIZE+_CH_SIZE:] |
|||
n := copy(body, data) |
|||
nr += n |
|||
data = data[n:] |
|||
pk := &packet{flag: _F_DATA, payload: body[:n], buffer: buf[:_AH_SIZE+n]} |
|||
err = c.inputAndSend(pk) |
|||
} |
|||
return |
|||
} |
|||
|
|||
func (c *Conn) LocalAddr() net.Addr { |
|||
return c.sock.LocalAddr() |
|||
} |
|||
|
|||
func (c *Conn) RemoteAddr() net.Addr { |
|||
return c.dest |
|||
} |
|||
|
|||
func (c *Conn) SetDeadline(t time.Time) error { |
|||
c.SetReadDeadline(t) |
|||
c.SetWriteDeadline(t) |
|||
return nil |
|||
} |
|||
|
|||
func (c *Conn) SetReadDeadline(t time.Time) error { |
|||
if d := t.UnixNano()/Millisecond - Now(); d > 0 { |
|||
c.rtmo = d |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (c *Conn) SetWriteDeadline(t time.Time) error { |
|||
if d := t.UnixNano()/Millisecond - Now(); d > 0 { |
|||
c.wtmo = d |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,86 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"math/rand" |
|||
"sort" |
|||
"testing" |
|||
) |
|||
|
|||
var conn *Conn |
|||
|
|||
func init() { |
|||
conn = &Conn{ |
|||
outQ: newLinkedMap(_QModeOut), |
|||
inQ: newLinkedMap(_QModeIn), |
|||
} |
|||
} |
|||
|
|||
func assert(cond bool, t testing.TB, format string, args ...interface{}) { |
|||
if !cond { |
|||
t.Errorf(format, args...) |
|||
panic("last error") |
|||
} |
|||
} |
|||
|
|||
func Test_ordered_insert(t *testing.T) { |
|||
data := []byte{1} |
|||
var pk *packet |
|||
for i := int32(1); i < 33; i++ { |
|||
pk = &packet{ |
|||
seq: uint32(i), |
|||
payload: data, |
|||
} |
|||
conn.insertData(pk) |
|||
assert(conn.inQ.size() == i, t, "len inQ=%d", conn.inQ.size()) |
|||
assert(conn.inQ.maxCtnSeq == pk.seq, t, "lastCtnIn") |
|||
assert(!conn.inQDirty, t, "dirty") |
|||
} |
|||
} |
|||
|
|||
func Test_unordered_insert(t *testing.T) { |
|||
conn.inQ.reset() |
|||
|
|||
data := []byte{1} |
|||
var pk *packet |
|||
var seqs = make([]int, 0xfff) |
|||
// unordered insert, and assert size
|
|||
for i := 1; i < len(seqs); i++ { |
|||
var seq uint32 |
|||
for conn.inQ.contains(seq) || seq == 0 { |
|||
seq = uint32(rand.Int31n(0xFFffff)) |
|||
} |
|||
seqs[i] = int(seq) |
|||
pk = &packet{ |
|||
seq: seq, |
|||
payload: data, |
|||
} |
|||
conn.insertData(pk) |
|||
assert(conn.inQ.size() == int32(i), t, "i=%d inQ.len=%d", i, conn.inQ.size()) |
|||
} |
|||
// assert lastCtnSeq
|
|||
sort.Ints(seqs) |
|||
var zero = 0 |
|||
var last *int |
|||
for i := 0; i < len(seqs); i++ { |
|||
if i == 0 && seqs[0] != 0 { |
|||
last = &zero |
|||
break |
|||
} |
|||
if last != nil && seqs[i]-*last > 1 { |
|||
if i == 1 { |
|||
last = &zero |
|||
} |
|||
break |
|||
} |
|||
last = &seqs[i] |
|||
} |
|||
if *last != int(conn.inQ.maxCtnSeq) { |
|||
for i, j := range seqs { |
|||
if i < 10 { |
|||
t.Logf("seq %d", j) |
|||
} |
|||
} |
|||
} |
|||
assert(*last == int(conn.inQ.maxCtnSeq), t, "lastCtnSeq=%d but expected=%d", conn.inQ.maxCtnSeq, *last) |
|||
t.Logf("lastCtnSeq=%d dirty=%v", conn.inQ.maxCtnSeq, conn.inQDirty) |
|||
} |
@ -0,0 +1,95 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"encoding/hex" |
|||
"fmt" |
|||
"log" |
|||
"os" |
|||
"os/signal" |
|||
"runtime" |
|||
"runtime/pprof" |
|||
"sync/atomic" |
|||
"syscall" |
|||
) |
|||
|
|||
var enable_pprof bool |
|||
var enable_stacktrace bool |
|||
var debug_inited int32 |
|||
|
|||
func set_debug_params(p *Params) { |
|||
if atomic.CompareAndSwapInt32(&debug_inited, 0, 1) { |
|||
debug = p.Debug |
|||
enable_pprof = p.EnablePprof |
|||
enable_stacktrace = p.Stacktrace |
|||
if enable_pprof { |
|||
f, err := os.Create("udptransfer.pprof") |
|||
if err != nil { |
|||
log.Fatalln(err) |
|||
} |
|||
pprof.StartCPUProfile(f) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) PrintState() { |
|||
log.Printf("inQ=%d inQReady=%d outQ=%d", c.inQ.size(), len(c.inQReady), c.outQ.size()) |
|||
log.Printf("inMaxCtnSeq=%d lastAck=%d lastReadSeq=%d", c.inQ.maxCtnSeq, c.lastAck, c.lastReadSeq) |
|||
if c.inPkCnt > 0 { |
|||
log.Printf("Rx pcnt=%d dups=%d %%d=%f%%", c.inPkCnt, c.inDupCnt, 100*float32(c.inDupCnt)/float32(c.inPkCnt)) |
|||
} |
|||
if c.outPkCnt > 0 { |
|||
log.Printf("Tx pcnt=%d dups=%d %%d=%f%%", c.outPkCnt, c.outDupCnt, 100*float32(c.outDupCnt)/float32(c.outPkCnt)) |
|||
} |
|||
log.Printf("current-rtt=%d FastRetransmit=%d", c.rtt, c.fRCnt) |
|||
if enable_stacktrace { |
|||
var buf = make([]byte, 6400) |
|||
for i := 0; i < 3; i++ { |
|||
n := runtime.Stack(buf, true) |
|||
if n >= len(buf) { |
|||
buf = make([]byte, len(buf)<<1) |
|||
} else { |
|||
buf = buf[:n] |
|||
break |
|||
} |
|||
} |
|||
fmt.Println(string(buf)) |
|||
} |
|||
if enable_pprof { |
|||
pprof.StopCPUProfile() |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) internal_state() { |
|||
ev := make(chan os.Signal, 10) |
|||
signal.Notify(ev, syscall.Signal(12), syscall.SIGINT) |
|||
for v := range ev { |
|||
c.PrintState() |
|||
if v == syscall.SIGINT { |
|||
os.Exit(1) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func printBits(b uint64, j, s, d uint32) { |
|||
fmt.Printf("bits=%064b j=%d seq=%d dis=%d\n", b, j, s, d) |
|||
} |
|||
|
|||
func dumpb(label string, buf []byte) { |
|||
log.Println(label, "\n", hex.Dump(buf)) |
|||
} |
|||
|
|||
func dumpQ(s string, q *linkedMap) { |
|||
var seqs = make([]uint32, 0, 20) |
|||
n := q.head |
|||
for i, m := int32(0), q.size(); i < m && n != nil; i++ { |
|||
seqs = append(seqs, n.seq) |
|||
n = n.next |
|||
if len(seqs) == 20 { |
|||
log.Printf("%s: Q=%d", s, seqs) |
|||
seqs = seqs[:0] |
|||
} |
|||
} |
|||
if len(seqs) > 0 { |
|||
log.Printf("%s: Q=%d", s, seqs) |
|||
} |
|||
} |
@ -0,0 +1,339 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"fmt" |
|||
"io" |
|||
"log" |
|||
"math/rand" |
|||
"net" |
|||
"sort" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
"github.com/cloudflare/golibs/bytepool" |
|||
) |
|||
|
|||
const ( |
|||
_SO_BUF_SIZE = 8 << 20 |
|||
) |
|||
|
|||
var ( |
|||
bpool bytepool.BytePool |
|||
) |
|||
|
|||
type Params struct { |
|||
LocalAddr string |
|||
Bandwidth int64 |
|||
Mtu int |
|||
IsServ bool |
|||
FastRetransmit bool |
|||
FlatTraffic bool |
|||
EnablePprof bool |
|||
Stacktrace bool |
|||
Debug int |
|||
} |
|||
|
|||
type connID struct { |
|||
lid uint32 |
|||
rid uint32 |
|||
} |
|||
|
|||
type Endpoint struct { |
|||
udpconn *net.UDPConn |
|||
state int32 |
|||
idSeq uint32 |
|||
isServ bool |
|||
listenChan chan *Conn |
|||
lRegistry map[uint32]*Conn |
|||
rRegistry map[string][]uint32 |
|||
mlock sync.RWMutex |
|||
timeout *time.Timer |
|||
params Params |
|||
} |
|||
|
|||
func (c *connID) setRid(b []byte) { |
|||
c.rid = binary.BigEndian.Uint32(b[_MAGIC_SIZE+6:]) |
|||
} |
|||
|
|||
func init() { |
|||
bpool.Init(0, 2000) |
|||
rand.Seed(NowNS()) |
|||
} |
|||
|
|||
func NewEndpoint(p *Params) (*Endpoint, error) { |
|||
set_debug_params(p) |
|||
if p.Bandwidth <= 0 || p.Bandwidth > 100 { |
|||
return nil, fmt.Errorf("bw->(0,100]") |
|||
} |
|||
conn, err := net.ListenPacket("udp", p.LocalAddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
e := &Endpoint{ |
|||
udpconn: conn.(*net.UDPConn), |
|||
idSeq: 1, |
|||
isServ: p.IsServ, |
|||
listenChan: make(chan *Conn, 1), |
|||
lRegistry: make(map[uint32]*Conn), |
|||
rRegistry: make(map[string][]uint32), |
|||
timeout: time.NewTimer(0), |
|||
params: *p, |
|||
} |
|||
if e.isServ { |
|||
e.state = _S_EST0 |
|||
} else { // client
|
|||
e.state = _S_EST1 |
|||
e.idSeq = uint32(rand.Int31()) |
|||
} |
|||
e.params.Bandwidth = p.Bandwidth << 20 // mbps to bps
|
|||
e.udpconn.SetReadBuffer(_SO_BUF_SIZE) |
|||
go e.internal_listen() |
|||
return e, nil |
|||
} |
|||
|
|||
func (e *Endpoint) internal_listen() { |
|||
const rtmo = time.Duration(30*time.Second) |
|||
var id connID |
|||
for { |
|||
//var buf = make([]byte, 1600)
|
|||
var buf = bpool.Get(1600) |
|||
e.udpconn.SetReadDeadline(time.Now().Add(rtmo)) |
|||
n, addr, err := e.udpconn.ReadFromUDP(buf) |
|||
if err == nil && n >= _AH_SIZE { |
|||
buf = buf[:n] |
|||
e.getConnID(&id, buf) |
|||
|
|||
switch id.lid { |
|||
case 0: // new connection
|
|||
if e.isServ { |
|||
go e.acceptNewConn(id, addr, buf) |
|||
} else { |
|||
dumpb("drop", buf) |
|||
} |
|||
|
|||
case _INVALID_SEQ: |
|||
dumpb("drop invalid", buf) |
|||
|
|||
default: // old connection
|
|||
e.mlock.RLock() |
|||
conn := e.lRegistry[id.lid] |
|||
e.mlock.RUnlock() |
|||
if conn != nil { |
|||
e.dispatch(conn, buf) |
|||
} else { |
|||
e.resetPeer(addr, id) |
|||
dumpb("drop null", buf) |
|||
} |
|||
} |
|||
|
|||
} else if err != nil { |
|||
// idle process
|
|||
if nerr, y := err.(net.Error); y && nerr.Timeout() { |
|||
e.idleProcess() |
|||
continue |
|||
} |
|||
// other errors
|
|||
if atomic.LoadInt32(&e.state) == _S_FIN { |
|||
return |
|||
} else { |
|||
log.Println("Error: read sock", err) |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (e *Endpoint) idleProcess() { |
|||
// recycle/shrink memory
|
|||
bpool.Drain() |
|||
e.mlock.Lock() |
|||
defer e.mlock.Unlock() |
|||
// reset urgent
|
|||
for _, c := range e.lRegistry { |
|||
c.outlock.Lock() |
|||
if c.outQ.size() == 0 && c.urgent != 0 { |
|||
c.urgent = 0 |
|||
} |
|||
c.outlock.Unlock() |
|||
} |
|||
} |
|||
|
|||
func (e *Endpoint) Dial(addr string) (*Conn, error) { |
|||
rAddr, err := net.ResolveUDPAddr("udp", addr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
e.mlock.Lock() |
|||
e.idSeq++ |
|||
id := connID{e.idSeq, 0} |
|||
conn := NewConn(e, rAddr, id) |
|||
e.lRegistry[id.lid] = conn |
|||
e.mlock.Unlock() |
|||
if atomic.LoadInt32(&e.state) != _S_FIN { |
|||
err = conn.initConnection(nil) |
|||
return conn, err |
|||
} |
|||
return nil, io.EOF |
|||
} |
|||
|
|||
func (e *Endpoint) acceptNewConn(id connID, addr *net.UDPAddr, buf []byte) { |
|||
rKey := addr.String() |
|||
e.mlock.Lock() |
|||
// map: remoteAddr => remoteConnID
|
|||
// filter duplicated syn packets
|
|||
if newArr := insertRid(e.rRegistry[rKey], id.rid); newArr != nil { |
|||
e.rRegistry[rKey] = newArr |
|||
} else { |
|||
e.mlock.Unlock() |
|||
log.Println("Warn: duplicated connection", addr) |
|||
return |
|||
} |
|||
e.idSeq++ |
|||
id.lid = e.idSeq |
|||
conn := NewConn(e, addr, id) |
|||
e.lRegistry[id.lid] = conn |
|||
e.mlock.Unlock() |
|||
err := conn.initConnection(buf) |
|||
if err == nil { |
|||
select { |
|||
case e.listenChan <- conn: |
|||
case <-time.After(_10ms): |
|||
log.Println("Warn: no listener") |
|||
} |
|||
} else { |
|||
e.removeConn(id, addr) |
|||
log.Println("Error: init_connection", addr, err) |
|||
} |
|||
} |
|||
|
|||
func (e *Endpoint) removeConn(id connID, addr *net.UDPAddr) { |
|||
e.mlock.Lock() |
|||
delete(e.lRegistry, id.lid) |
|||
rKey := addr.String() |
|||
if newArr := deleteRid(e.rRegistry[rKey], id.rid); newArr != nil { |
|||
if len(newArr) > 0 { |
|||
e.rRegistry[rKey] = newArr |
|||
} else { |
|||
delete(e.rRegistry, rKey) |
|||
} |
|||
} |
|||
e.mlock.Unlock() |
|||
} |
|||
|
|||
// net.Listener
|
|||
func (e *Endpoint) Close() error { |
|||
state := atomic.LoadInt32(&e.state) |
|||
if state > 0 && atomic.CompareAndSwapInt32(&e.state, state, _S_FIN) { |
|||
err := e.udpconn.Close() |
|||
e.lRegistry = nil |
|||
e.rRegistry = nil |
|||
select { // release listeners
|
|||
case e.listenChan <- nil: |
|||
default: |
|||
} |
|||
return err |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// net.Listener
|
|||
func (e *Endpoint) Addr() net.Addr { |
|||
return e.udpconn.LocalAddr() |
|||
} |
|||
|
|||
// net.Listener
|
|||
func (e *Endpoint) Accept() (net.Conn, error) { |
|||
if atomic.LoadInt32(&e.state) == _S_EST0 { |
|||
return <-e.listenChan, nil |
|||
} else { |
|||
return nil, io.EOF |
|||
} |
|||
} |
|||
|
|||
func (e *Endpoint) Listen() *Conn { |
|||
if atomic.LoadInt32(&e.state) == _S_EST0 { |
|||
return <-e.listenChan |
|||
} else { |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// tmo in MS
|
|||
func (e *Endpoint) ListenTimeout(tmo int64) *Conn { |
|||
if tmo <= 0 { |
|||
return e.Listen() |
|||
} |
|||
if atomic.LoadInt32(&e.state) == _S_EST0 { |
|||
select { |
|||
case c := <-e.listenChan: |
|||
return c |
|||
case <-NewTimerChan(tmo): |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (e *Endpoint) getConnID(idPtr *connID, buf []byte) { |
|||
// TODO determine magic header
|
|||
magicAndLen := binary.BigEndian.Uint64(buf) |
|||
if int(magicAndLen&0xFFff) == len(buf) { |
|||
id := binary.BigEndian.Uint64(buf[_MAGIC_SIZE+2:]) |
|||
idPtr.lid = uint32(id >> 32) |
|||
idPtr.rid = uint32(id) |
|||
} else { |
|||
idPtr.lid = _INVALID_SEQ |
|||
} |
|||
} |
|||
|
|||
func (e *Endpoint) dispatch(c *Conn, buf []byte) { |
|||
e.timeout.Reset(30*time.Millisecond) |
|||
select { |
|||
case c.evRecv <- buf: |
|||
case <-e.timeout.C: |
|||
log.Println("Warn: dispatch packet failed") |
|||
} |
|||
} |
|||
|
|||
func (e *Endpoint) resetPeer(addr *net.UDPAddr, id connID) { |
|||
pk := &packet{flag: _F_FIN | _F_RESET} |
|||
buf := nodeOf(pk).marshall(id) |
|||
e.udpconn.WriteToUDP(buf, addr) |
|||
} |
|||
|
|||
type u32Slice []uint32 |
|||
|
|||
func (p u32Slice) Len() int { return len(p) } |
|||
func (p u32Slice) Less(i, j int) bool { return p[i] < p[j] } |
|||
func (p u32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } |
|||
|
|||
// if the rid is not existed in array then insert it return new array
|
|||
func insertRid(array []uint32, rid uint32) []uint32 { |
|||
if len(array) > 0 { |
|||
pos := sort.Search(len(array), func(n int) bool { |
|||
return array[n] >= rid |
|||
}) |
|||
if pos < len(array) && array[pos] == rid { |
|||
return nil |
|||
} |
|||
} |
|||
array = append(array, rid) |
|||
sort.Sort(u32Slice(array)) |
|||
return array |
|||
} |
|||
|
|||
// if rid was existed in array then delete it return new array
|
|||
func deleteRid(array []uint32, rid uint32) []uint32 { |
|||
if len(array) > 0 { |
|||
pos := sort.Search(len(array), func(n int) bool { |
|||
return array[n] >= rid |
|||
}) |
|||
if pos < len(array) && array[pos] == rid { |
|||
newArray := make([]uint32, len(array)-1) |
|||
n := copy(newArray, array[:pos]) |
|||
copy(newArray[n:], array[pos+1:]) |
|||
return newArray |
|||
} |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,42 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"fmt" |
|||
"math/rand" |
|||
"sort" |
|||
"testing" |
|||
) |
|||
|
|||
func Test_insert_delete_rid(t *testing.T) { |
|||
var a []uint32 |
|||
var b = make([]uint32, 0, 1e3) |
|||
var uniq = make(map[uint32]int) |
|||
// insert into a with random
|
|||
for i := 0; i < cap(b); i++ { |
|||
n := uint32(rand.Int31()) |
|||
if _, y := uniq[n]; !y { |
|||
b = append(b, n) |
|||
uniq[n] = 1 |
|||
} |
|||
dups := 1 |
|||
if i&0xf == 0xf { |
|||
dups = 3 |
|||
} |
|||
for j := 0; j < dups; j++ { |
|||
if aa := insertRid(a, n); aa != nil { |
|||
a = aa |
|||
} |
|||
} |
|||
} |
|||
sort.Sort(u32Slice(b)) |
|||
bStr := fmt.Sprintf("%d", b) |
|||
aStr := fmt.Sprintf("%d", a) |
|||
assert(aStr == bStr, t, "a!=b") |
|||
|
|||
for i := 0; i < len(b); i++ { |
|||
if aa := deleteRid(a, b[i]); aa != nil { |
|||
a = aa |
|||
} |
|||
} |
|||
assert(len(a) == 0, t, "a!=0") |
|||
} |
@ -0,0 +1,333 @@ |
|||
package udptransfer |
|||
|
|||
type qNode struct { |
|||
*packet |
|||
prev *qNode |
|||
next *qNode |
|||
sent int64 // last sent time
|
|||
sent_1 int64 // prev sent time
|
|||
miss int // sack miss count
|
|||
} |
|||
|
|||
type linkedMap struct { |
|||
head *qNode |
|||
tail *qNode |
|||
qmap map[uint32]*qNode |
|||
lastIns *qNode |
|||
maxCtnSeq uint32 |
|||
mode int |
|||
} |
|||
|
|||
const ( |
|||
_QModeIn = 1 |
|||
_QModeOut = 2 |
|||
) |
|||
|
|||
func newLinkedMap(qmode int) *linkedMap { |
|||
return &linkedMap{ |
|||
qmap: make(map[uint32]*qNode), |
|||
mode: qmode, |
|||
} |
|||
} |
|||
|
|||
func (l *linkedMap) get(seq uint32) (i *qNode) { |
|||
i = l.qmap[seq] |
|||
return |
|||
} |
|||
|
|||
func (l *linkedMap) contains(seq uint32) bool { |
|||
_, y := l.qmap[seq] |
|||
return y |
|||
} |
|||
|
|||
func (l *linkedMap) size() int32 { |
|||
return int32(len(l.qmap)) |
|||
} |
|||
|
|||
func (l *linkedMap) reset() { |
|||
l.head = nil |
|||
l.tail = nil |
|||
l.lastIns = nil |
|||
l.maxCtnSeq = 0 |
|||
l.qmap = make(map[uint32]*qNode) |
|||
} |
|||
|
|||
func (l *linkedMap) isEqualsHead(seq uint32) bool { |
|||
return l.head != nil && seq == l.head.seq |
|||
} |
|||
|
|||
func (l *linkedMap) distanceOfHead(seq uint32) int32 { |
|||
if l.head != nil { |
|||
return int32(seq - l.head.seq) |
|||
} else { |
|||
return -1 |
|||
} |
|||
} |
|||
|
|||
func (l *linkedMap) appendTail(one *qNode) { |
|||
if l.tail != nil { |
|||
l.tail.next = one |
|||
one.prev = l.tail |
|||
l.tail = one |
|||
} else { |
|||
l.head = one |
|||
l.tail = one |
|||
} |
|||
l.qmap[one.seq] = one |
|||
} |
|||
|
|||
// xxx - n - yyy
|
|||
// xxx - yyy
|
|||
func (l *linkedMap) deleteAt(n *qNode) { |
|||
x, y := n.prev, n.next |
|||
if x != nil { |
|||
x.next = y |
|||
} else { |
|||
l.head = y |
|||
} |
|||
if y != nil { |
|||
y.prev = x |
|||
} else { |
|||
l.tail = x |
|||
} |
|||
n.prev, n.next = nil, nil |
|||
delete(l.qmap, n.seq) |
|||
} |
|||
|
|||
// delete with n <- ...n |
|
|||
func (l *linkedMap) deleteBefore(n *qNode) (left *qNode, deleted int32) { |
|||
for i := n; i != nil; i = i.prev { |
|||
delete(l.qmap, i.seq) |
|||
if i.scnt != _SENT_OK { |
|||
deleted++ |
|||
// only outQ could delete at here
|
|||
if l.mode == _QModeOut { |
|||
bpool.Put(i.buffer) |
|||
i.buffer = nil |
|||
} |
|||
} |
|||
} |
|||
left = l.head |
|||
l.head = n.next |
|||
n.next = nil |
|||
if l.head != nil { |
|||
l.head.prev = nil |
|||
} else { // n.next is the tail and is nil
|
|||
l.tail = nil |
|||
} |
|||
return |
|||
} |
|||
|
|||
// xxx - ref
|
|||
// xxx - one - ref
|
|||
func (l *linkedMap) insertBefore(ref, one *qNode) { |
|||
x := ref.prev |
|||
if x != nil { |
|||
x.next = one |
|||
one.prev = x |
|||
} else { |
|||
l.head = one |
|||
} |
|||
ref.prev = one |
|||
one.next = ref |
|||
l.qmap[one.seq] = one |
|||
} |
|||
|
|||
// ref - zzz
|
|||
// ref - one - zzz
|
|||
func (l *linkedMap) insertAfter(ref, one *qNode) { |
|||
z := ref.next |
|||
if z == nil { // append
|
|||
ref.next = one |
|||
l.tail = one |
|||
} else { // insert mid
|
|||
z.prev = one |
|||
ref.next = one |
|||
} |
|||
one.prev = ref |
|||
one.next = z |
|||
l.qmap[one.seq] = one |
|||
} |
|||
|
|||
// baseHead: the left outside boundary
|
|||
// if inserted, return the distance between newNode with baseHead
|
|||
func (l *linkedMap) searchInsert(one *qNode, baseHead uint32) (dis int64) { |
|||
for i := l.tail; i != nil; i = i.prev { |
|||
dis = int64(one.seq) - int64(i.seq) |
|||
if dis > 0 { |
|||
l.insertAfter(i, one) |
|||
return |
|||
} else if dis == 0 { |
|||
// duplicated
|
|||
return |
|||
} |
|||
} |
|||
if one.seq <= baseHead { |
|||
return 0 |
|||
} |
|||
if l.head != nil { |
|||
l.insertBefore(l.head, one) |
|||
} else { |
|||
l.head = one |
|||
l.tail = one |
|||
l.qmap[one.seq] = one |
|||
} |
|||
dis = int64(one.seq) - int64(baseHead) |
|||
return |
|||
} |
|||
|
|||
func (l *linkedMap) updateContinuous(i *qNode) bool { |
|||
var lastCtnSeq = l.maxCtnSeq |
|||
for ; i != nil; i = i.next { |
|||
if i.seq-lastCtnSeq == 1 { |
|||
lastCtnSeq = i.seq |
|||
} else { |
|||
break |
|||
} |
|||
} |
|||
if lastCtnSeq != l.maxCtnSeq { |
|||
l.maxCtnSeq = lastCtnSeq |
|||
return true |
|||
} |
|||
return false |
|||
} |
|||
|
|||
func (l *linkedMap) isWholeContinuous() bool { |
|||
return l.tail != nil && l.maxCtnSeq == l.tail.seq |
|||
} |
|||
|
|||
/* |
|||
func (l *linkedMap) searchMaxContinued(baseHead uint32) (*qNode, bool) { |
|||
var last *qNode |
|||
for i := l.head; i != nil; i = i.next { |
|||
if last != nil { |
|||
if i.seq-last.seq > 1 { |
|||
return last, true |
|||
} |
|||
} else { |
|||
if i.seq != baseHead { |
|||
return nil, false |
|||
} |
|||
} |
|||
last = i |
|||
} |
|||
if last != nil { |
|||
return last, true |
|||
} else { |
|||
return nil, false |
|||
} |
|||
} |
|||
*/ |
|||
|
|||
func (q *qNode) forward(n int) *qNode { |
|||
for ; n > 0 && q != nil; n-- { |
|||
q = q.next |
|||
} |
|||
return q |
|||
} |
|||
|
|||
// prev of bitmap start point
|
|||
func (l *linkedMap) makeHolesBitmap(prev uint32) ([]uint64, uint32) { |
|||
var start *qNode |
|||
var bits uint64 |
|||
var j uint32 |
|||
for i := l.head; i != nil; i = i.next { |
|||
if i.seq >= prev+1 { |
|||
start = i |
|||
break |
|||
} |
|||
} |
|||
var bitmap []uint64 |
|||
// search start which is the recent successor of [prev]
|
|||
Iterator: |
|||
for i := start; i != nil; i = i.next { |
|||
dis := i.seq - prev |
|||
prev = i.seq |
|||
j += dis // j is next bit index
|
|||
for j >= 65 { |
|||
if len(bitmap) >= 20 { // bitmap too long
|
|||
break Iterator |
|||
} |
|||
bitmap = append(bitmap, bits) |
|||
bits = 0 |
|||
j -= 64 |
|||
} |
|||
bits |= 1 << (j - 1) |
|||
} |
|||
if j > 0 { |
|||
// j -> (0, 64]
|
|||
bitmap = append(bitmap, bits) |
|||
} |
|||
return bitmap, j |
|||
} |
|||
|
|||
// from= the bitmap start point
|
|||
func (l *linkedMap) deleteByBitmap(bmap []uint64, from uint32, tailBitsLen uint32) (deleted, missed int32, lastContinued bool) { |
|||
var start = l.qmap[from] |
|||
if start != nil { |
|||
// delete predecessors
|
|||
if pred := start.prev; pred != nil { |
|||
_, n := l.deleteBefore(pred) |
|||
deleted += n |
|||
} |
|||
} else { |
|||
// [from] is out of bounds
|
|||
return |
|||
} |
|||
var j, bitsLen uint32 |
|||
var bits uint64 |
|||
bits, bmap = bmap[0], bmap[1:] |
|||
if len(bmap) > 0 { |
|||
bitsLen = 64 |
|||
} else { |
|||
// bmap.len==1, tail is here
|
|||
bitsLen = tailBitsLen |
|||
} |
|||
// maxContinued will save the max continued node (from [start]) which could be deleted safely.
|
|||
// keep the queue smallest
|
|||
var maxContinued *qNode |
|||
lastContinued = true |
|||
|
|||
for i := start; i != nil; j++ { |
|||
if j >= bitsLen { |
|||
if len(bmap) > 0 { |
|||
j = 0 |
|||
bits, bmap = bmap[0], bmap[1:] |
|||
if len(bmap) > 0 { |
|||
bitsLen = 64 |
|||
} else { |
|||
bitsLen = tailBitsLen |
|||
} |
|||
} else { |
|||
// no more pages
|
|||
goto finished |
|||
} |
|||
} |
|||
if bits&1 == 1 { |
|||
if lastContinued { |
|||
maxContinued = i |
|||
} |
|||
if i.scnt != _SENT_OK { |
|||
// no mark means first deleting
|
|||
deleted++ |
|||
} |
|||
// don't delete, just mark it
|
|||
i.scnt = _SENT_OK |
|||
} else { |
|||
// known it may be lost
|
|||
if i.miss == 0 { |
|||
missed++ |
|||
} |
|||
i.miss++ |
|||
lastContinued = false |
|||
} |
|||
bits >>= 1 |
|||
i = i.next |
|||
} |
|||
|
|||
finished: |
|||
if maxContinued != nil { |
|||
l.deleteBefore(maxContinued) |
|||
} |
|||
return |
|||
} |
@ -0,0 +1,236 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"fmt" |
|||
"math/rand" |
|||
"testing" |
|||
) |
|||
|
|||
var lmap *linkedMap |
|||
|
|||
func init() { |
|||
lmap = newLinkedMap(_QModeIn) |
|||
} |
|||
|
|||
func node(seq int) *qNode { |
|||
return &qNode{packet: &packet{seq: uint32(seq)}} |
|||
} |
|||
|
|||
func Test_get(t *testing.T) { |
|||
var i, n *qNode |
|||
assert(!lmap.contains(0), t, "0 nil") |
|||
n = node(0) |
|||
lmap.head = n |
|||
lmap.tail = n |
|||
lmap.qmap[0] = n |
|||
i = lmap.get(0) |
|||
assert(i == n, t, "0=n") |
|||
} |
|||
|
|||
func Test_insert(t *testing.T) { |
|||
lmap.reset() |
|||
n := node(1) |
|||
// appendTail
|
|||
lmap.appendTail(n) |
|||
assert(lmap.head == n, t, "head n") |
|||
assert(lmap.tail == n, t, "head n") |
|||
n = node(2) |
|||
lmap.appendTail(n) |
|||
assert(lmap.head != n, t, "head n") |
|||
assert(lmap.tail == n, t, "head n") |
|||
assert(lmap.size() == 2, t, "size") |
|||
} |
|||
|
|||
func Test_insertAfter(t *testing.T) { |
|||
n1 := lmap.get(1) |
|||
n2 := n1.next |
|||
n3 := node(3) |
|||
lmap.insertAfter(n1, n3) |
|||
assert(n1.next == n3, t, "n3") |
|||
assert(n1 == n3.prev, t, "left n3") |
|||
assert(n2 == n3.next, t, "n3 right") |
|||
} |
|||
|
|||
func Test_insertBefore(t *testing.T) { |
|||
n3 := lmap.get(3) |
|||
n2 := n3.next |
|||
n4 := node(4) |
|||
lmap.insertAfter(n3, n4) |
|||
assert(n3.next == n4, t, "n4") |
|||
assert(n3 == n4.prev, t, "left n4") |
|||
assert(n2 == n4.next, t, "n4 right") |
|||
} |
|||
|
|||
func Test_deleteBefore(t *testing.T) { |
|||
lmap.reset() |
|||
for i := 1; i < 10; i++ { |
|||
n := node(i) |
|||
lmap.appendTail(n) |
|||
} |
|||
|
|||
var assertRangeEquals = func(n *qNode, start, wantCount int) { |
|||
var last *qNode |
|||
var count int |
|||
for ; n != nil; n = n.next { |
|||
assert(int(n.seq) == start, t, "nseq=%d start=%d", n.seq, start) |
|||
last = n |
|||
start++ |
|||
count++ |
|||
} |
|||
assert(last.next == nil, t, "tail nil") |
|||
assert(count == wantCount, t, "count") |
|||
} |
|||
assertRangeEquals(lmap.head, 1, 9) |
|||
var n *qNode |
|||
n = lmap.get(3) |
|||
n, _ = lmap.deleteBefore(n) |
|||
assertRangeEquals(n, 1, 3) |
|||
assert(lmap.head.seq == 4, t, "head") |
|||
|
|||
n = lmap.get(8) |
|||
n, _ = lmap.deleteBefore(n) |
|||
assertRangeEquals(n, 4, 5) |
|||
assert(lmap.head.seq == 9, t, "head") |
|||
|
|||
n = lmap.get(9) |
|||
n, _ = lmap.deleteBefore(n) |
|||
assertRangeEquals(n, 9, 1) |
|||
|
|||
assert(lmap.size() == 0, t, "size 0") |
|||
assert(lmap.head == nil, t, "head nil") |
|||
assert(lmap.tail == nil, t, "tail nil") |
|||
} |
|||
|
|||
func testBitmap(t *testing.T, bmap []uint64, prev uint32) { |
|||
var j uint |
|||
var k int |
|||
bits := bmap[k] |
|||
t.Logf("test-%d %016x", k, bits) |
|||
var checkNextPage = func() { |
|||
if j >= 64 { |
|||
j = 0 |
|||
k++ |
|||
bits = bmap[k] |
|||
t.Logf("test-%d %016x", k, bits) |
|||
} |
|||
} |
|||
for i := lmap.head; i != nil && k < len(bmap); i = i.next { |
|||
checkNextPage() |
|||
dis := i.seq - prev |
|||
prev = i.seq |
|||
if dis == 1 { |
|||
bit := (bits >> j) & 1 |
|||
assert(bit == 1, t, "1 bit=%d j=%d", bit, j) |
|||
j++ |
|||
} else { |
|||
for ; dis > 0; dis-- { |
|||
checkNextPage() |
|||
bit := (bits >> j) & 1 |
|||
want := uint64(0) |
|||
if dis == 1 { |
|||
want = 1 |
|||
} |
|||
assert(bit == want, t, "?=%d bit=%d j=%d", want, bit, j) |
|||
j++ |
|||
} |
|||
} |
|||
} |
|||
// remains bits should be 0
|
|||
for i := j & 63; i > 0; i-- { |
|||
bit := (bits >> j) & 1 |
|||
assert(bit == 0, t, "00 bit=%d j=%d", bit, j) |
|||
j++ |
|||
} |
|||
} |
|||
|
|||
func Test_bitmap(t *testing.T) { |
|||
var prev uint32 |
|||
var head uint32 = prev + 1 |
|||
|
|||
lmap.reset() |
|||
// test 66-%3 and record holes
|
|||
var holes = make([]uint32, 0, 50) |
|||
for i := head; i < 366; i++ { |
|||
if i%3 == 0 { |
|||
holes = append(holes, i) |
|||
continue |
|||
} |
|||
n := node(int(i)) |
|||
lmap.appendTail(n) |
|||
} |
|||
bmap, tbl := lmap.makeHolesBitmap(prev) |
|||
testBitmap(t, bmap, prev) |
|||
|
|||
lmap.reset() |
|||
// full 66, do deleteByBitmap then compare
|
|||
for i := head; i < 366; i++ { |
|||
n := node(int(i)) |
|||
lmap.appendTail(n) |
|||
} |
|||
|
|||
lmap.deleteByBitmap(bmap, head, tbl) |
|||
var holesResult = make([]uint32, 0, 50) |
|||
for i := lmap.head; i != nil; i = i.next { |
|||
if i.scnt != _SENT_OK { |
|||
holesResult = append(holesResult, i.seq) |
|||
} |
|||
} |
|||
a := fmt.Sprintf("%x", holes) |
|||
b := fmt.Sprintf("%x", holesResult) |
|||
assert(a == b, t, "deleteByBitmap \na=%s \nb=%s", a, b) |
|||
|
|||
lmap.reset() |
|||
// test stride across page 1
|
|||
for i := head; i < 69; i++ { |
|||
if i >= 63 && i <= 65 { |
|||
continue |
|||
} |
|||
n := node(int(i)) |
|||
lmap.appendTail(n) |
|||
} |
|||
bmap, _ = lmap.makeHolesBitmap(prev) |
|||
testBitmap(t, bmap, prev) |
|||
|
|||
lmap.reset() |
|||
prev = 65 |
|||
head = prev + 1 |
|||
// test stride across page 0
|
|||
for i := head; i < 68; i++ { |
|||
n := node(int(i)) |
|||
lmap.appendTail(n) |
|||
} |
|||
bmap, _ = lmap.makeHolesBitmap(prev) |
|||
testBitmap(t, bmap, prev) |
|||
} |
|||
|
|||
var ackbitmap []uint64 |
|||
|
|||
func init_benchmark_map() { |
|||
if lmap.size() != 640 { |
|||
lmap.reset() |
|||
for i := 1; i <= 640; i++ { |
|||
lmap.appendTail(node(i)) |
|||
} |
|||
ackbitmap = make([]uint64, 10) |
|||
for i := 0; i < len(ackbitmap); i++ { |
|||
n := rand.Int63() |
|||
ackbitmap[i] = uint64(n) << 1 |
|||
} |
|||
} |
|||
} |
|||
|
|||
func Benchmark_make_bitmap(b *testing.B) { |
|||
init_benchmark_map() |
|||
|
|||
for i := 0; i < b.N; i++ { |
|||
lmap.makeHolesBitmap(0) |
|||
} |
|||
} |
|||
|
|||
func Benchmark_apply_bitmap(b *testing.B) { |
|||
init_benchmark_map() |
|||
|
|||
for i := 0; i < b.N; i++ { |
|||
lmap.deleteByBitmap(ackbitmap, 1, 64) |
|||
} |
|||
} |
@ -0,0 +1,142 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"fmt" |
|||
) |
|||
|
|||
const ( |
|||
_F_NIL = 0 |
|||
_F_SYN = 1 |
|||
_F_ACK = 1 << 1 |
|||
_F_SACK = 1 << 2 |
|||
_F_TIME = 1 << 3 |
|||
_F_DATA = 1 << 4 |
|||
// reserved = 1 << 5
|
|||
_F_RESET = 1 << 6 |
|||
_F_FIN = 1 << 7 |
|||
) |
|||
|
|||
var packetTypeNames = map[byte]string{ |
|||
0: "NOOP", |
|||
1: "SYN", |
|||
2: "ACK", |
|||
3: "SYN+ACK", |
|||
4: "SACK", |
|||
8: "TIME", |
|||
12: "SACK+TIME", |
|||
16: "DATA", |
|||
64: "RESET", |
|||
128: "FIN", |
|||
192: "FIN+RESET", |
|||
} |
|||
|
|||
const ( |
|||
_S_FIN = iota |
|||
_S_FIN0 |
|||
_S_FIN1 |
|||
_S_SYN0 |
|||
_S_SYN1 |
|||
_S_EST0 |
|||
_S_EST1 |
|||
) |
|||
|
|||
// Magic-6 | TH-10 | CH-10 | payload
|
|||
const ( |
|||
_MAGIC_SIZE = 6 |
|||
_TH_SIZE = 10 + _MAGIC_SIZE |
|||
_CH_SIZE = 10 |
|||
_AH_SIZE = _TH_SIZE + _CH_SIZE |
|||
) |
|||
|
|||
const ( |
|||
// Max UDP payload: 1500 MTU - 20 IP hdr - 8 UDP hdr = 1472 bytes
|
|||
// Then: MSS = 1472-26 = 1446
|
|||
// And For ADSL: 1446-8 = 1438
|
|||
_MSS = 1438 |
|||
) |
|||
|
|||
const ( |
|||
_SENT_OK = 0xff |
|||
) |
|||
|
|||
type packet struct { |
|||
seq uint32 |
|||
ack uint32 |
|||
flag uint8 |
|||
scnt uint8 |
|||
payload []byte |
|||
buffer []byte |
|||
} |
|||
|
|||
func (p *packet) marshall(id connID) []byte { |
|||
buf := p.buffer |
|||
if buf == nil { |
|||
buf = make([]byte, _AH_SIZE+len(p.payload)) |
|||
copy(buf[_TH_SIZE+10:], p.payload) |
|||
} |
|||
binary.BigEndian.PutUint16(buf[_MAGIC_SIZE:], uint16(len(buf))) |
|||
binary.BigEndian.PutUint32(buf[_MAGIC_SIZE+2:], id.rid) |
|||
binary.BigEndian.PutUint32(buf[_MAGIC_SIZE+6:], id.lid) |
|||
binary.BigEndian.PutUint32(buf[_TH_SIZE:], p.seq) |
|||
binary.BigEndian.PutUint32(buf[_TH_SIZE+4:], p.ack) |
|||
buf[_TH_SIZE+8] = p.flag |
|||
buf[_TH_SIZE+9] = p.scnt |
|||
return buf |
|||
} |
|||
|
|||
func unmarshall(pk *packet, buf []byte) { |
|||
if len(buf) >= _CH_SIZE { |
|||
pk.seq = binary.BigEndian.Uint32(buf) |
|||
pk.ack = binary.BigEndian.Uint32(buf[4:]) |
|||
pk.flag = buf[8] |
|||
pk.scnt = buf[9] |
|||
pk.payload = buf[10:] |
|||
} |
|||
} |
|||
|
|||
func (n *qNode) String() string { |
|||
now := Now() |
|||
return fmt.Sprintf("type=%s seq=%d scnt=%d sndtime~=%d,%d miss=%d", |
|||
packetTypeNames[n.flag], n.seq, n.scnt, n.sent-now, n.sent_1-now, n.miss) |
|||
} |
|||
|
|||
func maxI64(a, b int64) int64 { |
|||
if a >= b { |
|||
return a |
|||
} else { |
|||
return b |
|||
} |
|||
} |
|||
|
|||
func maxU32(a, b uint32) uint32 { |
|||
if a >= b { |
|||
return a |
|||
} else { |
|||
return b |
|||
} |
|||
} |
|||
|
|||
func minI64(a, b int64) int64 { |
|||
if a <= b { |
|||
return a |
|||
} else { |
|||
return b |
|||
} |
|||
} |
|||
|
|||
func maxI32(a, b int32) int32 { |
|||
if a >= b { |
|||
return a |
|||
} else { |
|||
return b |
|||
} |
|||
} |
|||
|
|||
func minI32(a, b int32) int32 { |
|||
if a <= b { |
|||
return a |
|||
} else { |
|||
return b |
|||
} |
|||
} |
@ -0,0 +1,516 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"errors" |
|||
"log" |
|||
"net" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
) |
|||
|
|||
const ( |
|||
_10ms = time.Millisecond * 10 |
|||
_100ms = time.Millisecond * 100 |
|||
) |
|||
|
|||
const ( |
|||
_FIN_ACK_SEQ uint32 = 0xffFF0000 |
|||
_INVALID_SEQ uint32 = 0xffFFffFF |
|||
) |
|||
|
|||
var ( |
|||
ErrIOTimeout error = &TimeoutError{} |
|||
ErrUnknown = errors.New("Unknown error") |
|||
ErrInexplicableData = errors.New("Inexplicable data") |
|||
ErrTooManyAttempts = errors.New("Too many attempts to connect") |
|||
) |
|||
|
|||
type TimeoutError struct{} |
|||
|
|||
func (e *TimeoutError) Error() string { return "i/o timeout" } |
|||
func (e *TimeoutError) Timeout() bool { return true } |
|||
func (e *TimeoutError) Temporary() bool { return true } |
|||
|
|||
type Conn struct { |
|||
sock *net.UDPConn |
|||
dest *net.UDPAddr |
|||
edp *Endpoint |
|||
connID connID // 8 bytes
|
|||
// events
|
|||
evRecv chan []byte |
|||
evRead chan byte |
|||
evSend chan byte |
|||
evSWnd chan byte |
|||
evAck chan byte |
|||
evClose chan byte |
|||
// protocol state
|
|||
inlock sync.Mutex |
|||
outlock sync.Mutex |
|||
state int32 |
|||
mySeq uint32 |
|||
swnd int32 |
|||
cwnd int32 |
|||
missed int32 |
|||
outPending int32 |
|||
lastAck uint32 |
|||
lastAckTime int64 |
|||
lastAckTime2 int64 |
|||
lastShrink int64 |
|||
lastRstMis int64 |
|||
ato int64 |
|||
rto int64 |
|||
rtt int64 |
|||
srtt int64 |
|||
mdev int64 |
|||
rtmo int64 |
|||
wtmo int64 |
|||
tSlot int64 |
|||
tSlotT0 int64 |
|||
lastSErr int64 |
|||
// queue
|
|||
outQ *linkedMap |
|||
inQ *linkedMap |
|||
inQReady []byte |
|||
inQDirty bool |
|||
lastReadSeq uint32 // last user read seq
|
|||
// params
|
|||
bandwidth int64 |
|||
fastRetransmit bool |
|||
flatTraffic bool |
|||
mss int |
|||
// statistics
|
|||
urgent int |
|||
inPkCnt int |
|||
inDupCnt int |
|||
outPkCnt int |
|||
outDupCnt int |
|||
fRCnt int |
|||
} |
|||
|
|||
func NewConn(e *Endpoint, dest *net.UDPAddr, id connID) *Conn { |
|||
c := &Conn{ |
|||
sock: e.udpconn, |
|||
dest: dest, |
|||
edp: e, |
|||
connID: id, |
|||
evRecv: make(chan []byte, 128), |
|||
evRead: make(chan byte, 1), |
|||
evSWnd: make(chan byte, 2), |
|||
evSend: make(chan byte, 4), |
|||
evAck: make(chan byte, 1), |
|||
evClose: make(chan byte, 2), |
|||
outQ: newLinkedMap(_QModeOut), |
|||
inQ: newLinkedMap(_QModeIn), |
|||
} |
|||
p := e.params |
|||
c.bandwidth = p.Bandwidth |
|||
c.fastRetransmit = p.FastRetransmit |
|||
c.flatTraffic = p.FlatTraffic |
|||
c.mss = _MSS |
|||
if dest.IP.To4() == nil { |
|||
// typical ipv6 header length=40
|
|||
c.mss -= 20 |
|||
} |
|||
return c |
|||
} |
|||
|
|||
func (c *Conn) initConnection(buf []byte) (err error) { |
|||
if buf == nil { |
|||
err = c.initDialing() |
|||
} else { //server
|
|||
err = c.acceptConnection(buf[_TH_SIZE:]) |
|||
} |
|||
if err != nil { |
|||
return |
|||
} |
|||
if c.state == _S_EST1 { |
|||
c.lastReadSeq = c.lastAck |
|||
c.inQ.maxCtnSeq = c.lastAck |
|||
c.rtt = maxI64(c.rtt, _MIN_RTT) |
|||
c.mdev = c.rtt << 1 |
|||
c.srtt = c.rtt << 3 |
|||
c.rto = maxI64(c.rtt*2, _MIN_RTO) |
|||
c.ato = maxI64(c.rtt>>4, _MIN_ATO) |
|||
c.ato = minI64(c.ato, _MAX_ATO) |
|||
// initial cwnd
|
|||
c.swnd = calSwnd(c.bandwidth, c.rtt) >> 1 |
|||
c.cwnd = 8 |
|||
go c.internalRecvLoop() |
|||
go c.internalSendLoop() |
|||
go c.internalAckLoop() |
|||
if debug >= 0 { |
|||
go c.internal_state() |
|||
} |
|||
return nil |
|||
} else { |
|||
return ErrUnknown |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) initDialing() error { |
|||
// first syn
|
|||
pk := &packet{ |
|||
seq: c.mySeq, |
|||
flag: _F_SYN, |
|||
} |
|||
item := nodeOf(pk) |
|||
var buf []byte |
|||
c.state = _S_SYN0 |
|||
t0 := Now() |
|||
for i := 0; i < _MAX_RETRIES && c.state == _S_SYN0; i++ { |
|||
// send syn
|
|||
c.internalWrite(item) |
|||
select { |
|||
case buf = <-c.evRecv: |
|||
c.rtt = Now() - t0 |
|||
c.state = _S_SYN1 |
|||
c.connID.setRid(buf) |
|||
buf = buf[_TH_SIZE:] |
|||
case <-time.After(time.Second): |
|||
continue |
|||
} |
|||
} |
|||
if c.state == _S_SYN0 { |
|||
return ErrTooManyAttempts |
|||
} |
|||
|
|||
unmarshall(pk, buf) |
|||
// expected syn+ack
|
|||
if pk.flag == _F_SYN|_F_ACK && pk.ack == c.mySeq { |
|||
if scnt := pk.scnt - 1; scnt > 0 { |
|||
c.rtt -= int64(scnt) * 1e3 |
|||
} |
|||
log.Println("rtt", c.rtt) |
|||
c.state = _S_EST0 |
|||
// build ack3
|
|||
pk.scnt = 0 |
|||
pk.ack = pk.seq |
|||
pk.flag = _F_ACK |
|||
item := nodeOf(pk) |
|||
// send ack3
|
|||
c.internalWrite(item) |
|||
// update lastAck
|
|||
c.logAck(pk.ack) |
|||
c.state = _S_EST1 |
|||
return nil |
|||
} else { |
|||
return ErrInexplicableData |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) acceptConnection(buf []byte) error { |
|||
var pk = new(packet) |
|||
var item *qNode |
|||
unmarshall(pk, buf) |
|||
// expected syn
|
|||
if pk.flag == _F_SYN { |
|||
c.state = _S_SYN1 |
|||
// build syn+ack
|
|||
pk.ack = pk.seq |
|||
pk.seq = c.mySeq |
|||
pk.flag |= _F_ACK |
|||
// update lastAck
|
|||
c.logAck(pk.ack) |
|||
item = nodeOf(pk) |
|||
item.scnt = pk.scnt - 1 |
|||
} else { |
|||
dumpb("Syn1 ?", buf) |
|||
return ErrInexplicableData |
|||
} |
|||
for i := 0; i < 5 && c.state == _S_SYN1; i++ { |
|||
t0 := Now() |
|||
// reply syn+ack
|
|||
c.internalWrite(item) |
|||
// recv ack3
|
|||
select { |
|||
case buf = <-c.evRecv: |
|||
c.state = _S_EST0 |
|||
c.rtt = Now() - t0 |
|||
buf = buf[_TH_SIZE:] |
|||
log.Println("rtt", c.rtt) |
|||
case <-time.After(time.Second): |
|||
continue |
|||
} |
|||
} |
|||
if c.state == _S_SYN1 { |
|||
return ErrTooManyAttempts |
|||
} |
|||
|
|||
pk = new(packet) |
|||
unmarshall(pk, buf) |
|||
// expected ack3
|
|||
if pk.flag == _F_ACK && pk.ack == c.mySeq { |
|||
c.state = _S_EST1 |
|||
} else { |
|||
// if ack3 lost, resend syn+ack 3-times
|
|||
// and drop these coming data
|
|||
if pk.flag&_F_DATA != 0 && pk.seq > c.lastAck { |
|||
c.internalWrite(item) |
|||
c.state = _S_EST1 |
|||
} else { |
|||
dumpb("Ack3 ?", buf) |
|||
return ErrInexplicableData |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// 20,20,20,20, 100,100,100,100, 1s,1s,1s,1s
|
|||
func selfSpinWait(fn func() bool) error { |
|||
const _MAX_SPIN = 12 |
|||
for i := 0; i < _MAX_SPIN; i++ { |
|||
if fn() { |
|||
return nil |
|||
} else if i <= 3 { |
|||
time.Sleep(_10ms * 2) |
|||
} else if i <= 7 { |
|||
time.Sleep(_100ms) |
|||
} else { |
|||
time.Sleep(time.Second) |
|||
} |
|||
} |
|||
return ErrIOTimeout |
|||
} |
|||
|
|||
func (c *Conn) IsClosed() bool { |
|||
return atomic.LoadInt32(&c.state) <= _S_FIN1 |
|||
} |
|||
|
|||
/* |
|||
active close: |
|||
1 <- send fin-W: closeW() |
|||
before sending, ensure all outQ items has beed sent out and all of them has been acked. |
|||
2 -> wait to recv ack{fin-W} |
|||
then trigger closeR, including send fin-R and wait to recv ack{fin-R} |
|||
|
|||
passive close: |
|||
-> fin: |
|||
if outQ is not empty then self-spin wait. |
|||
if outQ empty, send ack{fin-W} then goto closeW(). |
|||
*/ |
|||
func (c *Conn) Close() (err error) { |
|||
if !atomic.CompareAndSwapInt32(&c.state, _S_EST1, _S_FIN0) { |
|||
return selfSpinWait(func() bool { |
|||
return atomic.LoadInt32(&c.state) == _S_FIN |
|||
}) |
|||
} |
|||
var err0 error |
|||
err0 = c.closeW() |
|||
// waiting for fin-2 of peer
|
|||
err = selfSpinWait(func() bool { |
|||
select { |
|||
case v := <-c.evClose: |
|||
if v == _S_FIN { |
|||
return true |
|||
} else { |
|||
time.AfterFunc(_100ms, func() { c.evClose <- v }) |
|||
} |
|||
default: |
|||
} |
|||
return false |
|||
}) |
|||
defer c.afterShutdown() |
|||
if err != nil { |
|||
// backup path for wait ack(finW) timeout
|
|||
c.closeR(nil) |
|||
} |
|||
if err0 != nil { |
|||
return err0 |
|||
} else { |
|||
return |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) beforeCloseW() (err error) { |
|||
// check outQ was empty and all has been acked.
|
|||
// self-spin waiting
|
|||
for i := 0; i < 2; i++ { |
|||
err = selfSpinWait(func() bool { |
|||
return atomic.LoadInt32(&c.outPending) <= 0 |
|||
}) |
|||
if err == nil { |
|||
break |
|||
} |
|||
} |
|||
// send fin, reliably
|
|||
c.outlock.Lock() |
|||
c.mySeq++ |
|||
c.outPending++ |
|||
pk := &packet{seq: c.mySeq, flag: _F_FIN} |
|||
item := nodeOf(pk) |
|||
c.outQ.appendTail(item) |
|||
c.internalWrite(item) |
|||
c.outlock.Unlock() |
|||
c.evSWnd <- _VSWND_ACTIVE |
|||
return |
|||
} |
|||
|
|||
func (c *Conn) closeW() (err error) { |
|||
// close resource of sending
|
|||
defer c.afterCloseW() |
|||
// send fin
|
|||
err = c.beforeCloseW() |
|||
var closed bool |
|||
var max = 20 |
|||
if c.rtt > 200 { |
|||
max = int(c.rtt) / 10 |
|||
} |
|||
// waiting for outQ means:
|
|||
// 1. all outQ has been acked, for passive
|
|||
// 2. fin has been acked, for active
|
|||
for i := 0; i < max && (atomic.LoadInt32(&c.outPending) > 0 || !closed); i++ { |
|||
select { |
|||
case v := <-c.evClose: |
|||
if v == _S_FIN0 { |
|||
// namely, last fin has been acked.
|
|||
closed = true |
|||
} else { |
|||
time.AfterFunc(_100ms, func() { c.evClose <- v }) |
|||
} |
|||
case <-time.After(_100ms): |
|||
} |
|||
} |
|||
if closed || err != nil { |
|||
return |
|||
} else { |
|||
return ErrIOTimeout |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) afterCloseW() { |
|||
// can't close(c.evRecv), avoid endpoint dispatch exception
|
|||
// stop pending inputAndSend
|
|||
select { |
|||
case c.evSend <- _CLOSE: |
|||
default: |
|||
} |
|||
// stop internalSendLoop
|
|||
c.evSWnd <- _CLOSE |
|||
} |
|||
|
|||
// called by active and passive close()
|
|||
func (c *Conn) afterShutdown() { |
|||
// stop internalRecvLoop
|
|||
c.evRecv <- nil |
|||
// remove registry
|
|||
c.edp.removeConn(c.connID, c.dest) |
|||
log.Println("shutdown", c.state) |
|||
} |
|||
|
|||
// trigger by reset
|
|||
func (c *Conn) forceShutdownWithLock() { |
|||
c.outlock.Lock() |
|||
defer c.outlock.Unlock() |
|||
c.forceShutdown() |
|||
} |
|||
|
|||
// called by:
|
|||
// 1/ send exception
|
|||
// 2/ recv reset
|
|||
// drop outQ and force shutdown
|
|||
func (c *Conn) forceShutdown() { |
|||
if atomic.CompareAndSwapInt32(&c.state, _S_EST1, _S_FIN) { |
|||
defer c.afterShutdown() |
|||
// stop sender
|
|||
for i := 0; i < cap(c.evSend); i++ { |
|||
select { |
|||
case <-c.evSend: |
|||
default: |
|||
} |
|||
} |
|||
select { |
|||
case c.evSend <- _CLOSE: |
|||
default: |
|||
} |
|||
c.outQ.reset() |
|||
// stop reader
|
|||
close(c.evRead) |
|||
c.inQ.reset() |
|||
// stop internalLoops
|
|||
c.evSWnd <- _CLOSE |
|||
c.evAck <- _CLOSE |
|||
//log.Println("force shutdown")
|
|||
} |
|||
} |
|||
|
|||
// for sending fin failed
|
|||
func (c *Conn) fakeShutdown() { |
|||
select { |
|||
case c.evClose <- _S_FIN0: |
|||
default: |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) closeR(pk *packet) { |
|||
var passive = true |
|||
for { |
|||
state := atomic.LoadInt32(&c.state) |
|||
switch state { |
|||
case _S_FIN: |
|||
return |
|||
case _S_FIN1: // multiple FIN, maybe lost
|
|||
c.passiveCloseReply(pk, false) |
|||
return |
|||
case _S_FIN0: // active close preformed
|
|||
passive = false |
|||
} |
|||
if !atomic.CompareAndSwapInt32(&c.state, state, _S_FIN1) { |
|||
continue |
|||
} |
|||
c.passiveCloseReply(pk, true) |
|||
break |
|||
} |
|||
// here, R is closed.
|
|||
// ^^^^^^^^^^^^^^^^^^^^^
|
|||
if passive { |
|||
// passive closing call closeW contains sending fin and recv ack
|
|||
// may the ack of fin-2 was lost, then the closeW will timeout
|
|||
c.closeW() |
|||
} |
|||
// here, R,W both were closed.
|
|||
// ^^^^^^^^^^^^^^^^^^^^^
|
|||
atomic.StoreInt32(&c.state, _S_FIN) |
|||
// stop internalAckLoop
|
|||
c.evAck <- _CLOSE |
|||
|
|||
if passive { |
|||
// close evRecv within here
|
|||
c.afterShutdown() |
|||
} else { |
|||
// notify active close thread
|
|||
select { |
|||
case c.evClose <- _S_FIN: |
|||
default: |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (c *Conn) passiveCloseReply(pk *packet, first bool) { |
|||
if pk != nil && pk.flag&_F_FIN != 0 { |
|||
if first { |
|||
c.checkInQ(pk) |
|||
close(c.evRead) |
|||
} |
|||
// ack the FIN
|
|||
pk = &packet{seq: _FIN_ACK_SEQ, ack: pk.seq, flag: _F_ACK} |
|||
item := nodeOf(pk) |
|||
c.internalWrite(item) |
|||
} |
|||
} |
|||
|
|||
// check inQ ends orderly, and copy queue data to user space
|
|||
func (c *Conn) checkInQ(pk *packet) { |
|||
if nil != selfSpinWait(func() bool { |
|||
return c.inQ.maxCtnSeq+1 == pk.seq |
|||
}) { // timeout for waiting inQ to finish
|
|||
return |
|||
} |
|||
c.inlock.Lock() |
|||
defer c.inlock.Unlock() |
|||
if c.inQ.size() > 0 { |
|||
for i := c.inQ.head; i != nil; i = i.next { |
|||
c.inQReady = append(c.inQReady, i.payload...) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,36 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"fmt" |
|||
"os" |
|||
"time" |
|||
) |
|||
|
|||
type watch struct { |
|||
label string |
|||
t1 time.Time |
|||
} |
|||
|
|||
func StartWatch(s string) *watch { |
|||
return &watch{ |
|||
label: s, |
|||
t1: time.Now(), |
|||
} |
|||
} |
|||
|
|||
func (w *watch) StopLoops(loop int, size int) { |
|||
tu := time.Now().Sub(w.t1).Nanoseconds() |
|||
timePerLoop := float64(tu) / float64(loop) |
|||
throughput := float64(loop*size) * 1e6 / float64(tu) |
|||
tu_ms := float64(tu) / 1e6 |
|||
fmt.Fprintf(os.Stderr, "%s tu=%.2f ms tpl=%.0f ns throughput=%.2f K/s\n", w.label, tu_ms, timePerLoop, throughput) |
|||
} |
|||
|
|||
var _kt = float64(1e9 / 1024) |
|||
|
|||
func (w *watch) Stop(size int) { |
|||
tu := time.Now().Sub(w.t1).Nanoseconds() |
|||
throughput := float64(size) * _kt / float64(tu) |
|||
tu_ms := float64(tu) / 1e6 |
|||
fmt.Fprintf(os.Stderr, "%s tu=%.2f ms throughput=%.2f K/s\n", w.label, tu_ms, throughput) |
|||
} |
@ -0,0 +1,18 @@ |
|||
package udptransfer |
|||
|
|||
import "time" |
|||
|
|||
const Millisecond = 1e6 |
|||
|
|||
func Now() int64 { |
|||
return time.Now().UnixNano()/Millisecond |
|||
} |
|||
|
|||
func NowNS() int64 { |
|||
return time.Now().UnixNano() |
|||
} |
|||
|
|||
func NewTimerChan(d int64) <-chan time.Time { |
|||
ticker := time.NewTimer(time.Duration(d) * time.Millisecond) |
|||
return ticker.C |
|||
} |
@ -0,0 +1,32 @@ |
|||
package udptransfer |
|||
|
|||
import ( |
|||
"math" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func Test_sleep(t *testing.T) { |
|||
const loops = 10 |
|||
var intervals = [...]int64{1, 1e3, 1e4, 1e5, 1e6, 1e7} |
|||
var ret [len(intervals)][loops]int64 |
|||
for j := 0; j < len(intervals); j++ { |
|||
v := time.Duration(intervals[j]) |
|||
for i := 0; i < loops; i++ { |
|||
t0 := NowNS() |
|||
time.Sleep(v) |
|||
ret[j][i] = NowNS() - t0 |
|||
} |
|||
} |
|||
for j := 0; j < len(intervals); j++ { |
|||
var exp, sum, stdev float64 |
|||
exp = float64(intervals[j]) |
|||
for _, v := range ret[j] { |
|||
sum += float64(v) |
|||
stdev += math.Pow(float64(v)-exp, 2) |
|||
} |
|||
stdev /= float64(loops) |
|||
stdev = math.Sqrt(stdev) |
|||
t.Logf("interval=%s sleeping=%s stdev/k=%.2f", time.Duration(intervals[j]), time.Duration(sum/loops), stdev/1e3) |
|||
} |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue