diff --git a/weed/mq/kafka/protocol/api_versions_test.go b/weed/mq/kafka/protocol/api_versions_test.go index fb979088e..1dd7c1504 100644 --- a/weed/mq/kafka/protocol/api_versions_test.go +++ b/weed/mq/kafka/protocol/api_versions_test.go @@ -10,7 +10,7 @@ func TestApiVersions_AdvertisedVersionsMatch(t *testing.T) { handler := NewTestHandler() defer handler.Close() - response, err := handler.handleApiVersions(12345) + response, err := handler.handleApiVersions(12345, 0) if err != nil { t.Fatalf("handleApiVersions failed: %v", err) } @@ -253,7 +253,7 @@ func TestApiVersions_ResponseFormat(t *testing.T) { handler := NewTestHandler() defer handler.Close() - response, err := handler.handleApiVersions(99999) + response, err := handler.handleApiVersions(99999, 0) if err != nil { t.Fatalf("handleApiVersions failed: %v", err) } diff --git a/weed/mq/kafka/protocol/flexible_versions.go b/weed/mq/kafka/protocol/flexible_versions.go new file mode 100644 index 000000000..a013eb5f8 --- /dev/null +++ b/weed/mq/kafka/protocol/flexible_versions.go @@ -0,0 +1,359 @@ +package protocol + +import ( + "encoding/binary" + "fmt" +) + +// FlexibleVersions provides utilities for handling Kafka flexible versions protocol +// Flexible versions use compact arrays/strings and tagged fields for backward compatibility + +// CompactArrayLength encodes a length for compact arrays +// Compact arrays encode length as length+1, where 0 means empty array +func CompactArrayLength(length uint32) []byte { + if length == 0 { + return []byte{0} // Empty array + } + return EncodeUvarint(length + 1) +} + +// DecodeCompactArrayLength decodes a compact array length +// Returns the actual length and number of bytes consumed +func DecodeCompactArrayLength(data []byte) (uint32, int, error) { + if len(data) == 0 { + return 0, 0, fmt.Errorf("no data for compact array length") + } + + if data[0] == 0 { + return 0, 1, nil // Empty array + } + + length, consumed, err := DecodeUvarint(data) + if err != nil { + return 0, 0, fmt.Errorf("decode compact array length: %w", err) + } + + if length == 0 { + return 0, consumed, fmt.Errorf("invalid compact array length encoding") + } + + return length - 1, consumed, nil +} + +// CompactStringLength encodes a length for compact strings +// Compact strings encode length as length+1, where 0 means null string +func CompactStringLength(length int) []byte { + if length < 0 { + return []byte{0} // Null string + } + return EncodeUvarint(uint32(length + 1)) +} + +// DecodeCompactStringLength decodes a compact string length +// Returns the actual length (-1 for null), and number of bytes consumed +func DecodeCompactStringLength(data []byte) (int, int, error) { + if len(data) == 0 { + return 0, 0, fmt.Errorf("no data for compact string length") + } + + if data[0] == 0 { + return -1, 1, nil // Null string + } + + length, consumed, err := DecodeUvarint(data) + if err != nil { + return 0, 0, fmt.Errorf("decode compact string length: %w", err) + } + + if length == 0 { + return 0, consumed, fmt.Errorf("invalid compact string length encoding") + } + + return int(length - 1), consumed, nil +} + +// EncodeUvarint encodes an unsigned integer using variable-length encoding +// This is used for compact arrays, strings, and tagged fields +func EncodeUvarint(value uint32) []byte { + var buf []byte + for value >= 0x80 { + buf = append(buf, byte(value)|0x80) + value >>= 7 + } + buf = append(buf, byte(value)) + return buf +} + +// DecodeUvarint decodes a variable-length unsigned integer +// Returns the decoded value and number of bytes consumed +func DecodeUvarint(data []byte) (uint32, int, error) { + var value uint32 + var shift uint + var consumed int + + for i, b := range data { + consumed = i + 1 + value |= uint32(b&0x7F) << shift + + if (b & 0x80) == 0 { + return value, consumed, nil + } + + shift += 7 + if shift >= 32 { + return 0, consumed, fmt.Errorf("uvarint overflow") + } + } + + return 0, consumed, fmt.Errorf("incomplete uvarint") +} + +// TaggedField represents a tagged field in flexible versions +type TaggedField struct { + Tag uint32 + Data []byte +} + +// TaggedFields represents a collection of tagged fields +type TaggedFields struct { + Fields []TaggedField +} + +// EncodeTaggedFields encodes tagged fields for flexible versions +func (tf *TaggedFields) Encode() []byte { + if len(tf.Fields) == 0 { + return []byte{0} // Empty tagged fields + } + + var buf []byte + + // Number of tagged fields + buf = append(buf, EncodeUvarint(uint32(len(tf.Fields)))...) + + for _, field := range tf.Fields { + // Tag + buf = append(buf, EncodeUvarint(field.Tag)...) + // Size + buf = append(buf, EncodeUvarint(uint32(len(field.Data)))...) + // Data + buf = append(buf, field.Data...) + } + + return buf +} + +// DecodeTaggedFields decodes tagged fields from flexible versions +func DecodeTaggedFields(data []byte) (*TaggedFields, int, error) { + if len(data) == 0 { + return &TaggedFields{}, 0, fmt.Errorf("no data for tagged fields") + } + + if data[0] == 0 { + return &TaggedFields{}, 1, nil // Empty tagged fields + } + + offset := 0 + + // Number of tagged fields + numFields, consumed, err := DecodeUvarint(data[offset:]) + if err != nil { + return nil, 0, fmt.Errorf("decode tagged fields count: %w", err) + } + offset += consumed + + fields := make([]TaggedField, numFields) + + for i := uint32(0); i < numFields; i++ { + // Tag + tag, consumed, err := DecodeUvarint(data[offset:]) + if err != nil { + return nil, 0, fmt.Errorf("decode tagged field %d tag: %w", i, err) + } + offset += consumed + + // Size + size, consumed, err := DecodeUvarint(data[offset:]) + if err != nil { + return nil, 0, fmt.Errorf("decode tagged field %d size: %w", i, err) + } + offset += consumed + + // Data + if offset+int(size) > len(data) { + return nil, 0, fmt.Errorf("tagged field %d data truncated", i) + } + + fields[i] = TaggedField{ + Tag: tag, + Data: data[offset : offset+int(size)], + } + offset += int(size) + } + + return &TaggedFields{Fields: fields}, offset, nil +} + +// IsFlexibleVersion determines if an API version uses flexible versions +// This is API-specific and based on when each API adopted flexible versions +func IsFlexibleVersion(apiKey, apiVersion uint16) bool { + switch apiKey { + case 18: // ApiVersions + return apiVersion >= 3 + case 3: // Metadata + return apiVersion >= 9 + case 1: // Fetch + return apiVersion >= 12 + case 0: // Produce + return apiVersion >= 9 + case 11: // JoinGroup + return apiVersion >= 6 + case 14: // SyncGroup + return apiVersion >= 4 + case 8: // OffsetCommit + return apiVersion >= 8 + case 9: // OffsetFetch + return apiVersion >= 6 + case 10: // FindCoordinator + return apiVersion >= 3 + case 12: // Heartbeat + return apiVersion >= 4 + case 13: // LeaveGroup + return apiVersion >= 4 + case 19: // CreateTopics + return apiVersion >= 2 + case 20: // DeleteTopics + return apiVersion >= 4 + default: + return false + } +} + +// FlexibleString encodes a string for flexible versions (compact format) +func FlexibleString(s string) []byte { + if s == "" { + return []byte{0} // Null string + } + + var buf []byte + buf = append(buf, CompactStringLength(len(s))...) + buf = append(buf, []byte(s)...) + return buf +} + +// FlexibleNullableString encodes a nullable string for flexible versions +func FlexibleNullableString(s *string) []byte { + if s == nil { + return []byte{0} // Null string + } + return FlexibleString(*s) +} + +// DecodeFlexibleString decodes a flexible string +// Returns the string (empty for null) and bytes consumed +func DecodeFlexibleString(data []byte) (string, int, error) { + length, consumed, err := DecodeCompactStringLength(data) + if err != nil { + return "", 0, err + } + + if length < 0 { + return "", consumed, nil // Null string -> empty string + } + + if consumed+length > len(data) { + return "", 0, fmt.Errorf("string data truncated") + } + + return string(data[consumed : consumed+length]), consumed + length, nil +} + +// FlexibleVersionHeader handles the request header parsing for flexible versions +type FlexibleVersionHeader struct { + APIKey uint16 + APIVersion uint16 + CorrelationID uint32 + ClientID *string + TaggedFields *TaggedFields +} + +// ParseRequestHeader parses a Kafka request header, handling both regular and flexible versions +func ParseRequestHeader(data []byte) (*FlexibleVersionHeader, []byte, error) { + if len(data) < 8 { + return nil, nil, fmt.Errorf("header too short") + } + + header := &FlexibleVersionHeader{} + offset := 0 + + // API Key (2 bytes) + header.APIKey = binary.BigEndian.Uint16(data[offset : offset+2]) + offset += 2 + + // API Version (2 bytes) + header.APIVersion = binary.BigEndian.Uint16(data[offset : offset+2]) + offset += 2 + + // Correlation ID (4 bytes) + header.CorrelationID = binary.BigEndian.Uint32(data[offset : offset+4]) + offset += 4 + + // Client ID handling depends on flexible version + isFlexible := IsFlexibleVersion(header.APIKey, header.APIVersion) + + if isFlexible { + // Flexible versions use compact strings + clientID, consumed, err := DecodeFlexibleString(data[offset:]) + if err != nil { + return nil, nil, fmt.Errorf("decode flexible client_id: %w", err) + } + offset += consumed + + if clientID != "" { + header.ClientID = &clientID + } + + // Parse tagged fields in header + taggedFields, consumed, err := DecodeTaggedFields(data[offset:]) + if err != nil { + return nil, nil, fmt.Errorf("decode header tagged fields: %w", err) + } + offset += consumed + header.TaggedFields = taggedFields + + } else { + // Regular versions use standard strings + if len(data) < offset+2 { + return nil, nil, fmt.Errorf("missing client_id length") + } + + clientIDLen := int16(binary.BigEndian.Uint16(data[offset : offset+2])) + offset += 2 + + if clientIDLen >= 0 { + if len(data) < offset+int(clientIDLen) { + return nil, nil, fmt.Errorf("client_id truncated") + } + + clientID := string(data[offset : offset+int(clientIDLen)]) + header.ClientID = &clientID + offset += int(clientIDLen) + } + // No tagged fields in regular versions + } + + return header, data[offset:], nil +} + +// EncodeFlexibleResponse encodes a response with proper flexible version formatting +func EncodeFlexibleResponse(correlationID uint32, data []byte, hasTaggedFields bool) []byte { + response := make([]byte, 4) + binary.BigEndian.PutUint32(response, correlationID) + response = append(response, data...) + + if hasTaggedFields { + // Add empty tagged fields for flexible responses + response = append(response, 0) + } + + return response +} diff --git a/weed/mq/kafka/protocol/flexible_versions_integration_test.go b/weed/mq/kafka/protocol/flexible_versions_integration_test.go new file mode 100644 index 000000000..5fbf7d19e --- /dev/null +++ b/weed/mq/kafka/protocol/flexible_versions_integration_test.go @@ -0,0 +1,305 @@ +package protocol + +import ( + "encoding/binary" + "testing" +) + +func TestApiVersions_FlexibleVersionSupport(t *testing.T) { + handler := NewTestHandler() + + testCases := []struct { + name string + apiVersion uint16 + expectFlexible bool + }{ + {"ApiVersions v0", 0, false}, + {"ApiVersions v1", 1, false}, + {"ApiVersions v2", 2, false}, + {"ApiVersions v3", 3, true}, + {"ApiVersions v4", 4, true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + correlationID := uint32(12345) + + response, err := handler.handleApiVersions(correlationID, tc.apiVersion) + if err != nil { + t.Fatalf("handleApiVersions failed: %v", err) + } + + if len(response) < 4 { + t.Fatalf("Response too short: %d bytes", len(response)) + } + + // Check correlation ID + respCorrelationID := binary.BigEndian.Uint32(response[0:4]) + if respCorrelationID != correlationID { + t.Errorf("Correlation ID = %d, want %d", respCorrelationID, correlationID) + } + + // Check error code + errorCode := binary.BigEndian.Uint16(response[4:6]) + if errorCode != 0 { + t.Errorf("Error code = %d, want 0", errorCode) + } + + // Parse API keys count based on version + offset := 6 + var apiKeysCount uint32 + + if tc.expectFlexible { + // Should use compact array format + count, consumed, err := DecodeCompactArrayLength(response[offset:]) + if err != nil { + t.Fatalf("Failed to decode compact array length: %v", err) + } + apiKeysCount = count + offset += consumed + } else { + // Should use regular array format + if len(response) < offset+4 { + t.Fatalf("Response too short for regular array length") + } + apiKeysCount = binary.BigEndian.Uint32(response[offset:offset+4]) + offset += 4 + } + + if apiKeysCount != 14 { + t.Errorf("API keys count = %d, want 14", apiKeysCount) + } + + // Verify that we have enough data for all API keys + // Each API key entry is 6 bytes: api_key(2) + min_version(2) + max_version(2) + expectedMinSize := offset + int(apiKeysCount*6) + if tc.expectFlexible { + expectedMinSize += 1 // tagged fields + } + + if len(response) < expectedMinSize { + t.Errorf("Response too short: got %d bytes, expected at least %d", len(response), expectedMinSize) + } + + // Check that ApiVersions API itself is properly listed + // API Key 18 should be the first entry + if len(response) >= offset+6 { + apiKey := binary.BigEndian.Uint16(response[offset:offset+2]) + minVersion := binary.BigEndian.Uint16(response[offset+2:offset+4]) + maxVersion := binary.BigEndian.Uint16(response[offset+4:offset+6]) + + if apiKey != 18 { + t.Errorf("First API key = %d, want 18 (ApiVersions)", apiKey) + } + if minVersion != 0 { + t.Errorf("ApiVersions min version = %d, want 0", minVersion) + } + if maxVersion != 3 { + t.Errorf("ApiVersions max version = %d, want 3", maxVersion) + } + } + + t.Logf("ApiVersions v%d response: %d bytes, flexible: %v", tc.apiVersion, len(response), tc.expectFlexible) + }) + } +} + +func TestFlexibleVersions_HeaderParsingIntegration(t *testing.T) { + testCases := []struct { + name string + apiKey uint16 + apiVersion uint16 + clientID string + expectFlexible bool + }{ + {"Metadata v8 (regular)", 3, 8, "test-client", false}, + {"Metadata v9 (flexible)", 3, 9, "test-client", true}, + {"ApiVersions v2 (regular)", 18, 2, "test-client", false}, + {"ApiVersions v3 (flexible)", 18, 3, "test-client", true}, + {"CreateTopics v1 (regular)", 19, 1, "test-client", false}, + {"CreateTopics v2 (flexible)", 19, 2, "test-client", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Construct request header + var headerData []byte + + // API Key (2 bytes) + headerData = append(headerData, byte(tc.apiKey>>8), byte(tc.apiKey)) + + // API Version (2 bytes) + headerData = append(headerData, byte(tc.apiVersion>>8), byte(tc.apiVersion)) + + // Correlation ID (4 bytes) + correlationID := uint32(54321) + corrBytes := make([]byte, 4) + binary.BigEndian.PutUint32(corrBytes, correlationID) + headerData = append(headerData, corrBytes...) + + if tc.expectFlexible { + // Flexible version: compact string for client ID + headerData = append(headerData, FlexibleString(tc.clientID)...) + // Empty tagged fields + headerData = append(headerData, 0) + } else { + // Regular version: standard string for client ID + clientIDBytes := []byte(tc.clientID) + headerData = append(headerData, byte(len(clientIDBytes)>>8), byte(len(clientIDBytes))) + headerData = append(headerData, clientIDBytes...) + } + + // Add dummy request body + headerData = append(headerData, 1, 2, 3, 4) + + // Parse header + header, body, err := ParseRequestHeader(headerData) + if err != nil { + t.Fatalf("ParseRequestHeader failed: %v", err) + } + + // Validate parsed header + if header.APIKey != tc.apiKey { + t.Errorf("APIKey = %d, want %d", header.APIKey, tc.apiKey) + } + if header.APIVersion != tc.apiVersion { + t.Errorf("APIVersion = %d, want %d", header.APIVersion, tc.apiVersion) + } + if header.CorrelationID != correlationID { + t.Errorf("CorrelationID = %d, want %d", header.CorrelationID, correlationID) + } + if header.ClientID == nil || *header.ClientID != tc.clientID { + t.Errorf("ClientID = %v, want %s", header.ClientID, tc.clientID) + } + + // Check tagged fields presence + hasTaggedFields := header.TaggedFields != nil + if hasTaggedFields != tc.expectFlexible { + t.Errorf("Tagged fields present = %v, want %v", hasTaggedFields, tc.expectFlexible) + } + + // Validate body + expectedBody := []byte{1, 2, 3, 4} + if len(body) != len(expectedBody) { + t.Errorf("Body length = %d, want %d", len(body), len(expectedBody)) + } + for i, b := range expectedBody { + if i < len(body) && body[i] != b { + t.Errorf("Body[%d] = %d, want %d", i, body[i], b) + } + } + + t.Logf("Header parsing for %s v%d: flexible=%v, client=%s", + getAPIName(tc.apiKey), tc.apiVersion, tc.expectFlexible, tc.clientID) + }) + } +} + +func TestCreateTopics_FlexibleVersionConsistency(t *testing.T) { + handler := NewTestHandler() + + // Test that CreateTopics v2+ continues to work correctly with flexible version utilities + correlationID := uint32(99999) + + // Build CreateTopics v2 request using flexible version utilities + var requestData []byte + + // Topics array (compact: 1 topic = 2) + requestData = append(requestData, 2) + + // Topic name (compact string) + topicName := "flexible-test-topic" + requestData = append(requestData, FlexibleString(topicName)...) + + // Number of partitions (4 bytes) + requestData = append(requestData, 0, 0, 0, 3) + + // Replication factor (2 bytes) + requestData = append(requestData, 0, 1) + + // Configs array (compact: empty = 0) + requestData = append(requestData, 0) + + // Tagged fields (empty) + requestData = append(requestData, 0) + + // Timeout (4 bytes) + requestData = append(requestData, 0, 0, 0x27, 0x10) // 10000ms + + // Validate only (1 byte) + requestData = append(requestData, 0) + + // Tagged fields at end + requestData = append(requestData, 0) + + // Call CreateTopics v2 + response, err := handler.handleCreateTopicsV2Plus(correlationID, 2, requestData) + if err != nil { + t.Fatalf("handleCreateTopicsV2Plus failed: %v", err) + } + + if len(response) < 8 { + t.Fatalf("Response too short: %d bytes", len(response)) + } + + // Check correlation ID + respCorrelationID := binary.BigEndian.Uint32(response[0:4]) + if respCorrelationID != correlationID { + t.Errorf("Correlation ID = %d, want %d", respCorrelationID, correlationID) + } + + // Verify topic was created + if !handler.seaweedMQHandler.TopicExists(topicName) { + t.Errorf("Topic '%s' was not created", topicName) + } + + t.Logf("CreateTopics v2 with flexible utilities: topic created successfully") +} + +func BenchmarkFlexibleVersions_HeaderParsing(b *testing.B) { + // Pre-construct headers for different scenarios + scenarios := []struct { + name string + data []byte + }{ + { + name: "Regular_v1", + data: func() []byte { + var data []byte + data = append(data, 0, 3, 0, 1) // Metadata v1 + corrBytes := make([]byte, 4) + binary.BigEndian.PutUint32(corrBytes, 12345) + data = append(data, corrBytes...) + data = append(data, 0, 11, 'b', 'e', 'n', 'c', 'h', '-', 'c', 'l', 'i', 'e', 'n', 't') + data = append(data, 1, 2, 3) + return data + }(), + }, + { + name: "Flexible_v9", + data: func() []byte { + var data []byte + data = append(data, 0, 3, 0, 9) // Metadata v9 + corrBytes := make([]byte, 4) + binary.BigEndian.PutUint32(corrBytes, 12345) + data = append(data, corrBytes...) + data = append(data, FlexibleString("bench-client")...) + data = append(data, 0) // empty tagged fields + data = append(data, 1, 2, 3) + return data + }(), + }, + } + + for _, scenario := range scenarios { + b.Run(scenario.name, func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := ParseRequestHeader(scenario.data) + if err != nil { + b.Fatalf("ParseRequestHeader failed: %v", err) + } + } + }) + } +} diff --git a/weed/mq/kafka/protocol/flexible_versions_test.go b/weed/mq/kafka/protocol/flexible_versions_test.go new file mode 100644 index 000000000..c14487c9e --- /dev/null +++ b/weed/mq/kafka/protocol/flexible_versions_test.go @@ -0,0 +1,486 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "testing" +) + +func TestEncodeDecodeUvarint(t *testing.T) { + testCases := []uint32{ + 0, 1, 127, 128, 255, 256, 16383, 16384, 32767, 32768, 65535, 65536, + 0x1FFFFF, 0x200000, 0x0FFFFFFF, 0x10000000, 0xFFFFFFFF, + } + + for _, value := range testCases { + t.Run(fmt.Sprintf("value_%d", value), func(t *testing.T) { + encoded := EncodeUvarint(value) + decoded, consumed, err := DecodeUvarint(encoded) + + if err != nil { + t.Fatalf("DecodeUvarint failed: %v", err) + } + + if decoded != value { + t.Errorf("Decoded value %d != original %d", decoded, value) + } + + if consumed != len(encoded) { + t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) + } + }) + } +} + +func TestCompactArrayLength(t *testing.T) { + testCases := []struct { + name string + length uint32 + expected []byte + }{ + {"Empty array", 0, []byte{0}}, + {"Single element", 1, []byte{2}}, + {"Small array", 10, []byte{11}}, + {"Large array", 127, []byte{128, 1}}, // 128 = 127+1 encoded as varint (two bytes since >= 128) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encoded := CompactArrayLength(tc.length) + if !bytes.Equal(encoded, tc.expected) { + t.Errorf("CompactArrayLength(%d) = %v, want %v", tc.length, encoded, tc.expected) + } + + // Test round trip + decoded, consumed, err := DecodeCompactArrayLength(encoded) + if err != nil { + t.Fatalf("DecodeCompactArrayLength failed: %v", err) + } + + if decoded != tc.length { + t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length) + } + + if consumed != len(encoded) { + t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) + } + }) + } +} + +func TestCompactStringLength(t *testing.T) { + testCases := []struct { + name string + length int + expected []byte + }{ + {"Null string", -1, []byte{0}}, + {"Empty string", 0, []byte{1}}, + {"Short string", 5, []byte{6}}, + {"Medium string", 100, []byte{101}}, // 101 encoded as varint (single byte since < 128) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encoded := CompactStringLength(tc.length) + if !bytes.Equal(encoded, tc.expected) { + t.Errorf("CompactStringLength(%d) = %v, want %v", tc.length, encoded, tc.expected) + } + + // Test round trip + decoded, consumed, err := DecodeCompactStringLength(encoded) + if err != nil { + t.Fatalf("DecodeCompactStringLength failed: %v", err) + } + + if decoded != tc.length { + t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length) + } + + if consumed != len(encoded) { + t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) + } + }) + } +} + +func TestFlexibleString(t *testing.T) { + testCases := []struct { + name string + input string + expected []byte + }{ + {"Empty string", "", []byte{0}}, + {"Hello", "hello", []byte{6, 'h', 'e', 'l', 'l', 'o'}}, + {"Unicode", "测试", []byte{7, 0xE6, 0xB5, 0x8B, 0xE8, 0xAF, 0x95}}, // UTF-8 encoded + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encoded := FlexibleString(tc.input) + if !bytes.Equal(encoded, tc.expected) { + t.Errorf("FlexibleString(%q) = %v, want %v", tc.input, encoded, tc.expected) + } + + // Test round trip + decoded, consumed, err := DecodeFlexibleString(encoded) + if err != nil { + t.Fatalf("DecodeFlexibleString failed: %v", err) + } + + if decoded != tc.input { + t.Errorf("Round trip failed: got %q, want %q", decoded, tc.input) + } + + if consumed != len(encoded) { + t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) + } + }) + } +} + +func TestFlexibleNullableString(t *testing.T) { + // Null string + nullResult := FlexibleNullableString(nil) + expected := []byte{0} + if !bytes.Equal(nullResult, expected) { + t.Errorf("FlexibleNullableString(nil) = %v, want %v", nullResult, expected) + } + + // Non-null string + str := "test" + nonNullResult := FlexibleNullableString(&str) + expectedNonNull := []byte{5, 't', 'e', 's', 't'} + if !bytes.Equal(nonNullResult, expectedNonNull) { + t.Errorf("FlexibleNullableString(&%q) = %v, want %v", str, nonNullResult, expectedNonNull) + } +} + +func TestTaggedFields(t *testing.T) { + t.Run("Empty tagged fields", func(t *testing.T) { + tf := &TaggedFields{} + encoded := tf.Encode() + expected := []byte{0} + + if !bytes.Equal(encoded, expected) { + t.Errorf("Empty TaggedFields.Encode() = %v, want %v", encoded, expected) + } + + // Test round trip + decoded, consumed, err := DecodeTaggedFields(encoded) + if err != nil { + t.Fatalf("DecodeTaggedFields failed: %v", err) + } + + if len(decoded.Fields) != 0 { + t.Errorf("Decoded tagged fields length = %d, want 0", len(decoded.Fields)) + } + + if consumed != len(encoded) { + t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) + } + }) + + t.Run("Single tagged field", func(t *testing.T) { + tf := &TaggedFields{ + Fields: []TaggedField{ + {Tag: 1, Data: []byte("test")}, + }, + } + + encoded := tf.Encode() + + // Test round trip + decoded, consumed, err := DecodeTaggedFields(encoded) + if err != nil { + t.Fatalf("DecodeTaggedFields failed: %v", err) + } + + if len(decoded.Fields) != 1 { + t.Fatalf("Decoded tagged fields length = %d, want 1", len(decoded.Fields)) + } + + field := decoded.Fields[0] + if field.Tag != 1 { + t.Errorf("Decoded tag = %d, want 1", field.Tag) + } + + if !bytes.Equal(field.Data, []byte("test")) { + t.Errorf("Decoded data = %v, want %v", field.Data, []byte("test")) + } + + if consumed != len(encoded) { + t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) + } + }) + + t.Run("Multiple tagged fields", func(t *testing.T) { + tf := &TaggedFields{ + Fields: []TaggedField{ + {Tag: 1, Data: []byte("first")}, + {Tag: 5, Data: []byte("second")}, + }, + } + + encoded := tf.Encode() + + // Test round trip + decoded, consumed, err := DecodeTaggedFields(encoded) + if err != nil { + t.Fatalf("DecodeTaggedFields failed: %v", err) + } + + if len(decoded.Fields) != 2 { + t.Fatalf("Decoded tagged fields length = %d, want 2", len(decoded.Fields)) + } + + // Check first field + field1 := decoded.Fields[0] + if field1.Tag != 1 { + t.Errorf("Decoded field 1 tag = %d, want 1", field1.Tag) + } + if !bytes.Equal(field1.Data, []byte("first")) { + t.Errorf("Decoded field 1 data = %v, want %v", field1.Data, []byte("first")) + } + + // Check second field + field2 := decoded.Fields[1] + if field2.Tag != 5 { + t.Errorf("Decoded field 2 tag = %d, want 5", field2.Tag) + } + if !bytes.Equal(field2.Data, []byte("second")) { + t.Errorf("Decoded field 2 data = %v, want %v", field2.Data, []byte("second")) + } + + if consumed != len(encoded) { + t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) + } + }) +} + +func TestIsFlexibleVersion(t *testing.T) { + testCases := []struct { + apiKey uint16 + apiVersion uint16 + expected bool + name string + }{ + // ApiVersions + {18, 2, false, "ApiVersions v2"}, + {18, 3, true, "ApiVersions v3"}, + {18, 4, true, "ApiVersions v4"}, + + // Metadata + {3, 8, false, "Metadata v8"}, + {3, 9, true, "Metadata v9"}, + {3, 10, true, "Metadata v10"}, + + // Fetch + {1, 11, false, "Fetch v11"}, + {1, 12, true, "Fetch v12"}, + {1, 13, true, "Fetch v13"}, + + // Produce + {0, 8, false, "Produce v8"}, + {0, 9, true, "Produce v9"}, + {0, 10, true, "Produce v10"}, + + // CreateTopics + {19, 1, false, "CreateTopics v1"}, + {19, 2, true, "CreateTopics v2"}, + {19, 3, true, "CreateTopics v3"}, + + // Unknown API + {99, 1, false, "Unknown API"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsFlexibleVersion(tc.apiKey, tc.apiVersion) + if result != tc.expected { + t.Errorf("IsFlexibleVersion(%d, %d) = %v, want %v", + tc.apiKey, tc.apiVersion, result, tc.expected) + } + }) + } +} + +func TestParseRequestHeader(t *testing.T) { + t.Run("Regular version header", func(t *testing.T) { + // Construct a regular version header (Metadata v1) + data := make([]byte, 0) + data = append(data, 0, 3) // API Key = 3 (Metadata) + data = append(data, 0, 1) // API Version = 1 + data = append(data, 0, 0, 0, 123) // Correlation ID = 123 + data = append(data, 0, 4) // Client ID length = 4 + data = append(data, 't', 'e', 's', 't') // Client ID = "test" + data = append(data, 1, 2, 3) // Request body + + header, body, err := ParseRequestHeader(data) + if err != nil { + t.Fatalf("ParseRequestHeader failed: %v", err) + } + + if header.APIKey != 3 { + t.Errorf("APIKey = %d, want 3", header.APIKey) + } + if header.APIVersion != 1 { + t.Errorf("APIVersion = %d, want 1", header.APIVersion) + } + if header.CorrelationID != 123 { + t.Errorf("CorrelationID = %d, want 123", header.CorrelationID) + } + if header.ClientID == nil || *header.ClientID != "test" { + t.Errorf("ClientID = %v, want 'test'", header.ClientID) + } + if header.TaggedFields != nil { + t.Errorf("TaggedFields should be nil for regular versions") + } + + expectedBody := []byte{1, 2, 3} + if !bytes.Equal(body, expectedBody) { + t.Errorf("Body = %v, want %v", body, expectedBody) + } + }) + + t.Run("Flexible version header", func(t *testing.T) { + // Construct a flexible version header (ApiVersions v3) + data := make([]byte, 0) + data = append(data, 0, 18) // API Key = 18 (ApiVersions) + data = append(data, 0, 3) // API Version = 3 (flexible) + + // Correlation ID = 456 (4 bytes, big endian) + correlationID := make([]byte, 4) + binary.BigEndian.PutUint32(correlationID, 456) + data = append(data, correlationID...) + + data = append(data, 5, 't', 'e', 's', 't') // Client ID = "test" (compact string) + data = append(data, 0) // Empty tagged fields + data = append(data, 4, 5, 6) // Request body + + header, body, err := ParseRequestHeader(data) + if err != nil { + t.Fatalf("ParseRequestHeader failed: %v", err) + } + + if header.APIKey != 18 { + t.Errorf("APIKey = %d, want 18", header.APIKey) + } + if header.APIVersion != 3 { + t.Errorf("APIVersion = %d, want 3", header.APIVersion) + } + if header.CorrelationID != 456 { + t.Errorf("CorrelationID = %d, want 456", header.CorrelationID) + } + if header.ClientID == nil || *header.ClientID != "test" { + t.Errorf("ClientID = %v, want 'test'", header.ClientID) + } + if header.TaggedFields == nil { + t.Errorf("TaggedFields should not be nil for flexible versions") + } + if len(header.TaggedFields.Fields) != 0 { + t.Errorf("TaggedFields should be empty") + } + + expectedBody := []byte{4, 5, 6} + if !bytes.Equal(body, expectedBody) { + t.Errorf("Body = %v, want %v", body, expectedBody) + } + }) + + t.Run("Null client ID", func(t *testing.T) { + // Regular version with null client ID + data := make([]byte, 0) + data = append(data, 0, 3) // API Key = 3 (Metadata) + data = append(data, 0, 1) // API Version = 1 + + // Correlation ID = 789 (4 bytes, big endian) + correlationID := make([]byte, 4) + binary.BigEndian.PutUint32(correlationID, 789) + data = append(data, correlationID...) + + data = append(data, 0xFF, 0xFF) // Client ID length = -1 (null) + data = append(data, 7, 8, 9) // Request body + + header, body, err := ParseRequestHeader(data) + if err != nil { + t.Fatalf("ParseRequestHeader failed: %v", err) + } + + if header.ClientID != nil { + t.Errorf("ClientID = %v, want nil", header.ClientID) + } + + expectedBody := []byte{7, 8, 9} + if !bytes.Equal(body, expectedBody) { + t.Errorf("Body = %v, want %v", body, expectedBody) + } + }) +} + +func TestEncodeFlexibleResponse(t *testing.T) { + correlationID := uint32(123) + data := []byte{1, 2, 3, 4} + + t.Run("Without tagged fields", func(t *testing.T) { + result := EncodeFlexibleResponse(correlationID, data, false) + expected := []byte{0, 0, 0, 123, 1, 2, 3, 4} + + if !bytes.Equal(result, expected) { + t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected) + } + }) + + t.Run("With tagged fields", func(t *testing.T) { + result := EncodeFlexibleResponse(correlationID, data, true) + expected := []byte{0, 0, 0, 123, 1, 2, 3, 4, 0} // 0 at end for empty tagged fields + + if !bytes.Equal(result, expected) { + t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected) + } + }) +} + +func BenchmarkEncodeUvarint(b *testing.B) { + testValues := []uint32{0, 127, 128, 16383, 16384, 65535, 65536, 0xFFFFFFFF} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, val := range testValues { + EncodeUvarint(val) + } + } +} + +func BenchmarkDecodeUvarint(b *testing.B) { + // Pre-encode test values + testData := [][]byte{ + EncodeUvarint(0), + EncodeUvarint(127), + EncodeUvarint(128), + EncodeUvarint(16383), + EncodeUvarint(16384), + EncodeUvarint(65535), + EncodeUvarint(65536), + EncodeUvarint(0xFFFFFFFF), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, data := range testData { + DecodeUvarint(data) + } + } +} + +func BenchmarkFlexibleString(b *testing.B) { + testStrings := []string{"", "a", "hello", "this is a longer test string", "测试中文字符串"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, s := range testStrings { + FlexibleString(s) + } + } +} diff --git a/weed/mq/kafka/protocol/handler.go b/weed/mq/kafka/protocol/handler.go index 3d182e61b..a9e048086 100644 --- a/weed/mq/kafka/protocol/handler.go +++ b/weed/mq/kafka/protocol/handler.go @@ -300,24 +300,49 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { continue } - // Strip client_id (nullable STRING) from header to get pure request body - bodyOffset := 8 - if len(messageBuf) < bodyOffset+2 { - return fmt.Errorf("invalid header: missing client_id length") - } - clientIDLen := int16(binary.BigEndian.Uint16(messageBuf[bodyOffset : bodyOffset+2])) - bodyOffset += 2 - if clientIDLen >= 0 { - if len(messageBuf) < bodyOffset+int(clientIDLen) { - return fmt.Errorf("invalid header: client_id truncated") + // Parse header using flexible version utilities for validation and client ID extraction + header, requestBody, parseErr := ParseRequestHeader(messageBuf) + if parseErr != nil { + // Fall back to basic header parsing if flexible version parsing fails + fmt.Printf("DEBUG: Flexible header parsing failed, using basic parsing: %v\n", parseErr) + + // Basic header parsing fallback (original logic) + bodyOffset := 8 + if len(messageBuf) < bodyOffset+2 { + return fmt.Errorf("invalid header: missing client_id length") + } + clientIDLen := int16(binary.BigEndian.Uint16(messageBuf[bodyOffset : bodyOffset+2])) + bodyOffset += 2 + if clientIDLen >= 0 { + if len(messageBuf) < bodyOffset+int(clientIDLen) { + return fmt.Errorf("invalid header: client_id truncated") + } + bodyOffset += int(clientIDLen) } - // clientID := string(messageBuf[bodyOffset : bodyOffset+int(clientIDLen)]) - bodyOffset += int(clientIDLen) + requestBody = messageBuf[bodyOffset:] } else { - // client_id is null; nothing to skip + // Validate parsed header matches what we already extracted + if header.APIKey != apiKey || header.APIVersion != apiVersion || header.CorrelationID != correlationID { + fmt.Printf("DEBUG: Header parsing mismatch - using basic parsing as fallback\n") + // Fall back to basic parsing rather than failing + bodyOffset := 8 + if len(messageBuf) < bodyOffset+2 { + return fmt.Errorf("invalid header: missing client_id length") + } + clientIDLen := int16(binary.BigEndian.Uint16(messageBuf[bodyOffset : bodyOffset+2])) + bodyOffset += 2 + if clientIDLen >= 0 { + if len(messageBuf) < bodyOffset+int(clientIDLen) { + return fmt.Errorf("invalid header: client_id truncated") + } + bodyOffset += int(clientIDLen) + } + requestBody = messageBuf[bodyOffset:] + } else if header.ClientID != nil { + // Log client ID if available and parsing was successful + fmt.Printf("DEBUG: Client ID: %s\n", *header.ClientID) + } } - // TODO: Flexible versions have tagged fields in header; ignored for now - requestBody := messageBuf[bodyOffset:] // Handle the request based on API key and version var response []byte @@ -325,7 +350,7 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { switch apiKey { case 18: // ApiVersions - response, err = h.handleApiVersions(correlationID) + response, err = h.handleApiVersions(correlationID, apiVersion) case 3: // Metadata response, err = h.handleMetadata(correlationID, apiVersion, requestBody) case 2: // ListOffsets @@ -410,11 +435,11 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { } } -func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { - // Build ApiVersions response manually - // Response format (v0): correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys - - response := make([]byte, 0, 64) +func (h *Handler) handleApiVersions(correlationID uint32, apiVersion uint16) ([]byte, error) { + // Build ApiVersions response supporting flexible versions (v3+) + isFlexible := IsFlexibleVersion(18, apiVersion) + + response := make([]byte, 0, 128) // Correlation ID correlationIDBytes := make([]byte, 4) @@ -424,8 +449,15 @@ func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { // Error code (0 = no error) response = append(response, 0, 0) - // Number of API keys (compact array format in newer versions, but using basic format for simplicity) - response = append(response, 0, 0, 0, 14) // 14 API keys + // Number of API keys - use compact or regular array format based on version + apiKeysCount := uint32(14) + if isFlexible { + // Compact array format for flexible versions + response = append(response, CompactArrayLength(apiKeysCount)...) + } else { + // Regular array format for older versions + response = append(response, 0, 0, 0, 14) // 14 API keys + } // API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2) response = append(response, 0, 18) // API key 18 @@ -500,7 +532,13 @@ func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { response = append(response, 0, 0) // min version 0 response = append(response, 0, 4) // max version 4 - fmt.Printf("DEBUG: ApiVersions v0 response: %d bytes\n", len(response)) + // Add tagged fields for flexible versions + if isFlexible { + // Empty tagged fields for now + response = append(response, 0) + } + + fmt.Printf("DEBUG: ApiVersions v%d response: %d bytes\n", apiVersion, len(response)) return response, nil }