You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

339 lines
6.9 KiB

  1. package udptransfer
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "log"
  7. "math/rand"
  8. "net"
  9. "sort"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/cloudflare/golibs/bytepool"
  14. )
  15. const (
  16. _SO_BUF_SIZE = 8 << 20
  17. )
  18. var (
  19. bpool bytepool.BytePool
  20. )
  21. type Params struct {
  22. LocalAddr string
  23. Bandwidth int64
  24. Mtu int
  25. IsServ bool
  26. FastRetransmit bool
  27. FlatTraffic bool
  28. EnablePprof bool
  29. Stacktrace bool
  30. Debug int
  31. }
  32. type connID struct {
  33. lid uint32
  34. rid uint32
  35. }
  36. type Endpoint struct {
  37. udpconn *net.UDPConn
  38. state int32
  39. idSeq uint32
  40. isServ bool
  41. listenChan chan *Conn
  42. lRegistry map[uint32]*Conn
  43. rRegistry map[string][]uint32
  44. mlock sync.RWMutex
  45. timeout *time.Timer
  46. params Params
  47. }
  48. func (c *connID) setRid(b []byte) {
  49. c.rid = binary.BigEndian.Uint32(b[_MAGIC_SIZE+6:])
  50. }
  51. func init() {
  52. bpool.Init(0, 2000)
  53. rand.Seed(NowNS())
  54. }
  55. func NewEndpoint(p *Params) (*Endpoint, error) {
  56. set_debug_params(p)
  57. if p.Bandwidth <= 0 || p.Bandwidth > 100 {
  58. return nil, fmt.Errorf("bw->(0,100]")
  59. }
  60. conn, err := net.ListenPacket("udp", p.LocalAddr)
  61. if err != nil {
  62. return nil, err
  63. }
  64. e := &Endpoint{
  65. udpconn: conn.(*net.UDPConn),
  66. idSeq: 1,
  67. isServ: p.IsServ,
  68. listenChan: make(chan *Conn, 1),
  69. lRegistry: make(map[uint32]*Conn),
  70. rRegistry: make(map[string][]uint32),
  71. timeout: time.NewTimer(0),
  72. params: *p,
  73. }
  74. if e.isServ {
  75. e.state = _S_EST0
  76. } else { // client
  77. e.state = _S_EST1
  78. e.idSeq = uint32(rand.Int31())
  79. }
  80. e.params.Bandwidth = p.Bandwidth << 20 // mbps to bps
  81. e.udpconn.SetReadBuffer(_SO_BUF_SIZE)
  82. go e.internal_listen()
  83. return e, nil
  84. }
  85. func (e *Endpoint) internal_listen() {
  86. const rtmo = time.Duration(30*time.Second)
  87. var id connID
  88. for {
  89. //var buf = make([]byte, 1600)
  90. var buf = bpool.Get(1600)
  91. e.udpconn.SetReadDeadline(time.Now().Add(rtmo))
  92. n, addr, err := e.udpconn.ReadFromUDP(buf)
  93. if err == nil && n >= _AH_SIZE {
  94. buf = buf[:n]
  95. e.getConnID(&id, buf)
  96. switch id.lid {
  97. case 0: // new connection
  98. if e.isServ {
  99. go e.acceptNewConn(id, addr, buf)
  100. } else {
  101. dumpb("drop", buf)
  102. }
  103. case _INVALID_SEQ:
  104. dumpb("drop invalid", buf)
  105. default: // old connection
  106. e.mlock.RLock()
  107. conn := e.lRegistry[id.lid]
  108. e.mlock.RUnlock()
  109. if conn != nil {
  110. e.dispatch(conn, buf)
  111. } else {
  112. e.resetPeer(addr, id)
  113. dumpb("drop null", buf)
  114. }
  115. }
  116. } else if err != nil {
  117. // idle process
  118. if nerr, y := err.(net.Error); y && nerr.Timeout() {
  119. e.idleProcess()
  120. continue
  121. }
  122. // other errors
  123. if atomic.LoadInt32(&e.state) == _S_FIN {
  124. return
  125. } else {
  126. log.Println("Error: read sock", err)
  127. }
  128. }
  129. }
  130. }
  131. func (e *Endpoint) idleProcess() {
  132. // recycle/shrink memory
  133. bpool.Drain()
  134. e.mlock.Lock()
  135. defer e.mlock.Unlock()
  136. // reset urgent
  137. for _, c := range e.lRegistry {
  138. c.outlock.Lock()
  139. if c.outQ.size() == 0 && c.urgent != 0 {
  140. c.urgent = 0
  141. }
  142. c.outlock.Unlock()
  143. }
  144. }
  145. func (e *Endpoint) Dial(addr string) (*Conn, error) {
  146. rAddr, err := net.ResolveUDPAddr("udp", addr)
  147. if err != nil {
  148. return nil, err
  149. }
  150. e.mlock.Lock()
  151. e.idSeq++
  152. id := connID{e.idSeq, 0}
  153. conn := NewConn(e, rAddr, id)
  154. e.lRegistry[id.lid] = conn
  155. e.mlock.Unlock()
  156. if atomic.LoadInt32(&e.state) != _S_FIN {
  157. err = conn.initConnection(nil)
  158. return conn, err
  159. }
  160. return nil, io.EOF
  161. }
  162. func (e *Endpoint) acceptNewConn(id connID, addr *net.UDPAddr, buf []byte) {
  163. rKey := addr.String()
  164. e.mlock.Lock()
  165. // map: remoteAddr => remoteConnID
  166. // filter duplicated syn packets
  167. if newArr := insertRid(e.rRegistry[rKey], id.rid); newArr != nil {
  168. e.rRegistry[rKey] = newArr
  169. } else {
  170. e.mlock.Unlock()
  171. log.Println("Warn: duplicated connection", addr)
  172. return
  173. }
  174. e.idSeq++
  175. id.lid = e.idSeq
  176. conn := NewConn(e, addr, id)
  177. e.lRegistry[id.lid] = conn
  178. e.mlock.Unlock()
  179. err := conn.initConnection(buf)
  180. if err == nil {
  181. select {
  182. case e.listenChan <- conn:
  183. case <-time.After(_10ms):
  184. log.Println("Warn: no listener")
  185. }
  186. } else {
  187. e.removeConn(id, addr)
  188. log.Println("Error: init_connection", addr, err)
  189. }
  190. }
  191. func (e *Endpoint) removeConn(id connID, addr *net.UDPAddr) {
  192. e.mlock.Lock()
  193. delete(e.lRegistry, id.lid)
  194. rKey := addr.String()
  195. if newArr := deleteRid(e.rRegistry[rKey], id.rid); newArr != nil {
  196. if len(newArr) > 0 {
  197. e.rRegistry[rKey] = newArr
  198. } else {
  199. delete(e.rRegistry, rKey)
  200. }
  201. }
  202. e.mlock.Unlock()
  203. }
  204. // net.Listener
  205. func (e *Endpoint) Close() error {
  206. state := atomic.LoadInt32(&e.state)
  207. if state > 0 && atomic.CompareAndSwapInt32(&e.state, state, _S_FIN) {
  208. err := e.udpconn.Close()
  209. e.lRegistry = nil
  210. e.rRegistry = nil
  211. select { // release listeners
  212. case e.listenChan <- nil:
  213. default:
  214. }
  215. return err
  216. }
  217. return nil
  218. }
  219. // net.Listener
  220. func (e *Endpoint) Addr() net.Addr {
  221. return e.udpconn.LocalAddr()
  222. }
  223. // net.Listener
  224. func (e *Endpoint) Accept() (net.Conn, error) {
  225. if atomic.LoadInt32(&e.state) == _S_EST0 {
  226. return <-e.listenChan, nil
  227. } else {
  228. return nil, io.EOF
  229. }
  230. }
  231. func (e *Endpoint) Listen() *Conn {
  232. if atomic.LoadInt32(&e.state) == _S_EST0 {
  233. return <-e.listenChan
  234. } else {
  235. return nil
  236. }
  237. }
  238. // tmo in MS
  239. func (e *Endpoint) ListenTimeout(tmo int64) *Conn {
  240. if tmo <= 0 {
  241. return e.Listen()
  242. }
  243. if atomic.LoadInt32(&e.state) == _S_EST0 {
  244. select {
  245. case c := <-e.listenChan:
  246. return c
  247. case <-NewTimerChan(tmo):
  248. }
  249. }
  250. return nil
  251. }
  252. func (e *Endpoint) getConnID(idPtr *connID, buf []byte) {
  253. // TODO determine magic header
  254. magicAndLen := binary.BigEndian.Uint64(buf)
  255. if int(magicAndLen&0xFFff) == len(buf) {
  256. id := binary.BigEndian.Uint64(buf[_MAGIC_SIZE+2:])
  257. idPtr.lid = uint32(id >> 32)
  258. idPtr.rid = uint32(id)
  259. } else {
  260. idPtr.lid = _INVALID_SEQ
  261. }
  262. }
  263. func (e *Endpoint) dispatch(c *Conn, buf []byte) {
  264. e.timeout.Reset(30*time.Millisecond)
  265. select {
  266. case c.evRecv <- buf:
  267. case <-e.timeout.C:
  268. log.Println("Warn: dispatch packet failed")
  269. }
  270. }
  271. func (e *Endpoint) resetPeer(addr *net.UDPAddr, id connID) {
  272. pk := &packet{flag: _F_FIN | _F_RESET}
  273. buf := nodeOf(pk).marshall(id)
  274. e.udpconn.WriteToUDP(buf, addr)
  275. }
  276. type u32Slice []uint32
  277. func (p u32Slice) Len() int { return len(p) }
  278. func (p u32Slice) Less(i, j int) bool { return p[i] < p[j] }
  279. func (p u32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  280. // if the rid is not existed in array then insert it return new array
  281. func insertRid(array []uint32, rid uint32) []uint32 {
  282. if len(array) > 0 {
  283. pos := sort.Search(len(array), func(n int) bool {
  284. return array[n] >= rid
  285. })
  286. if pos < len(array) && array[pos] == rid {
  287. return nil
  288. }
  289. }
  290. array = append(array, rid)
  291. sort.Sort(u32Slice(array))
  292. return array
  293. }
  294. // if rid was existed in array then delete it return new array
  295. func deleteRid(array []uint32, rid uint32) []uint32 {
  296. if len(array) > 0 {
  297. pos := sort.Search(len(array), func(n int) bool {
  298. return array[n] >= rid
  299. })
  300. if pos < len(array) && array[pos] == rid {
  301. newArray := make([]uint32, len(array)-1)
  302. n := copy(newArray, array[:pos])
  303. copy(newArray[n:], array[pos+1:])
  304. return newArray
  305. }
  306. }
  307. return nil
  308. }