diff --git a/weed/mq/broker/broker_grpc_sub.go b/weed/mq/broker/broker_grpc_sub.go index 62a329499..126322cd9 100644 --- a/weed/mq/broker/broker_grpc_sub.go +++ b/weed/mq/broker/broker_grpc_sub.go @@ -101,11 +101,15 @@ func (b *MessageQueueBroker) SubscribeMessage(stream mq_pb.SeaweedMessaging_Subs }}) break } - glog.V(0).Infof("topic %v partition %v subscriber %s error: %v", t, partition, clientName, err) + glog.V(0).Infof("topic %v partition %v subscriber %s lastOffset %d error: %v", t, partition, clientName, lastOffset, err) break } + if ack.GetAck().Key == nil { + // skip ack for control messages + continue + } imt.AcknowledgeMessage(ack.GetAck().Key, ack.GetAck().Sequence) - currentLastOffset := imt.GetOldest() + currentLastOffset := imt.GetOldestAckedTimestamp() fmt.Printf("%+v recv (%s,%d), oldest %d\n", partition, string(ack.GetAck().Key), ack.GetAck().Sequence, currentLastOffset) if subscribeFollowMeStream != nil && currentLastOffset > lastOffset { if err := subscribeFollowMeStream.Send(&mq_pb.SubscribeFollowMeRequest{ @@ -124,16 +128,16 @@ func (b *MessageQueueBroker) SubscribeMessage(stream mq_pb.SeaweedMessaging_Subs } if lastOffset > 0 { if err := b.saveConsumerGroupOffset(t, partition, req.GetInit().ConsumerGroup, lastOffset); err != nil { - glog.Errorf("saveConsumerGroupOffset: %v", err) + glog.Errorf("saveConsumerGroupOffset partition %v lastOffset %d: %v", partition, lastOffset, err) } - if subscribeFollowMeStream != nil { - if err := subscribeFollowMeStream.Send(&mq_pb.SubscribeFollowMeRequest{ - Message: &mq_pb.SubscribeFollowMeRequest_Close{ - Close: &mq_pb.SubscribeFollowMeRequest_CloseMessage{}, - }, - }); err != nil { - glog.Errorf("Error sending close to follower: %v", err) - } + } + if subscribeFollowMeStream != nil { + if err := subscribeFollowMeStream.Send(&mq_pb.SubscribeFollowMeRequest{ + Message: &mq_pb.SubscribeFollowMeRequest_Close{ + Close: &mq_pb.SubscribeFollowMeRequest_CloseMessage{}, + }, + }); err != nil { + glog.Errorf("Error sending close to follower: %v", err) } } }() @@ -170,8 +174,9 @@ func (b *MessageQueueBroker) SubscribeMessage(stream mq_pb.SeaweedMessaging_Subs for imt.IsInflight(logEntry.Key) { time.Sleep(137 * time.Millisecond) } - - imt.InflightMessage(logEntry.Key, logEntry.TsNs) + if logEntry.Key != nil { + imt.EnflightMessage(logEntry.Key, logEntry.TsNs) + } if err := stream.Send(&mq_pb.SubscribeMessageResponse{Message: &mq_pb.SubscribeMessageResponse_Data{ Data: &mq_pb.DataMessage{ diff --git a/weed/mq/sub_coordinator/inflight_message_tracker.go b/weed/mq/sub_coordinator/inflight_message_tracker.go index 1b1caca00..6ac5103a9 100644 --- a/weed/mq/sub_coordinator/inflight_message_tracker.go +++ b/weed/mq/sub_coordinator/inflight_message_tracker.go @@ -1,6 +1,7 @@ package sub_coordinator import ( + "fmt" "sort" "sync" ) @@ -18,13 +19,14 @@ func NewInflightMessageTracker(capacity int) *InflightMessageTracker { } } -// InflightMessage tracks the message with the key and timestamp. +// EnflightMessage tracks the message with the key and timestamp. // These messages are sent to the consumer group instances and waiting for ack. -func (imt *InflightMessageTracker) InflightMessage(key []byte, tsNs int64) { +func (imt *InflightMessageTracker) EnflightMessage(key []byte, tsNs int64) { + fmt.Printf("EnflightMessage(%s,%d)\n", string(key), tsNs) imt.mu.Lock() defer imt.mu.Unlock() imt.messages[string(key)] = tsNs - imt.timestamps.Add(tsNs) + imt.timestamps.EnflightTimestamp(tsNs) } // IsMessageAcknowledged returns true if the message has been acknowledged. @@ -35,7 +37,7 @@ func (imt *InflightMessageTracker) IsMessageAcknowledged(key []byte, tsNs int64) imt.mu.Lock() defer imt.mu.Unlock() - if tsNs < imt.timestamps.Oldest() { + if tsNs <= imt.timestamps.OldestAckedTimestamp() { return true } if tsNs > imt.timestamps.Latest() { @@ -51,6 +53,7 @@ func (imt *InflightMessageTracker) IsMessageAcknowledged(key []byte, tsNs int64) // AcknowledgeMessage acknowledges the message with the key and timestamp. func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bool { + fmt.Printf("AcknowledgeMessage(%s,%d)\n", string(key), tsNs) imt.mu.Lock() defer imt.mu.Unlock() timestamp, exists := imt.messages[string(key)] @@ -59,12 +62,12 @@ func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bo } delete(imt.messages, string(key)) // Remove the specific timestamp from the ring buffer. - imt.timestamps.Remove(tsNs) + imt.timestamps.AckTimestamp(tsNs) return true } -func (imt *InflightMessageTracker) GetOldest() int64 { - return imt.timestamps.Oldest() +func (imt *InflightMessageTracker) GetOldestAckedTimestamp() int64 { + return imt.timestamps.OldestAckedTimestamp() } // IsInflight returns true if the message with the key is inflight. @@ -75,63 +78,81 @@ func (imt *InflightMessageTracker) IsInflight(key []byte) bool { return found } +type TimestampStatus struct { + Timestamp int64 + Acked bool +} + // RingBuffer represents a circular buffer to hold timestamps. type RingBuffer struct { - buffer []int64 - head int - size int + buffer []*TimestampStatus + head int + size int + maxTimestamp int64 + minAckedTs int64 } // NewRingBuffer creates a new RingBuffer of the given capacity. func NewRingBuffer(capacity int) *RingBuffer { return &RingBuffer{ - buffer: make([]int64, capacity), + buffer: newBuffer(capacity), } } -// Add adds a new timestamp to the ring buffer. -func (rb *RingBuffer) Add(timestamp int64) { - rb.buffer[rb.head] = timestamp - rb.head = (rb.head + 1) % len(rb.buffer) +func newBuffer(capacity int) []*TimestampStatus { + buffer := make([]*TimestampStatus, capacity) + for i := range buffer { + buffer[i] = &TimestampStatus{} + } + return buffer +} + +// EnflightTimestamp adds a new timestamp to the ring buffer. +func (rb *RingBuffer) EnflightTimestamp(timestamp int64) { if rb.size < len(rb.buffer) { rb.size++ + } else { + newBuf := newBuffer(2*len(rb.buffer)) + for i := 0; i < rb.size; i++ { + newBuf[i] = rb.buffer[(rb.head+len(rb.buffer)-rb.size+i)%len(rb.buffer)] + } + rb.buffer = newBuf + rb.head = rb.size + rb.size++ + } + head := rb.buffer[rb.head] + head.Timestamp = timestamp + head.Acked = false + rb.head = (rb.head + 1) % len(rb.buffer) + if timestamp > rb.maxTimestamp { + rb.maxTimestamp = timestamp } } -// Remove removes the specified timestamp from the ring buffer. -func (rb *RingBuffer) Remove(timestamp int64) { +// AckTimestamp removes the specified timestamp from the ring buffer. +func (rb *RingBuffer) AckTimestamp(timestamp int64) { // Perform binary search index := sort.Search(rb.size, func(i int) bool { - return rb.buffer[(rb.head+len(rb.buffer)-rb.size+i)%len(rb.buffer)] >= timestamp + return rb.buffer[(rb.head+len(rb.buffer)-rb.size+i)%len(rb.buffer)].Timestamp >= timestamp }) actualIndex := (rb.head + len(rb.buffer) - rb.size + index) % len(rb.buffer) - if index < rb.size && rb.buffer[actualIndex] == timestamp { - // Shift elements to maintain the buffer order - for i := index; i < rb.size-1; i++ { - fromIndex := (rb.head + len(rb.buffer) - rb.size + i + 1) % len(rb.buffer) - toIndex := (rb.head + len(rb.buffer) - rb.size + i) % len(rb.buffer) - rb.buffer[toIndex] = rb.buffer[fromIndex] - } + rb.buffer[actualIndex].Acked = true + + // Remove all the acknowledged timestamps from the buffer + startPos := (rb.head + len(rb.buffer) - rb.size) % len(rb.buffer) + for i := 0; i < len(rb.buffer) && rb.buffer[(startPos+i)%len(rb.buffer)].Acked; i++ { rb.size-- - rb.buffer[(rb.head+len(rb.buffer)-1)%len(rb.buffer)] = 0 // Clear the last element + rb.minAckedTs = rb.buffer[(startPos+i)%len(rb.buffer)].Timestamp } } -// Oldest returns the oldest timestamp in the ring buffer. -func (rb *RingBuffer) Oldest() int64 { - if rb.size == 0 { - return 0 - } - oldestIndex := (rb.head + len(rb.buffer) - rb.size) % len(rb.buffer) - return rb.buffer[oldestIndex] +// OldestAckedTimestamp returns the oldest that is already acked timestamp in the ring buffer. +func (rb *RingBuffer) OldestAckedTimestamp() int64 { + return rb.minAckedTs } -// Latest returns the most recently added timestamp in the ring buffer. +// Latest returns the most recently known timestamp in the ring buffer. func (rb *RingBuffer) Latest() int64 { - if rb.size == 0 { - return 0 - } - latestIndex := (rb.head + len(rb.buffer) - 1) % len(rb.buffer) - return rb.buffer[latestIndex] + return rb.maxTimestamp } diff --git a/weed/mq/sub_coordinator/inflight_message_tracker_test.go b/weed/mq/sub_coordinator/inflight_message_tracker_test.go index 4b35e32bf..83c33b5ba 100644 --- a/weed/mq/sub_coordinator/inflight_message_tracker_test.go +++ b/weed/mq/sub_coordinator/inflight_message_tracker_test.go @@ -1,9 +1,8 @@ package sub_coordinator import ( - "sort" + "github.com/stretchr/testify/assert" "testing" - "time" ) func TestRingBuffer(t *testing.T) { @@ -13,7 +12,7 @@ func TestRingBuffer(t *testing.T) { // Add timestamps to the buffer timestamps := []int64{100, 200, 300, 400, 500} for _, ts := range timestamps { - rb.Add(ts) + rb.EnflightTimestamp(ts) } // Test Add method and buffer size @@ -22,38 +21,25 @@ func TestRingBuffer(t *testing.T) { t.Errorf("Expected buffer size %d, got %d", expectedSize, rb.size) } - // Test Oldest and Latest methods - expectedOldest := int64(100) - if oldest := rb.Oldest(); oldest != expectedOldest { - t.Errorf("Expected oldest timestamp %d, got %d", expectedOldest, oldest) - } - expectedLatest := int64(500) - if latest := rb.Latest(); latest != expectedLatest { - t.Errorf("Expected latest timestamp %d, got %d", expectedLatest, latest) - } + assert.Equal(t, int64(0), rb.OldestAckedTimestamp()) + assert.Equal(t, int64(500), rb.Latest()) - // Test Remove method - rb.Remove(200) - expectedSize-- - if rb.size != expectedSize { - t.Errorf("Expected buffer size %d after removal, got %d", expectedSize, rb.size) - } + rb.AckTimestamp(200) + assert.Equal(t, int64(0), rb.OldestAckedTimestamp()) + rb.AckTimestamp(100) + assert.Equal(t, int64(200), rb.OldestAckedTimestamp()) - // Test removal of non-existent element - rb.Remove(600) - if rb.size != expectedSize { - t.Errorf("Expected buffer size %d after attempting removal of non-existent element, got %d", expectedSize, rb.size) - } + rb.EnflightTimestamp(int64(600)) + rb.EnflightTimestamp(int64(700)) - // Test binary search correctness - target := int64(300) - index := sort.Search(rb.size, func(i int) bool { - return rb.buffer[(rb.head+len(rb.buffer)-rb.size+i)%len(rb.buffer)] >= target - }) - actualIndex := (rb.head + len(rb.buffer) - rb.size + index) % len(rb.buffer) - if rb.buffer[actualIndex] != target { - t.Errorf("Binary search failed to find the correct index for timestamp %d", target) - } + rb.AckTimestamp(500) + assert.Equal(t, int64(200), rb.OldestAckedTimestamp()) + rb.AckTimestamp(400) + assert.Equal(t, int64(200), rb.OldestAckedTimestamp()) + rb.AckTimestamp(300) + assert.Equal(t, int64(500), rb.OldestAckedTimestamp()) + + assert.Equal(t, int64(700), rb.Latest()) } func TestInflightMessageTracker(t *testing.T) { @@ -61,9 +47,9 @@ func TestInflightMessageTracker(t *testing.T) { tracker := NewInflightMessageTracker(5) // Add inflight messages - key := []byte("exampleKey") - timestamp := time.Now().UnixNano() - tracker.InflightMessage(key, timestamp) + key := []byte("1") + timestamp := int64(1) + tracker.EnflightMessage(key, timestamp) // Test IsMessageAcknowledged method isOld := tracker.IsMessageAcknowledged(key, timestamp-10) @@ -82,4 +68,29 @@ func TestInflightMessageTracker(t *testing.T) { if tracker.timestamps.size != 0 { t.Error("Expected buffer size to be 0 after ack") } + assert.Equal(t, timestamp, tracker.GetOldestAckedTimestamp()) +} + +func TestInflightMessageTracker2(t *testing.T) { + // Initialize an InflightMessageTracker with initial capacity 1 + tracker := NewInflightMessageTracker(1) + + tracker.EnflightMessage([]byte("1"), int64(1)) + tracker.EnflightMessage([]byte("2"), int64(2)) + tracker.EnflightMessage([]byte("3"), int64(3)) + tracker.EnflightMessage([]byte("4"), int64(4)) + tracker.EnflightMessage([]byte("5"), int64(5)) + assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1))) + assert.Equal(t, int64(1), tracker.GetOldestAckedTimestamp()) + + // Test IsMessageAcknowledged method + isAcked := tracker.IsMessageAcknowledged([]byte("2"), int64(2)) + if isAcked { + t.Error("Expected message to be not acked") + } + + // Test AcknowledgeMessage method + assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2))) + assert.Equal(t, int64(2), tracker.GetOldestAckedTimestamp()) + }