package udptransfer

import (
	"encoding/binary"
	"fmt"
	"io"
	"log"
	"net"
	"time"
)

const (
	_MAX_RETRIES = 6
	_MIN_RTT     = 8
	_MIN_RTO     = 30
	_MIN_ATO     = 2
	_MAX_ATO     = 10
	_MIN_SWND    = 10
	_MAX_SWND    = 960
)

const (
	_VACK_SCHED = iota + 1
	_VACK_QUICK
	_VACK_MUST
	_VSWND_ACTIVE
	_VRETR_IMMED
)

const (
	_RETR_REST = -1
	_CLOSE     = 0xff
)

var debug int

func nodeOf(pk *packet) *qNode {
	return &qNode{packet: pk}
}

func (c *Conn) internalRecvLoop() {
	defer func() {
		// avoid send to closed channel while some replaying
		// data packets were received in shutting down.
		_ = recover()
	}()
	var buf, body []byte
	for {
		select {
		case buf = <-c.evRecv:
			if buf != nil {
				body = buf[_TH_SIZE:]
			} else { // shutdown
				return
			}
		}
		pk := new(packet)
		// keep the original buffer, so we could recycle it in future
		pk.buffer = buf
		unmarshall(pk, body)
		if pk.flag&_F_SACK != 0 {
			c.processSAck(pk)
			continue
		}
		if pk.flag&_F_ACK != 0 {
			c.processAck(pk)
		}
		if pk.flag&_F_DATA != 0 {
			c.insertData(pk)
		} else if pk.flag&_F_FIN != 0 {
			if pk.flag&_F_RESET != 0 {
				go c.forceShutdownWithLock()
			} else {
				go c.closeR(pk)
			}
		}
	}
}

func (c *Conn) internalSendLoop() {
	var timer = time.NewTimer(time.Duration(c.rtt) * time.Millisecond)
	for {
		select {
		case v := <-c.evSWnd:
			switch v {
			case _VRETR_IMMED:
				c.outlock.Lock()
				c.retransmit2()
				c.outlock.Unlock()
			case _VSWND_ACTIVE:
				timer.Reset(time.Duration(c.rtt) * time.Millisecond)
			case _CLOSE:
				return
			}
		case <-timer.C: // timeout yet
			var notifySender bool
			c.outlock.Lock()
			rest, _ := c.retransmit()
			switch rest {
			case _RETR_REST, 0: // nothing to send
				if c.outQ.size() > 0 {
					timer.Reset(time.Duration(c.rtt) * time.Millisecond)
				} else {
					timer.Stop()
					// avoid sender blocking
					notifySender = true
				}
			default: // recent rto point
				timer.Reset(time.Duration(minI64(rest, c.rtt)) * time.Millisecond)
			}
			c.outlock.Unlock()
			if notifySender {
				select {
				case c.evSend <- 1:
				default:
				}
			}
		}
	}
}

func (c *Conn) internalAckLoop() {
	// var ackTimer = time.NewTicker(time.Duration(c.ato))
	var ackTimer = time.NewTimer(time.Duration(c.ato) * time.Millisecond)
	var lastAckState byte
	for {
		var v byte
		select {
		case <-ackTimer.C:
			// may cause sending duplicated ack if ato>rtt
			v = _VACK_QUICK
		case v = <-c.evAck:
			ackTimer.Reset(time.Duration(c.ato) * time.Millisecond)
			state := lastAckState
			lastAckState = v
			if state != v {
				if v == _CLOSE {
					return
				}
				v = _VACK_MUST
			}
		}
		c.inlock.Lock()
		if pkAck := c.makeAck(v); pkAck != nil {
			c.internalWrite(nodeOf(pkAck))
		}
		c.inlock.Unlock()
	}
}

func (c *Conn) retransmit() (rest int64, count int32) {
	var now, rto = Now(), c.rto
	var limit = c.cwnd
	for item := c.outQ.head; item != nil && limit > 0; item = item.next {
		if item.scnt != _SENT_OK { // ACKed has scnt==-1
			diff := now - item.sent
			if diff > rto { // already rto
				c.internalWrite(item)
				count++
			} else {
				// continue search next min rto duration
				if rest > 0 {
					rest = minI64(rest, rto-diff+1)
				} else {
					rest = rto - diff + 1
				}
				limit--
			}
		}
	}
	c.outDupCnt += int(count)
	if count > 0 {
		shrcond := (c.fastRetransmit && count > maxI32(c.cwnd>>5, 4)) || (!c.fastRetransmit && count > c.cwnd>>3)
		if shrcond && now-c.lastShrink > c.rto {
			log.Printf("shrink cwnd from=%d to=%d s/4=%d", c.cwnd, c.cwnd>>1, c.swnd>>2)
			c.lastShrink = now
			// shrink cwnd and ensure cwnd >= swnd/4
			if c.cwnd > c.swnd>>1 {
				c.cwnd >>= 1
			}
		}
	}
	if c.outQ.size() > 0 {
		return
	}
	return _RETR_REST, 0
}

func (c *Conn) retransmit2() (count int32) {
	var limit, now = minI32(c.outPending>>4, 8), Now()
	var fRtt = c.rtt
	if now-c.lastShrink > c.rto {
		fRtt += maxI64(c.rtt>>4, 1)
	} else {
		fRtt += maxI64(c.rtt>>1, 2)
	}
	for item := c.outQ.head; item != nil && count < limit; item = item.next {
		if item.scnt != _SENT_OK { // ACKed has scnt==-1
			if item.miss >= 3 && now-item.sent >= fRtt {
				item.miss = 0
				c.internalWrite(item)
				count++
			}
		}
	}
	c.fRCnt += int(count)
	c.outDupCnt += int(count)
	return
}

func (c *Conn) inputAndSend(pk *packet) error {
	item := &qNode{packet: pk}
	if c.mySeq&3 == 1 {
		c.tSlotT0 = NowNS()
	}
	c.outlock.Lock()
	// inflight packets exceeds cwnd
	// inflight includes: 1, unacked; 2, missed
	for c.outPending >= c.cwnd+c.missed {
		c.outlock.Unlock()
		if c.wtmo > 0 {
			var tmo int64
			tmo, c.wtmo = c.wtmo, 0
			select {
			case v := <-c.evSend:
				if v == _CLOSE {
					return io.EOF
				}
			case <-NewTimerChan(tmo):
				return ErrIOTimeout
			}
		} else {
			if v := <-c.evSend; v == _CLOSE {
				return io.EOF
			}
		}
		c.outlock.Lock()
	}
	c.outPending++
	c.outPkCnt++
	c.mySeq++
	pk.seq = c.mySeq
	c.outQ.appendTail(item)
	c.internalWrite(item)
	c.outlock.Unlock()
	// active resending timer, must blocking
	c.evSWnd <- _VSWND_ACTIVE
	if c.mySeq&3 == 0 && c.flatTraffic {
		// calculate time error bewteen tslot with actual usage.
		// consider last sleep time error
		t1 := NowNS()
		terr := c.tSlot<<2 - c.lastSErr - (t1 - c.tSlotT0)
		// rest terr/2 if current time usage less than tslot of 100us.
		if terr > 1e5 { // 100us
			time.Sleep(time.Duration(terr >> 1))
			c.lastSErr = maxI64(NowNS()-t1-terr, 0)
		} else {
			c.lastSErr >>= 1
		}
	}
	return nil
}

func (c *Conn) internalWrite(item *qNode) {
	if item.scnt >= 20 {
		// no exception of sending fin
		if item.flag&_F_FIN != 0 {
			c.fakeShutdown()
			c.dest = nil
			return
		} else {
			log.Println("Warn: too many retries", item)
			if c.urgent > 0 { // abort
				c.forceShutdown()
				return
			} else { // continue to retry 10
				c.urgent++
				item.scnt = 10
			}
		}
	}
	// update current sent time and prev sent time
	item.sent, item.sent_1 = Now(), item.sent
	item.scnt++
	buf := item.marshall(c.connID)
	if debug >= 3 {
		var pkType = packetTypeNames[item.flag]
		if item.flag&_F_SACK != 0 {
			log.Printf("send %s trp=%d on=%d %x", pkType, item.seq, item.ack, buf[_AH_SIZE+4:])
		} else {
			log.Printf("send %s seq=%d ack=%d scnt=%d len=%d", pkType, item.seq, item.ack, item.scnt, len(buf)-_TH_SIZE)
		}
	}
	c.sock.WriteToUDP(buf, c.dest)
}

func (c *Conn) logAck(ack uint32) {
	c.lastAck = ack
	c.lastAckTime = Now()
}

func (c *Conn) makeLastAck() (pk *packet) {
	c.inlock.Lock()
	defer c.inlock.Unlock()
	if Now()-c.lastAckTime < c.rtt {
		return nil
	}
	pk = &packet{
		ack:  maxU32(c.lastAck, c.inQ.maxCtnSeq),
		flag: _F_ACK,
	}
	c.logAck(pk.ack)
	return
}

func (c *Conn) makeAck(level byte) (pk *packet) {
	now := Now()
	if level < _VACK_MUST && now-c.lastAckTime < c.ato {
		if level < _VACK_QUICK || now-c.lastAckTime < minI64(c.ato>>2, 1) {
			return
		}
	}
	//	    ready Q <-|
	//	              |-> outQ start (or more right)
	//	              |-> bitmap start
	//	[predecessor]  [predecessor+1]  [predecessor+2] .....
	var fakeSAck bool
	var predecessor = c.inQ.maxCtnSeq
	bmap, tbl := c.inQ.makeHolesBitmap(predecessor)
	if len(bmap) <= 0 { // fake sack
		bmap = make([]uint64, 1)
		bmap[0], tbl = 1, 1
		fakeSAck = true
	}
	// head 4-byte: TBL:1 | SCNT:1 | DELAY:2
	buf := make([]byte, len(bmap)*8+4)
	pk = &packet{
		ack:     predecessor + 1,
		flag:    _F_SACK,
		payload: buf,
	}
	if fakeSAck {
		pk.ack--
	}
	buf[0] = byte(tbl)
	// mark delayed time according to the time reference point
	if trp := c.inQ.lastIns; trp != nil {
		delayed := now - trp.sent
		if delayed < c.rtt {
			pk.seq = trp.seq
			pk.flag |= _F_TIME
			buf[1] = trp.scnt
			if delayed <= 0 {
				delayed = 1
			}
			binary.BigEndian.PutUint16(buf[2:], uint16(delayed))
		}
	}
	buf1 := buf[4:]
	for i, b := range bmap {
		binary.BigEndian.PutUint64(buf1[i*8:], b)
	}
	c.logAck(predecessor)
	return
}

func unmarshallSAck(data []byte) (bmap []uint64, tbl uint32, delayed uint16, scnt uint8) {
	if len(data) > 0 {
		bmap = make([]uint64, len(data)>>3)
	} else {
		return
	}
	tbl = uint32(data[0])
	scnt = data[1]
	delayed = binary.BigEndian.Uint16(data[2:])
	data = data[4:]
	for i := 0; i < len(bmap); i++ {
		bmap[i] = binary.BigEndian.Uint64(data[i*8:])
	}
	return
}

func calSwnd(bandwidth, rtt int64) int32 {
	w := int32(bandwidth * rtt / (8000 * _MSS))
	if w <= _MAX_SWND {
		if w >= _MIN_SWND {
			return w
		} else {
			return _MIN_SWND
		}
	} else {
		return _MAX_SWND
	}
}

func (c *Conn) measure(seq uint32, delayed int64, scnt uint8) {
	target := c.outQ.get(seq)
	if target != nil {
		var lastSent int64
		switch target.scnt - scnt {
		case 0:
			// not sent again since this ack was sent out
			lastSent = target.sent
		case 1:
			// sent again once since this ack was sent out
			// then use prev sent time
			lastSent = target.sent_1
		default:
			// can't measure here because the packet was sent too many times
			return
		}
		// real-time rtt
		rtt := Now() - lastSent - delayed
		// reject these abnormal measures:
		// 1. rtt too small -> rtt/8
		// 2. backlogging too long
		if rtt < maxI64(c.rtt>>3, 1) || delayed > c.rtt>>1 {
			return
		}
		// srtt: update 1/8
		err := rtt - (c.srtt >> 3)
		c.srtt += err
		c.rtt = c.srtt >> 3
		if c.rtt < _MIN_RTT {
			c.rtt = _MIN_RTT
		}
		// s-swnd: update 1/4
		swnd := c.swnd<<3 - c.swnd + calSwnd(c.bandwidth, c.rtt)
		c.swnd = swnd >> 3
		c.tSlot = c.rtt * 1e6 / int64(c.swnd)
		c.ato = c.rtt >> 4
		if c.ato < _MIN_ATO {
			c.ato = _MIN_ATO
		} else if c.ato > _MAX_ATO {
			c.ato = _MAX_ATO
		}
		if err < 0 {
			err = -err
			err -= c.mdev >> 2
			if err > 0 {
				err >>= 3
			}
		} else {
			err -= c.mdev >> 2
		}
		// mdev: update 1/4
		c.mdev += err
		rto := c.rtt + maxI64(c.rtt<<1, c.mdev)
		if rto >= c.rto {
			c.rto = rto
		} else {
			c.rto = (c.rto + rto) >> 1
		}
		if c.rto < _MIN_RTO {
			c.rto = _MIN_RTO
		}
		if debug >= 1 {
			log.Printf("--- rtt=%d srtt=%d rto=%d swnd=%d", c.rtt, c.srtt, c.rto, c.swnd)
		}
	}
}

func (c *Conn) processSAck(pk *packet) {
	c.outlock.Lock()
	bmap, tbl, delayed, scnt := unmarshallSAck(pk.payload)
	if bmap == nil { // bad packet
		c.outlock.Unlock()
		return
	}
	if pk.flag&_F_TIME != 0 {
		c.measure(pk.seq, int64(delayed), scnt)
	}
	deleted, missed, continuous := c.outQ.deleteByBitmap(bmap, pk.ack, tbl)
	if deleted > 0 {
		c.ackHit(deleted, missed)
		// lock is released
	} else {
		c.outlock.Unlock()
	}
	if c.fastRetransmit && !continuous {
		// peer Q is uncontinuous, then trigger FR
		if deleted == 0 {
			c.evSWnd <- _VRETR_IMMED
		} else {
			select {
			case c.evSWnd <- _VRETR_IMMED:
			default:
			}
		}
	}
	if debug >= 2 {
		log.Printf("SACK qhead=%d deleted=%d outPending=%d on=%d %016x",
			c.outQ.distanceOfHead(0), deleted, c.outPending, pk.ack, bmap)
	}
}

func (c *Conn) processAck(pk *packet) {
	c.outlock.Lock()
	if end := c.outQ.get(pk.ack); end != nil { // ack hit
		_, deleted := c.outQ.deleteBefore(end)
		c.ackHit(deleted, 0) // lock is released
		if debug >= 2 {
			log.Printf("ACK hit on=%d", pk.ack)
		}
		// special case: ack the FIN
		if pk.seq == _FIN_ACK_SEQ {
			select {
			case c.evClose <- _S_FIN0:
			default:
			}
		}
	} else { // duplicated ack
		if debug >= 2 {
			log.Printf("ACK miss on=%d", pk.ack)
		}
		if pk.flag&_F_SYN != 0 { // No.3 Ack lost
			if pkAck := c.makeLastAck(); pkAck != nil {
				c.internalWrite(nodeOf(pkAck))
			}
		}
		c.outlock.Unlock()
	}
}

func (c *Conn) ackHit(deleted, missed int32) {
	// must in outlock
	c.outPending -= deleted
	now := Now()
	if c.cwnd < c.swnd && now-c.lastShrink > c.rto {
		if c.cwnd < c.swnd>>1 {
			c.cwnd <<= 1
		} else {
			c.cwnd += deleted << 1
		}
	}
	if c.cwnd > c.swnd {
		c.cwnd = c.swnd
	}
	if now-c.lastRstMis > c.ato {
		c.lastRstMis = now
		c.missed = missed
	} else {
		c.missed = c.missed>>1 + missed
	}
	if qswnd := c.swnd >> 4; c.missed > qswnd {
		c.missed = qswnd
	}
	c.outlock.Unlock()
	select {
	case c.evSend <- 1:
	default:
	}
}

func (c *Conn) insertData(pk *packet) {
	c.inlock.Lock()
	defer c.inlock.Unlock()
	exists := c.inQ.contains(pk.seq)
	// duplicated with already queued or history
	// means: last ACK were lost
	if exists || pk.seq <= c.inQ.maxCtnSeq {
		// then send ACK for dups
		select {
		case c.evAck <- _VACK_MUST:
		default:
		}
		if debug >= 2 {
			dumpQ(fmt.Sprint("duplicated ", pk.seq), c.inQ)
		}
		c.inDupCnt++
		return
	}
	// record current time in sent and regard as received time
	item := &qNode{packet: pk, sent: Now()}
	dis := c.inQ.searchInsert(item, c.lastReadSeq)
	if debug >= 3 {
		log.Printf("\t\t\trecv DATA seq=%d dis=%d maxCtn=%d lastReadSeq=%d", item.seq, dis, c.inQ.maxCtnSeq, c.lastReadSeq)
	}

	var ackState byte = _VACK_MUST
	var available bool
	switch dis {
	case 0: // impossible
		return
	case 1:
		if c.inQDirty {
			available = c.inQ.updateContinuous(item)
			if c.inQ.isWholeContinuous() { // whole Q is ordered
				c.inQDirty = false
			} else { //those holes still exists.
				ackState = _VACK_QUICK
			}
		} else {
			// here is an ideal situation
			c.inQ.maxCtnSeq = pk.seq
			available = true
			ackState = _VACK_SCHED
		}

	default: // there is an unordered packet, hole occurred here.
		if !c.inQDirty {
			c.inQDirty = true
		}
	}

	// write valid received count
	c.inPkCnt++
	c.inQ.lastIns = item
	// try notify ack
	select {
	case c.evAck <- ackState:
	default:
	}
	if available { // try notify reader
		select {
		case c.evRead <- 1:
		default:
		}
	}
}

func (c *Conn) readInQ() bool {
	c.inlock.Lock()
	defer c.inlock.Unlock()
	// read already <-|-> expected Q
	//  [lastReadSeq] | [lastReadSeq+1] [lastReadSeq+2] ......
	if c.inQ.isEqualsHead(c.lastReadSeq+1) && c.lastReadSeq < c.inQ.maxCtnSeq {
		c.lastReadSeq = c.inQ.maxCtnSeq
		availabled := c.inQ.get(c.inQ.maxCtnSeq)
		availabled, _ = c.inQ.deleteBefore(availabled)
		for i := availabled; i != nil; i = i.next {
			c.inQReady = append(c.inQReady, i.payload...)
			// data was copied, then could recycle memory
			bpool.Put(i.buffer)
			i.payload = nil
			i.buffer = nil
		}
		return true
	}
	return false
}

// should not call this function concurrently.
func (c *Conn) Read(buf []byte) (nr int, err error) {
	for {
		if len(c.inQReady) > 0 {
			n := copy(buf, c.inQReady)
			c.inQReady = c.inQReady[n:]
			return n, nil
		}
		if !c.readInQ() {
			if c.rtmo > 0 {
				var tmo int64
				tmo, c.rtmo = c.rtmo, 0
				select {
				case _, y := <-c.evRead:
					if !y && len(c.inQReady) == 0 {
						return 0, io.EOF
					}
				case <-NewTimerChan(tmo):
					return 0, ErrIOTimeout
				}
			} else {
				// only when evRead is closed and inQReady is empty
				// then could reply eof
				if _, y := <-c.evRead; !y && len(c.inQReady) == 0 {
					return 0, io.EOF
				}
			}
		}
	}
}

// should not call this function concurrently.
func (c *Conn) Write(data []byte) (nr int, err error) {
	for len(data) > 0 && err == nil {
		//buf := make([]byte, _MSS+_AH_SIZE)
		buf := bpool.Get(c.mss + _AH_SIZE)
		body := buf[_TH_SIZE+_CH_SIZE:]
		n := copy(body, data)
		nr += n
		data = data[n:]
		pk := &packet{flag: _F_DATA, payload: body[:n], buffer: buf[:_AH_SIZE+n]}
		err = c.inputAndSend(pk)
	}
	return
}

func (c *Conn) LocalAddr() net.Addr {
	return c.sock.LocalAddr()
}

func (c *Conn) RemoteAddr() net.Addr {
	return c.dest
}

func (c *Conn) SetDeadline(t time.Time) error {
	c.SetReadDeadline(t)
	c.SetWriteDeadline(t)
	return nil
}

func (c *Conn) SetReadDeadline(t time.Time) error {
	if d := t.UnixNano()/Millisecond - Now(); d > 0 {
		c.rtmo = d
	}
	return nil
}

func (c *Conn) SetWriteDeadline(t time.Time) error {
	if d := t.UnixNano()/Millisecond - Now(); d > 0 {
		c.wtmo = d
	}
	return nil
}