You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

339 lines
6.9 KiB

package udptransfer
import (
"encoding/binary"
"fmt"
"io"
"log"
"math/rand"
"net"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/cloudflare/golibs/bytepool"
)
const (
_SO_BUF_SIZE = 8 << 20
)
var (
bpool bytepool.BytePool
)
type Params struct {
LocalAddr string
Bandwidth int64
Mtu int
IsServ bool
FastRetransmit bool
FlatTraffic bool
EnablePprof bool
Stacktrace bool
Debug int
}
type connID struct {
lid uint32
rid uint32
}
type Endpoint struct {
udpconn *net.UDPConn
state int32
idSeq uint32
isServ bool
listenChan chan *Conn
lRegistry map[uint32]*Conn
rRegistry map[string][]uint32
mlock sync.RWMutex
timeout *time.Timer
params Params
}
func (c *connID) setRid(b []byte) {
c.rid = binary.BigEndian.Uint32(b[_MAGIC_SIZE+6:])
}
func init() {
bpool.Init(0, 2000)
rand.Seed(NowNS())
}
func NewEndpoint(p *Params) (*Endpoint, error) {
set_debug_params(p)
if p.Bandwidth <= 0 || p.Bandwidth > 100 {
return nil, fmt.Errorf("bw->(0,100]")
}
conn, err := net.ListenPacket("udp", p.LocalAddr)
if err != nil {
return nil, err
}
e := &Endpoint{
udpconn: conn.(*net.UDPConn),
idSeq: 1,
isServ: p.IsServ,
listenChan: make(chan *Conn, 1),
lRegistry: make(map[uint32]*Conn),
rRegistry: make(map[string][]uint32),
timeout: time.NewTimer(0),
params: *p,
}
if e.isServ {
e.state = _S_EST0
} else { // client
e.state = _S_EST1
e.idSeq = uint32(rand.Int31())
}
e.params.Bandwidth = p.Bandwidth << 20 // mbps to bps
e.udpconn.SetReadBuffer(_SO_BUF_SIZE)
go e.internal_listen()
return e, nil
}
func (e *Endpoint) internal_listen() {
const rtmo = time.Duration(30*time.Second)
var id connID
for {
//var buf = make([]byte, 1600)
var buf = bpool.Get(1600)
e.udpconn.SetReadDeadline(time.Now().Add(rtmo))
n, addr, err := e.udpconn.ReadFromUDP(buf)
if err == nil && n >= _AH_SIZE {
buf = buf[:n]
e.getConnID(&id, buf)
switch id.lid {
case 0: // new connection
if e.isServ {
go e.acceptNewConn(id, addr, buf)
} else {
dumpb("drop", buf)
}
case _INVALID_SEQ:
dumpb("drop invalid", buf)
default: // old connection
e.mlock.RLock()
conn := e.lRegistry[id.lid]
e.mlock.RUnlock()
if conn != nil {
e.dispatch(conn, buf)
} else {
e.resetPeer(addr, id)
dumpb("drop null", buf)
}
}
} else if err != nil {
// idle process
if nerr, y := err.(net.Error); y && nerr.Timeout() {
e.idleProcess()
continue
}
// other errors
if atomic.LoadInt32(&e.state) == _S_FIN {
return
} else {
log.Println("Error: read sock", err)
}
}
}
}
func (e *Endpoint) idleProcess() {
// recycle/shrink memory
bpool.Drain()
e.mlock.Lock()
defer e.mlock.Unlock()
// reset urgent
for _, c := range e.lRegistry {
c.outlock.Lock()
if c.outQ.size() == 0 && c.urgent != 0 {
c.urgent = 0
}
c.outlock.Unlock()
}
}
func (e *Endpoint) Dial(addr string) (*Conn, error) {
rAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
e.mlock.Lock()
e.idSeq++
id := connID{e.idSeq, 0}
conn := NewConn(e, rAddr, id)
e.lRegistry[id.lid] = conn
e.mlock.Unlock()
if atomic.LoadInt32(&e.state) != _S_FIN {
err = conn.initConnection(nil)
return conn, err
}
return nil, io.EOF
}
func (e *Endpoint) acceptNewConn(id connID, addr *net.UDPAddr, buf []byte) {
rKey := addr.String()
e.mlock.Lock()
// map: remoteAddr => remoteConnID
// filter duplicated syn packets
if newArr := insertRid(e.rRegistry[rKey], id.rid); newArr != nil {
e.rRegistry[rKey] = newArr
} else {
e.mlock.Unlock()
log.Println("Warn: duplicated connection", addr)
return
}
e.idSeq++
id.lid = e.idSeq
conn := NewConn(e, addr, id)
e.lRegistry[id.lid] = conn
e.mlock.Unlock()
err := conn.initConnection(buf)
if err == nil {
select {
case e.listenChan <- conn:
case <-time.After(_10ms):
log.Println("Warn: no listener")
}
} else {
e.removeConn(id, addr)
log.Println("Error: init_connection", addr, err)
}
}
func (e *Endpoint) removeConn(id connID, addr *net.UDPAddr) {
e.mlock.Lock()
delete(e.lRegistry, id.lid)
rKey := addr.String()
if newArr := deleteRid(e.rRegistry[rKey], id.rid); newArr != nil {
if len(newArr) > 0 {
e.rRegistry[rKey] = newArr
} else {
delete(e.rRegistry, rKey)
}
}
e.mlock.Unlock()
}
// net.Listener
func (e *Endpoint) Close() error {
state := atomic.LoadInt32(&e.state)
if state > 0 && atomic.CompareAndSwapInt32(&e.state, state, _S_FIN) {
err := e.udpconn.Close()
e.lRegistry = nil
e.rRegistry = nil
select { // release listeners
case e.listenChan <- nil:
default:
}
return err
}
return nil
}
// net.Listener
func (e *Endpoint) Addr() net.Addr {
return e.udpconn.LocalAddr()
}
// net.Listener
func (e *Endpoint) Accept() (net.Conn, error) {
if atomic.LoadInt32(&e.state) == _S_EST0 {
return <-e.listenChan, nil
} else {
return nil, io.EOF
}
}
func (e *Endpoint) Listen() *Conn {
if atomic.LoadInt32(&e.state) == _S_EST0 {
return <-e.listenChan
} else {
return nil
}
}
// tmo in MS
func (e *Endpoint) ListenTimeout(tmo int64) *Conn {
if tmo <= 0 {
return e.Listen()
}
if atomic.LoadInt32(&e.state) == _S_EST0 {
select {
case c := <-e.listenChan:
return c
case <-NewTimerChan(tmo):
}
}
return nil
}
func (e *Endpoint) getConnID(idPtr *connID, buf []byte) {
// TODO determine magic header
magicAndLen := binary.BigEndian.Uint64(buf)
if int(magicAndLen&0xFFff) == len(buf) {
id := binary.BigEndian.Uint64(buf[_MAGIC_SIZE+2:])
idPtr.lid = uint32(id >> 32)
idPtr.rid = uint32(id)
} else {
idPtr.lid = _INVALID_SEQ
}
}
func (e *Endpoint) dispatch(c *Conn, buf []byte) {
e.timeout.Reset(30*time.Millisecond)
select {
case c.evRecv <- buf:
case <-e.timeout.C:
log.Println("Warn: dispatch packet failed")
}
}
func (e *Endpoint) resetPeer(addr *net.UDPAddr, id connID) {
pk := &packet{flag: _F_FIN | _F_RESET}
buf := nodeOf(pk).marshall(id)
e.udpconn.WriteToUDP(buf, addr)
}
type u32Slice []uint32
func (p u32Slice) Len() int { return len(p) }
func (p u32Slice) Less(i, j int) bool { return p[i] < p[j] }
func (p u32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// if the rid is not existed in array then insert it return new array
func insertRid(array []uint32, rid uint32) []uint32 {
if len(array) > 0 {
pos := sort.Search(len(array), func(n int) bool {
return array[n] >= rid
})
if pos < len(array) && array[pos] == rid {
return nil
}
}
array = append(array, rid)
sort.Sort(u32Slice(array))
return array
}
// if rid was existed in array then delete it return new array
func deleteRid(array []uint32, rid uint32) []uint32 {
if len(array) > 0 {
pos := sort.Search(len(array), func(n int) bool {
return array[n] >= rid
})
if pos < len(array) && array[pos] == rid {
newArray := make([]uint32, len(array)-1)
n := copy(newArray, array[:pos])
copy(newArray[n:], array[pos+1:])
return newArray
}
}
return nil
}