chrislu
8 months ago
2 changed files with 205 additions and 0 deletions
-
120weed/mq/sub_coordinator/inflight_message_tracker.go
-
85weed/mq/sub_coordinator/inflight_message_tracker_test.go
@ -0,0 +1,120 @@ |
|||
package sub_coordinator |
|||
|
|||
import ( |
|||
"sort" |
|||
"sync" |
|||
) |
|||
|
|||
type InflightMessageTracker struct { |
|||
messages map[string]int64 |
|||
mu sync.Mutex |
|||
timestamps *RingBuffer |
|||
} |
|||
|
|||
func NewInflightMessageTracker(capacity int) *InflightMessageTracker { |
|||
return &InflightMessageTracker{ |
|||
messages: make(map[string]int64), |
|||
timestamps: NewRingBuffer(capacity), |
|||
} |
|||
} |
|||
|
|||
// InflightMessage tracks the message with the key and timestamp.
|
|||
// These messages are sent to the consumer group instances and waiting for ack.
|
|||
func (imt *InflightMessageTracker) InflightMessage(key []byte, tsNs int64) { |
|||
imt.mu.Lock() |
|||
defer imt.mu.Unlock() |
|||
imt.messages[string(key)] = tsNs |
|||
imt.timestamps.Add(tsNs) |
|||
} |
|||
// IsMessageAcknowledged returns true if the message has been acknowledged.
|
|||
// If the message is older than the oldest inflight messages, returns false.
|
|||
// returns false if the message is inflight.
|
|||
// Otherwise, returns false if the message is old and can be ignored.
|
|||
func (imt *InflightMessageTracker) IsMessageAcknowledged(key []byte, tsNs int64) bool { |
|||
imt.mu.Lock() |
|||
defer imt.mu.Unlock() |
|||
|
|||
if tsNs < imt.timestamps.Oldest() { |
|||
return true |
|||
} |
|||
if tsNs > imt.timestamps.Latest() { |
|||
return false |
|||
} |
|||
|
|||
if _, found := imt.messages[string(key)]; found { |
|||
return false |
|||
} |
|||
|
|||
return true |
|||
} |
|||
// AcknowledgeMessage acknowledges the message with the key and timestamp.
|
|||
func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bool { |
|||
imt.mu.Lock() |
|||
defer imt.mu.Unlock() |
|||
timestamp, exists := imt.messages[string(key)] |
|||
if !exists || timestamp != tsNs { |
|||
return false |
|||
} |
|||
delete(imt.messages, string(key)) |
|||
// Remove the specific timestamp from the ring buffer.
|
|||
imt.timestamps.Remove(tsNs) |
|||
return true |
|||
} |
|||
|
|||
// RingBuffer represents a circular buffer to hold timestamps.
|
|||
type RingBuffer struct { |
|||
buffer []int64 |
|||
head int |
|||
size int |
|||
} |
|||
// NewRingBuffer creates a new RingBuffer of the given capacity.
|
|||
func NewRingBuffer(capacity int) *RingBuffer { |
|||
return &RingBuffer{ |
|||
buffer: make([]int64, capacity), |
|||
} |
|||
} |
|||
// Add adds a new timestamp to the ring buffer.
|
|||
func (rb *RingBuffer) Add(timestamp int64) { |
|||
rb.buffer[rb.head] = timestamp |
|||
rb.head = (rb.head + 1) % len(rb.buffer) |
|||
if rb.size < len(rb.buffer) { |
|||
rb.size++ |
|||
} |
|||
} |
|||
// Remove removes the specified timestamp from the ring buffer.
|
|||
func (rb *RingBuffer) Remove(timestamp int64) { |
|||
// Perform binary search
|
|||
index := sort.Search(rb.size, func(i int) bool { |
|||
return rb.buffer[(rb.head+len(rb.buffer)-rb.size+i)%len(rb.buffer)] >= timestamp |
|||
}) |
|||
actualIndex := (rb.head + len(rb.buffer) - rb.size + index) % len(rb.buffer) |
|||
|
|||
if index < rb.size && rb.buffer[actualIndex] == timestamp { |
|||
// Shift elements to maintain the buffer order
|
|||
for i := index; i < rb.size-1; i++ { |
|||
fromIndex := (rb.head + len(rb.buffer) - rb.size + i + 1) % len(rb.buffer) |
|||
toIndex := (rb.head + len(rb.buffer) - rb.size + i) % len(rb.buffer) |
|||
rb.buffer[toIndex] = rb.buffer[fromIndex] |
|||
} |
|||
rb.size-- |
|||
rb.buffer[(rb.head+len(rb.buffer)-1)%len(rb.buffer)] = 0 // Clear the last element
|
|||
} |
|||
} |
|||
|
|||
// Oldest returns the oldest timestamp in the ring buffer.
|
|||
func (rb *RingBuffer) Oldest() int64 { |
|||
if rb.size == 0 { |
|||
return 0 |
|||
} |
|||
oldestIndex := (rb.head + len(rb.buffer) - rb.size) % len(rb.buffer) |
|||
return rb.buffer[oldestIndex] |
|||
} |
|||
|
|||
// Latest returns the most recently added timestamp in the ring buffer.
|
|||
func (rb *RingBuffer) Latest() int64 { |
|||
if rb.size == 0 { |
|||
return 0 |
|||
} |
|||
latestIndex := (rb.head + len(rb.buffer) - 1) % len(rb.buffer) |
|||
return rb.buffer[latestIndex] |
|||
} |
@ -0,0 +1,85 @@ |
|||
package sub_coordinator |
|||
|
|||
import ( |
|||
"sort" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestRingBuffer(t *testing.T) { |
|||
// Initialize a RingBuffer with capacity 5
|
|||
rb := NewRingBuffer(5) |
|||
|
|||
// Add timestamps to the buffer
|
|||
timestamps := []int64{100, 200, 300, 400, 500} |
|||
for _, ts := range timestamps { |
|||
rb.Add(ts) |
|||
} |
|||
|
|||
// Test Add method and buffer size
|
|||
expectedSize := 5 |
|||
if rb.size != expectedSize { |
|||
t.Errorf("Expected buffer size %d, got %d", expectedSize, rb.size) |
|||
} |
|||
|
|||
// Test Oldest and Latest methods
|
|||
expectedOldest := int64(100) |
|||
if oldest := rb.Oldest(); oldest != expectedOldest { |
|||
t.Errorf("Expected oldest timestamp %d, got %d", expectedOldest, oldest) |
|||
} |
|||
expectedLatest := int64(500) |
|||
if latest := rb.Latest(); latest != expectedLatest { |
|||
t.Errorf("Expected latest timestamp %d, got %d", expectedLatest, latest) |
|||
} |
|||
|
|||
// Test Remove method
|
|||
rb.Remove(200) |
|||
expectedSize-- |
|||
if rb.size != expectedSize { |
|||
t.Errorf("Expected buffer size %d after removal, got %d", expectedSize, rb.size) |
|||
} |
|||
|
|||
// Test removal of non-existent element
|
|||
rb.Remove(600) |
|||
if rb.size != expectedSize { |
|||
t.Errorf("Expected buffer size %d after attempting removal of non-existent element, got %d", expectedSize, rb.size) |
|||
} |
|||
|
|||
// Test binary search correctness
|
|||
target := int64(300) |
|||
index := sort.Search(rb.size, func(i int) bool { |
|||
return rb.buffer[(rb.head+len(rb.buffer)-rb.size+i)%len(rb.buffer)] >= target |
|||
}) |
|||
actualIndex := (rb.head + len(rb.buffer) - rb.size + index) % len(rb.buffer) |
|||
if rb.buffer[actualIndex] != target { |
|||
t.Errorf("Binary search failed to find the correct index for timestamp %d", target) |
|||
} |
|||
} |
|||
|
|||
func TestInflightMessageTracker(t *testing.T) { |
|||
// Initialize an InflightMessageTracker with capacity 5
|
|||
tracker := NewInflightMessageTracker(5) |
|||
|
|||
// Add inflight messages
|
|||
key := []byte("exampleKey") |
|||
timestamp := time.Now().UnixNano() |
|||
tracker.InflightMessage(key, timestamp) |
|||
|
|||
// Test IsMessageAcknowledged method
|
|||
isOld := tracker.IsMessageAcknowledged(key, timestamp-10) |
|||
if !isOld { |
|||
t.Error("Expected message to be old") |
|||
} |
|||
|
|||
// Test AcknowledgeMessage method
|
|||
acked := tracker.AcknowledgeMessage(key, timestamp) |
|||
if !acked { |
|||
t.Error("Expected message to be acked") |
|||
} |
|||
if _, exists := tracker.messages[string(key)]; exists { |
|||
t.Error("Expected message to be deleted after ack") |
|||
} |
|||
if tracker.timestamps.size != 0 { |
|||
t.Error("Expected buffer size to be 0 after ack") |
|||
} |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue