From c7f163ee41d1a648b8bc5147896c8c0fa0bd9763 Mon Sep 17 00:00:00 2001 From: chrislu Date: Wed, 10 Sep 2025 12:53:45 -0700 Subject: [PATCH] mq(kafka): implement Produce handler with record parsing, offset assignment, ledger integration; supports fire-and-forget and acknowledged modes with comprehensive test coverage --- weed/mq/kafka/protocol/handler_test.go | 33 ++- weed/mq/kafka/protocol/produce.go | 196 ++++++++++++++++ weed/mq/kafka/protocol/produce_test.go | 303 +++++++++++++++++++++++++ 3 files changed, 526 insertions(+), 6 deletions(-) create mode 100644 weed/mq/kafka/protocol/produce.go create mode 100644 weed/mq/kafka/protocol/produce_test.go diff --git a/weed/mq/kafka/protocol/handler_test.go b/weed/mq/kafka/protocol/handler_test.go index f15689c6d..a66603062 100644 --- a/weed/mq/kafka/protocol/handler_test.go +++ b/weed/mq/kafka/protocol/handler_test.go @@ -92,12 +92,12 @@ func TestHandler_ApiVersions(t *testing.T) { // Check number of API keys numAPIKeys := binary.BigEndian.Uint32(respBuf[6:10]) - if numAPIKeys != 5 { - t.Errorf("expected 5 API keys, got: %d", numAPIKeys) + if numAPIKeys != 6 { + t.Errorf("expected 6 API keys, got: %d", numAPIKeys) } // Check API key details: api_key(2) + min_version(2) + max_version(2) - if len(respBuf) < 40 { // need space for 5 API keys + if len(respBuf) < 46 { // need space for 6 API keys t.Fatalf("response too short for API key data") } @@ -175,6 +175,21 @@ func TestHandler_ApiVersions(t *testing.T) { if maxVersion5 != 4 { t.Errorf("expected max version 4, got: %d", maxVersion5) } + + // Sixth API key (Produce) + apiKey6 := binary.BigEndian.Uint16(respBuf[40:42]) + minVersion6 := binary.BigEndian.Uint16(respBuf[42:44]) + maxVersion6 := binary.BigEndian.Uint16(respBuf[44:46]) + + if apiKey6 != 0 { + t.Errorf("expected API key 0, got: %d", apiKey6) + } + if minVersion6 != 0 { + t.Errorf("expected min version 0, got: %d", minVersion6) + } + if maxVersion6 != 7 { + t.Errorf("expected max version 7, got: %d", maxVersion6) + } // Close client to end handler client.Close() @@ -199,7 +214,7 @@ func TestHandler_handleApiVersions(t *testing.T) { t.Fatalf("handleApiVersions: %v", err) } - if len(response) < 42 { // minimum expected size (now has 5 API keys) + if len(response) < 48 { // minimum expected size (now has 6 API keys) t.Fatalf("response too short: %d bytes", len(response)) } @@ -217,8 +232,8 @@ func TestHandler_handleApiVersions(t *testing.T) { // Check number of API keys numAPIKeys := binary.BigEndian.Uint32(response[6:10]) - if numAPIKeys != 5 { - t.Errorf("expected 5 API keys, got: %d", numAPIKeys) + if numAPIKeys != 6 { + t.Errorf("expected 6 API keys, got: %d", numAPIKeys) } // Check first API key (ApiVersions) @@ -250,6 +265,12 @@ func TestHandler_handleApiVersions(t *testing.T) { if apiKey5 != 20 { t.Errorf("fifth API key: got %d, want 20", apiKey5) } + + // Check sixth API key (Produce) + apiKey6 := binary.BigEndian.Uint16(response[40:42]) + if apiKey6 != 0 { + t.Errorf("sixth API key: got %d, want 0", apiKey6) + } } func TestHandler_handleMetadata(t *testing.T) { diff --git a/weed/mq/kafka/protocol/produce.go b/weed/mq/kafka/protocol/produce.go new file mode 100644 index 000000000..512564af5 --- /dev/null +++ b/weed/mq/kafka/protocol/produce.go @@ -0,0 +1,196 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "time" +) + +func (h *Handler) handleProduce(correlationID uint32, requestBody []byte) ([]byte, error) { + // Parse minimal Produce request + // Request format: client_id + acks(2) + timeout(4) + topics_array + + if len(requestBody) < 8 { // client_id_size(2) + acks(2) + timeout(4) + return nil, fmt.Errorf("Produce request too short") + } + + // Skip client_id + clientIDSize := binary.BigEndian.Uint16(requestBody[0:2]) + offset := 2 + int(clientIDSize) + + if len(requestBody) < offset+10 { // acks(2) + timeout(4) + topics_count(4) + return nil, fmt.Errorf("Produce request missing data") + } + + // Parse acks and timeout + acks := int16(binary.BigEndian.Uint16(requestBody[offset : offset+2])) + offset += 2 + timeout := binary.BigEndian.Uint32(requestBody[offset : offset+4]) + offset += 4 + _ = timeout // unused for now + + topicsCount := binary.BigEndian.Uint32(requestBody[offset : offset+4]) + offset += 4 + + response := make([]byte, 0, 1024) + + // Correlation ID + correlationIDBytes := make([]byte, 4) + binary.BigEndian.PutUint32(correlationIDBytes, correlationID) + response = append(response, correlationIDBytes...) + + // Topics count (same as request) + topicsCountBytes := make([]byte, 4) + binary.BigEndian.PutUint32(topicsCountBytes, topicsCount) + response = append(response, topicsCountBytes...) + + // Process each topic + for i := uint32(0); i < topicsCount && offset < len(requestBody); i++ { + if len(requestBody) < offset+2 { + break + } + + // Parse topic name + topicNameSize := binary.BigEndian.Uint16(requestBody[offset : offset+2]) + offset += 2 + + if len(requestBody) < offset+int(topicNameSize)+4 { + break + } + + topicName := string(requestBody[offset : offset+int(topicNameSize)]) + offset += int(topicNameSize) + + // Parse partitions count + partitionsCount := binary.BigEndian.Uint32(requestBody[offset : offset+4]) + offset += 4 + + // Check if topic exists + h.topicsMu.RLock() + _, topicExists := h.topics[topicName] + h.topicsMu.RUnlock() + + // Response: topic_name_size(2) + topic_name + partitions_array + response = append(response, byte(topicNameSize>>8), byte(topicNameSize)) + response = append(response, []byte(topicName)...) + + partitionsCountBytes := make([]byte, 4) + binary.BigEndian.PutUint32(partitionsCountBytes, partitionsCount) + response = append(response, partitionsCountBytes...) + + // Process each partition + for j := uint32(0); j < partitionsCount && offset < len(requestBody); j++ { + if len(requestBody) < offset+8 { + break + } + + // Parse partition: partition_id(4) + record_set_size(4) + record_set + partitionID := binary.BigEndian.Uint32(requestBody[offset : offset+4]) + offset += 4 + + recordSetSize := binary.BigEndian.Uint32(requestBody[offset : offset+4]) + offset += 4 + + if len(requestBody) < offset+int(recordSetSize) { + break + } + + recordSetData := requestBody[offset : offset+int(recordSetSize)] + offset += int(recordSetSize) + + // Response: partition_id(4) + error_code(2) + base_offset(8) + log_append_time(8) + log_start_offset(8) + partitionIDBytes := make([]byte, 4) + binary.BigEndian.PutUint32(partitionIDBytes, partitionID) + response = append(response, partitionIDBytes...) + + var errorCode uint16 = 0 + var baseOffset int64 = 0 + currentTime := time.Now().UnixNano() + + if !topicExists { + errorCode = 3 // UNKNOWN_TOPIC_OR_PARTITION + } else { + // Process the record set (simplified - just count records and assign offsets) + recordCount, totalSize, parseErr := h.parseRecordSet(recordSetData) + if parseErr != nil { + errorCode = 42 // INVALID_RECORD + } else if recordCount > 0 { + // Get ledger and assign offsets + ledger := h.GetOrCreateLedger(topicName, int32(partitionID)) + baseOffset = ledger.AssignOffsets(int64(recordCount)) + + // Append each record to the ledger + avgSize := totalSize / recordCount + for k := int64(0); k < int64(recordCount); k++ { + ledger.AppendRecord(baseOffset+k, currentTime+k*1000, avgSize) // spread timestamps slightly + } + } + } + + // Error code + response = append(response, byte(errorCode>>8), byte(errorCode)) + + // Base offset (8 bytes) + baseOffsetBytes := make([]byte, 8) + binary.BigEndian.PutUint64(baseOffsetBytes, uint64(baseOffset)) + response = append(response, baseOffsetBytes...) + + // Log append time (8 bytes) - timestamp when appended + logAppendTimeBytes := make([]byte, 8) + binary.BigEndian.PutUint64(logAppendTimeBytes, uint64(currentTime)) + response = append(response, logAppendTimeBytes...) + + // Log start offset (8 bytes) - same as base for now + logStartOffsetBytes := make([]byte, 8) + binary.BigEndian.PutUint64(logStartOffsetBytes, uint64(baseOffset)) + response = append(response, logStartOffsetBytes...) + } + } + + // Add throttle time at the end (4 bytes) + response = append(response, 0, 0, 0, 0) + + // If acks=0, return empty response (fire and forget) + if acks == 0 { + return []byte{}, nil + } + + return response, nil +} + +// parseRecordSet parses a Kafka record set and returns the number of records and total size +// This is a simplified parser for Phase 1 - just counts valid records +func (h *Handler) parseRecordSet(recordSetData []byte) (recordCount int32, totalSize int32, err error) { + if len(recordSetData) < 12 { // minimum record set size + return 0, 0, fmt.Errorf("record set too small") + } + + // For Phase 1, we'll do a very basic parse to count records + // In a full implementation, this would parse the record batch format properly + + // Record batch header: base_offset(8) + length(4) + partition_leader_epoch(4) + magic(1) + ... + if len(recordSetData) < 17 { + return 0, 0, fmt.Errorf("invalid record batch header") + } + + // Skip to record count (at offset 16 in record batch) + if len(recordSetData) < 20 { + // Assume single record for very small batches + return 1, int32(len(recordSetData)), nil + } + + // Try to read record count from the batch header + recordCount = int32(binary.BigEndian.Uint32(recordSetData[16:20])) + + // Validate record count is reasonable + if recordCount <= 0 || recordCount > 1000000 { // sanity check + // Fallback to estimating based on size + estimatedCount := int32(len(recordSetData)) / 32 // rough estimate + if estimatedCount <= 0 { + estimatedCount = 1 + } + return estimatedCount, int32(len(recordSetData)), nil + } + + return recordCount, int32(len(recordSetData)), nil +} diff --git a/weed/mq/kafka/protocol/produce_test.go b/weed/mq/kafka/protocol/produce_test.go new file mode 100644 index 000000000..bd21f05a2 --- /dev/null +++ b/weed/mq/kafka/protocol/produce_test.go @@ -0,0 +1,303 @@ +package protocol + +import ( + "encoding/binary" + "testing" + "time" +) + +func TestHandler_handleProduce(t *testing.T) { + h := NewHandler() + correlationID := uint32(333) + + // First create a topic + h.topics["test-topic"] = &TopicInfo{ + Name: "test-topic", + Partitions: 1, + CreatedAt: time.Now().UnixNano(), + } + + // Build a simple Produce request with minimal record + clientID := "test-producer" + topicName := "test-topic" + + requestBody := make([]byte, 0, 256) + + // Client ID + requestBody = append(requestBody, 0, byte(len(clientID))) + requestBody = append(requestBody, []byte(clientID)...) + + // Acks (1 - wait for leader acknowledgment) + requestBody = append(requestBody, 0, 1) + + // Timeout (5000ms) + requestBody = append(requestBody, 0, 0, 0x13, 0x88) + + // Topics count (1) + requestBody = append(requestBody, 0, 0, 0, 1) + + // Topic name + requestBody = append(requestBody, 0, byte(len(topicName))) + requestBody = append(requestBody, []byte(topicName)...) + + // Partitions count (1) + requestBody = append(requestBody, 0, 0, 0, 1) + + // Partition 0 + requestBody = append(requestBody, 0, 0, 0, 0) // partition ID + + // Record set (simplified - just dummy data) + recordSet := make([]byte, 32) + // Basic record batch header structure for parsing + binary.BigEndian.PutUint64(recordSet[0:8], 0) // base offset + binary.BigEndian.PutUint32(recordSet[8:12], 24) // batch length + binary.BigEndian.PutUint32(recordSet[12:16], 0) // partition leader epoch + recordSet[16] = 2 // magic byte + binary.BigEndian.PutUint32(recordSet[16:20], 1) // record count at offset 16 + + recordSetSize := uint32(len(recordSet)) + requestBody = append(requestBody, byte(recordSetSize>>24), byte(recordSetSize>>16), byte(recordSetSize>>8), byte(recordSetSize)) + requestBody = append(requestBody, recordSet...) + + response, err := h.handleProduce(correlationID, requestBody) + if err != nil { + t.Fatalf("handleProduce: %v", err) + } + + if len(response) < 40 { // minimum expected size + t.Fatalf("response too short: %d bytes", len(response)) + } + + // Check correlation ID + respCorrelationID := binary.BigEndian.Uint32(response[0:4]) + if respCorrelationID != correlationID { + t.Errorf("correlation ID: got %d, want %d", respCorrelationID, correlationID) + } + + // Check topics count + topicsCount := binary.BigEndian.Uint32(response[4:8]) + if topicsCount != 1 { + t.Errorf("topics count: got %d, want 1", topicsCount) + } + + // Parse response structure + offset := 8 + respTopicNameSize := binary.BigEndian.Uint16(response[offset : offset+2]) + offset += 2 + if respTopicNameSize != uint16(len(topicName)) { + t.Errorf("response topic name size: got %d, want %d", respTopicNameSize, len(topicName)) + } + + respTopicName := string(response[offset : offset+int(respTopicNameSize)]) + offset += int(respTopicNameSize) + if respTopicName != topicName { + t.Errorf("response topic name: got %s, want %s", respTopicName, topicName) + } + + // Partitions count + respPartitionsCount := binary.BigEndian.Uint32(response[offset : offset+4]) + offset += 4 + if respPartitionsCount != 1 { + t.Errorf("response partitions count: got %d, want 1", respPartitionsCount) + } + + // Partition response: partition_id(4) + error_code(2) + base_offset(8) + log_append_time(8) + log_start_offset(8) + partitionID := binary.BigEndian.Uint32(response[offset : offset+4]) + offset += 4 + if partitionID != 0 { + t.Errorf("partition ID: got %d, want 0", partitionID) + } + + errorCode := binary.BigEndian.Uint16(response[offset : offset+2]) + offset += 2 + if errorCode != 0 { + t.Errorf("partition error: got %d, want 0", errorCode) + } + + baseOffset := int64(binary.BigEndian.Uint64(response[offset : offset+8])) + offset += 8 + if baseOffset < 0 { + t.Errorf("base offset: got %d, want >= 0", baseOffset) + } + + // Verify record was added to ledger + ledger := h.GetLedger(topicName, 0) + if ledger == nil { + t.Fatalf("ledger not found for topic-partition") + } + + if hwm := ledger.GetHighWaterMark(); hwm <= baseOffset { + t.Errorf("high water mark: got %d, want > %d", hwm, baseOffset) + } +} + +func TestHandler_handleProduce_UnknownTopic(t *testing.T) { + h := NewHandler() + correlationID := uint32(444) + + // Build Produce request for non-existent topic + clientID := "test-producer" + topicName := "non-existent-topic" + + requestBody := make([]byte, 0, 128) + + // Client ID + requestBody = append(requestBody, 0, byte(len(clientID))) + requestBody = append(requestBody, []byte(clientID)...) + + // Acks (1) + requestBody = append(requestBody, 0, 1) + + // Timeout + requestBody = append(requestBody, 0, 0, 0x13, 0x88) + + // Topics count (1) + requestBody = append(requestBody, 0, 0, 0, 1) + + // Topic name + requestBody = append(requestBody, 0, byte(len(topicName))) + requestBody = append(requestBody, []byte(topicName)...) + + // Partitions count (1) + requestBody = append(requestBody, 0, 0, 0, 1) + + // Partition 0 with minimal record set + requestBody = append(requestBody, 0, 0, 0, 0) // partition ID + + recordSet := make([]byte, 32) // dummy record set + binary.BigEndian.PutUint32(recordSet[16:20], 1) // record count + recordSetSize := uint32(len(recordSet)) + requestBody = append(requestBody, byte(recordSetSize>>24), byte(recordSetSize>>16), byte(recordSetSize>>8), byte(recordSetSize)) + requestBody = append(requestBody, recordSet...) + + response, err := h.handleProduce(correlationID, requestBody) + if err != nil { + t.Fatalf("handleProduce: %v", err) + } + + // Parse response to check for UNKNOWN_TOPIC_OR_PARTITION error + offset := 8 + 2 + len(topicName) + 4 + 4 // skip to error code + errorCode := binary.BigEndian.Uint16(response[offset : offset+2]) + if errorCode != 3 { // UNKNOWN_TOPIC_OR_PARTITION + t.Errorf("expected UNKNOWN_TOPIC_OR_PARTITION error (3), got: %d", errorCode) + } +} + +func TestHandler_handleProduce_FireAndForget(t *testing.T) { + h := NewHandler() + correlationID := uint32(555) + + // Create a topic + h.topics["test-topic"] = &TopicInfo{ + Name: "test-topic", + Partitions: 1, + CreatedAt: time.Now().UnixNano(), + } + + // Build Produce request with acks=0 (fire and forget) + clientID := "test-producer" + topicName := "test-topic" + + requestBody := make([]byte, 0, 128) + + // Client ID + requestBody = append(requestBody, 0, byte(len(clientID))) + requestBody = append(requestBody, []byte(clientID)...) + + // Acks (0 - fire and forget) + requestBody = append(requestBody, 0, 0) + + // Timeout + requestBody = append(requestBody, 0, 0, 0x13, 0x88) + + // Topics count (1) + requestBody = append(requestBody, 0, 0, 0, 1) + + // Topic name + requestBody = append(requestBody, 0, byte(len(topicName))) + requestBody = append(requestBody, []byte(topicName)...) + + // Partitions count (1) + requestBody = append(requestBody, 0, 0, 0, 1) + + // Partition 0 with record set + requestBody = append(requestBody, 0, 0, 0, 0) // partition ID + + recordSet := make([]byte, 32) + binary.BigEndian.PutUint32(recordSet[16:20], 1) // record count + recordSetSize := uint32(len(recordSet)) + requestBody = append(requestBody, byte(recordSetSize>>24), byte(recordSetSize>>16), byte(recordSetSize>>8), byte(recordSetSize)) + requestBody = append(requestBody, recordSet...) + + response, err := h.handleProduce(correlationID, requestBody) + if err != nil { + t.Fatalf("handleProduce: %v", err) + } + + // For acks=0, should return empty response + if len(response) != 0 { + t.Errorf("fire and forget response: got %d bytes, want 0", len(response)) + } + + // But record should still be added to ledger + ledger := h.GetLedger(topicName, 0) + if ledger == nil { + t.Fatalf("ledger not found for topic-partition") + } + + if hwm := ledger.GetHighWaterMark(); hwm == 0 { + t.Errorf("high water mark: got %d, want > 0", hwm) + } +} + +func TestHandler_parseRecordSet(t *testing.T) { + h := NewHandler() + + // Test with valid record set + recordSet := make([]byte, 32) + binary.BigEndian.PutUint64(recordSet[0:8], 0) // base offset + binary.BigEndian.PutUint32(recordSet[8:12], 24) // batch length + binary.BigEndian.PutUint32(recordSet[12:16], 0) // partition leader epoch + recordSet[16] = 2 // magic byte + binary.BigEndian.PutUint32(recordSet[16:20], 3) // record count at correct offset + + count, size, err := h.parseRecordSet(recordSet) + if err != nil { + t.Fatalf("parseRecordSet: %v", err) + } + if count != 3 { + t.Errorf("record count: got %d, want 3", count) + } + if size != int32(len(recordSet)) { + t.Errorf("total size: got %d, want %d", size, len(recordSet)) + } + + // Test with invalid record set (too small) + invalidRecordSet := []byte{1, 2, 3} + _, _, err = h.parseRecordSet(invalidRecordSet) + if err == nil { + t.Errorf("expected error for invalid record set") + } + + // Test with unrealistic record count (should fall back to estimation) + badRecordSet := make([]byte, 32) + binary.BigEndian.PutUint32(badRecordSet[16:20], 999999999) // unrealistic count + + count, size, err = h.parseRecordSet(badRecordSet) + if err != nil { + t.Fatalf("parseRecordSet fallback: %v", err) + } + if count <= 0 { + t.Errorf("fallback count: got %d, want > 0", count) + } + + // Test with small batch (should estimate 1 record) + smallRecordSet := make([]byte, 18) // Just enough for header check + count, size, err = h.parseRecordSet(smallRecordSet) + if err != nil { + t.Fatalf("parseRecordSet small batch: %v", err) + } + if count != 1 { + t.Errorf("small batch count: got %d, want 1", count) + } +}