Browse Source

Phase 6: Add basic flexible versions support

- Added flexible_versions.go with utilities for Kafka flexible versions (v3+)
- Implemented ParseRequestHeader for compact string parsing and tagged fields
- Added fallback mechanism in handler.go for backward compatibility
- Updated handleApiVersions to support flexible version responses
- Added comprehensive tests for flexible version utilities
- All protocol tests passing with robust error handling
pull/7231/head
chrislu 2 months ago
parent
commit
2e2ccbf488
  1. 4
      weed/mq/kafka/protocol/api_versions_test.go
  2. 359
      weed/mq/kafka/protocol/flexible_versions.go
  3. 305
      weed/mq/kafka/protocol/flexible_versions_integration_test.go
  4. 486
      weed/mq/kafka/protocol/flexible_versions_test.go
  5. 86
      weed/mq/kafka/protocol/handler.go

4
weed/mq/kafka/protocol/api_versions_test.go

@ -10,7 +10,7 @@ func TestApiVersions_AdvertisedVersionsMatch(t *testing.T) {
handler := NewTestHandler()
defer handler.Close()
response, err := handler.handleApiVersions(12345)
response, err := handler.handleApiVersions(12345, 0)
if err != nil {
t.Fatalf("handleApiVersions failed: %v", err)
}
@ -253,7 +253,7 @@ func TestApiVersions_ResponseFormat(t *testing.T) {
handler := NewTestHandler()
defer handler.Close()
response, err := handler.handleApiVersions(99999)
response, err := handler.handleApiVersions(99999, 0)
if err != nil {
t.Fatalf("handleApiVersions failed: %v", err)
}

359
weed/mq/kafka/protocol/flexible_versions.go

@ -0,0 +1,359 @@
package protocol
import (
"encoding/binary"
"fmt"
)
// FlexibleVersions provides utilities for handling Kafka flexible versions protocol
// Flexible versions use compact arrays/strings and tagged fields for backward compatibility
// CompactArrayLength encodes a length for compact arrays
// Compact arrays encode length as length+1, where 0 means empty array
func CompactArrayLength(length uint32) []byte {
if length == 0 {
return []byte{0} // Empty array
}
return EncodeUvarint(length + 1)
}
// DecodeCompactArrayLength decodes a compact array length
// Returns the actual length and number of bytes consumed
func DecodeCompactArrayLength(data []byte) (uint32, int, error) {
if len(data) == 0 {
return 0, 0, fmt.Errorf("no data for compact array length")
}
if data[0] == 0 {
return 0, 1, nil // Empty array
}
length, consumed, err := DecodeUvarint(data)
if err != nil {
return 0, 0, fmt.Errorf("decode compact array length: %w", err)
}
if length == 0 {
return 0, consumed, fmt.Errorf("invalid compact array length encoding")
}
return length - 1, consumed, nil
}
// CompactStringLength encodes a length for compact strings
// Compact strings encode length as length+1, where 0 means null string
func CompactStringLength(length int) []byte {
if length < 0 {
return []byte{0} // Null string
}
return EncodeUvarint(uint32(length + 1))
}
// DecodeCompactStringLength decodes a compact string length
// Returns the actual length (-1 for null), and number of bytes consumed
func DecodeCompactStringLength(data []byte) (int, int, error) {
if len(data) == 0 {
return 0, 0, fmt.Errorf("no data for compact string length")
}
if data[0] == 0 {
return -1, 1, nil // Null string
}
length, consumed, err := DecodeUvarint(data)
if err != nil {
return 0, 0, fmt.Errorf("decode compact string length: %w", err)
}
if length == 0 {
return 0, consumed, fmt.Errorf("invalid compact string length encoding")
}
return int(length - 1), consumed, nil
}
// EncodeUvarint encodes an unsigned integer using variable-length encoding
// This is used for compact arrays, strings, and tagged fields
func EncodeUvarint(value uint32) []byte {
var buf []byte
for value >= 0x80 {
buf = append(buf, byte(value)|0x80)
value >>= 7
}
buf = append(buf, byte(value))
return buf
}
// DecodeUvarint decodes a variable-length unsigned integer
// Returns the decoded value and number of bytes consumed
func DecodeUvarint(data []byte) (uint32, int, error) {
var value uint32
var shift uint
var consumed int
for i, b := range data {
consumed = i + 1
value |= uint32(b&0x7F) << shift
if (b & 0x80) == 0 {
return value, consumed, nil
}
shift += 7
if shift >= 32 {
return 0, consumed, fmt.Errorf("uvarint overflow")
}
}
return 0, consumed, fmt.Errorf("incomplete uvarint")
}
// TaggedField represents a tagged field in flexible versions
type TaggedField struct {
Tag uint32
Data []byte
}
// TaggedFields represents a collection of tagged fields
type TaggedFields struct {
Fields []TaggedField
}
// EncodeTaggedFields encodes tagged fields for flexible versions
func (tf *TaggedFields) Encode() []byte {
if len(tf.Fields) == 0 {
return []byte{0} // Empty tagged fields
}
var buf []byte
// Number of tagged fields
buf = append(buf, EncodeUvarint(uint32(len(tf.Fields)))...)
for _, field := range tf.Fields {
// Tag
buf = append(buf, EncodeUvarint(field.Tag)...)
// Size
buf = append(buf, EncodeUvarint(uint32(len(field.Data)))...)
// Data
buf = append(buf, field.Data...)
}
return buf
}
// DecodeTaggedFields decodes tagged fields from flexible versions
func DecodeTaggedFields(data []byte) (*TaggedFields, int, error) {
if len(data) == 0 {
return &TaggedFields{}, 0, fmt.Errorf("no data for tagged fields")
}
if data[0] == 0 {
return &TaggedFields{}, 1, nil // Empty tagged fields
}
offset := 0
// Number of tagged fields
numFields, consumed, err := DecodeUvarint(data[offset:])
if err != nil {
return nil, 0, fmt.Errorf("decode tagged fields count: %w", err)
}
offset += consumed
fields := make([]TaggedField, numFields)
for i := uint32(0); i < numFields; i++ {
// Tag
tag, consumed, err := DecodeUvarint(data[offset:])
if err != nil {
return nil, 0, fmt.Errorf("decode tagged field %d tag: %w", i, err)
}
offset += consumed
// Size
size, consumed, err := DecodeUvarint(data[offset:])
if err != nil {
return nil, 0, fmt.Errorf("decode tagged field %d size: %w", i, err)
}
offset += consumed
// Data
if offset+int(size) > len(data) {
return nil, 0, fmt.Errorf("tagged field %d data truncated", i)
}
fields[i] = TaggedField{
Tag: tag,
Data: data[offset : offset+int(size)],
}
offset += int(size)
}
return &TaggedFields{Fields: fields}, offset, nil
}
// IsFlexibleVersion determines if an API version uses flexible versions
// This is API-specific and based on when each API adopted flexible versions
func IsFlexibleVersion(apiKey, apiVersion uint16) bool {
switch apiKey {
case 18: // ApiVersions
return apiVersion >= 3
case 3: // Metadata
return apiVersion >= 9
case 1: // Fetch
return apiVersion >= 12
case 0: // Produce
return apiVersion >= 9
case 11: // JoinGroup
return apiVersion >= 6
case 14: // SyncGroup
return apiVersion >= 4
case 8: // OffsetCommit
return apiVersion >= 8
case 9: // OffsetFetch
return apiVersion >= 6
case 10: // FindCoordinator
return apiVersion >= 3
case 12: // Heartbeat
return apiVersion >= 4
case 13: // LeaveGroup
return apiVersion >= 4
case 19: // CreateTopics
return apiVersion >= 2
case 20: // DeleteTopics
return apiVersion >= 4
default:
return false
}
}
// FlexibleString encodes a string for flexible versions (compact format)
func FlexibleString(s string) []byte {
if s == "" {
return []byte{0} // Null string
}
var buf []byte
buf = append(buf, CompactStringLength(len(s))...)
buf = append(buf, []byte(s)...)
return buf
}
// FlexibleNullableString encodes a nullable string for flexible versions
func FlexibleNullableString(s *string) []byte {
if s == nil {
return []byte{0} // Null string
}
return FlexibleString(*s)
}
// DecodeFlexibleString decodes a flexible string
// Returns the string (empty for null) and bytes consumed
func DecodeFlexibleString(data []byte) (string, int, error) {
length, consumed, err := DecodeCompactStringLength(data)
if err != nil {
return "", 0, err
}
if length < 0 {
return "", consumed, nil // Null string -> empty string
}
if consumed+length > len(data) {
return "", 0, fmt.Errorf("string data truncated")
}
return string(data[consumed : consumed+length]), consumed + length, nil
}
// FlexibleVersionHeader handles the request header parsing for flexible versions
type FlexibleVersionHeader struct {
APIKey uint16
APIVersion uint16
CorrelationID uint32
ClientID *string
TaggedFields *TaggedFields
}
// ParseRequestHeader parses a Kafka request header, handling both regular and flexible versions
func ParseRequestHeader(data []byte) (*FlexibleVersionHeader, []byte, error) {
if len(data) < 8 {
return nil, nil, fmt.Errorf("header too short")
}
header := &FlexibleVersionHeader{}
offset := 0
// API Key (2 bytes)
header.APIKey = binary.BigEndian.Uint16(data[offset : offset+2])
offset += 2
// API Version (2 bytes)
header.APIVersion = binary.BigEndian.Uint16(data[offset : offset+2])
offset += 2
// Correlation ID (4 bytes)
header.CorrelationID = binary.BigEndian.Uint32(data[offset : offset+4])
offset += 4
// Client ID handling depends on flexible version
isFlexible := IsFlexibleVersion(header.APIKey, header.APIVersion)
if isFlexible {
// Flexible versions use compact strings
clientID, consumed, err := DecodeFlexibleString(data[offset:])
if err != nil {
return nil, nil, fmt.Errorf("decode flexible client_id: %w", err)
}
offset += consumed
if clientID != "" {
header.ClientID = &clientID
}
// Parse tagged fields in header
taggedFields, consumed, err := DecodeTaggedFields(data[offset:])
if err != nil {
return nil, nil, fmt.Errorf("decode header tagged fields: %w", err)
}
offset += consumed
header.TaggedFields = taggedFields
} else {
// Regular versions use standard strings
if len(data) < offset+2 {
return nil, nil, fmt.Errorf("missing client_id length")
}
clientIDLen := int16(binary.BigEndian.Uint16(data[offset : offset+2]))
offset += 2
if clientIDLen >= 0 {
if len(data) < offset+int(clientIDLen) {
return nil, nil, fmt.Errorf("client_id truncated")
}
clientID := string(data[offset : offset+int(clientIDLen)])
header.ClientID = &clientID
offset += int(clientIDLen)
}
// No tagged fields in regular versions
}
return header, data[offset:], nil
}
// EncodeFlexibleResponse encodes a response with proper flexible version formatting
func EncodeFlexibleResponse(correlationID uint32, data []byte, hasTaggedFields bool) []byte {
response := make([]byte, 4)
binary.BigEndian.PutUint32(response, correlationID)
response = append(response, data...)
if hasTaggedFields {
// Add empty tagged fields for flexible responses
response = append(response, 0)
}
return response
}

305
weed/mq/kafka/protocol/flexible_versions_integration_test.go

@ -0,0 +1,305 @@
package protocol
import (
"encoding/binary"
"testing"
)
func TestApiVersions_FlexibleVersionSupport(t *testing.T) {
handler := NewTestHandler()
testCases := []struct {
name string
apiVersion uint16
expectFlexible bool
}{
{"ApiVersions v0", 0, false},
{"ApiVersions v1", 1, false},
{"ApiVersions v2", 2, false},
{"ApiVersions v3", 3, true},
{"ApiVersions v4", 4, true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
correlationID := uint32(12345)
response, err := handler.handleApiVersions(correlationID, tc.apiVersion)
if err != nil {
t.Fatalf("handleApiVersions failed: %v", err)
}
if len(response) < 4 {
t.Fatalf("Response too short: %d bytes", len(response))
}
// Check correlation ID
respCorrelationID := binary.BigEndian.Uint32(response[0:4])
if respCorrelationID != correlationID {
t.Errorf("Correlation ID = %d, want %d", respCorrelationID, correlationID)
}
// Check error code
errorCode := binary.BigEndian.Uint16(response[4:6])
if errorCode != 0 {
t.Errorf("Error code = %d, want 0", errorCode)
}
// Parse API keys count based on version
offset := 6
var apiKeysCount uint32
if tc.expectFlexible {
// Should use compact array format
count, consumed, err := DecodeCompactArrayLength(response[offset:])
if err != nil {
t.Fatalf("Failed to decode compact array length: %v", err)
}
apiKeysCount = count
offset += consumed
} else {
// Should use regular array format
if len(response) < offset+4 {
t.Fatalf("Response too short for regular array length")
}
apiKeysCount = binary.BigEndian.Uint32(response[offset:offset+4])
offset += 4
}
if apiKeysCount != 14 {
t.Errorf("API keys count = %d, want 14", apiKeysCount)
}
// Verify that we have enough data for all API keys
// Each API key entry is 6 bytes: api_key(2) + min_version(2) + max_version(2)
expectedMinSize := offset + int(apiKeysCount*6)
if tc.expectFlexible {
expectedMinSize += 1 // tagged fields
}
if len(response) < expectedMinSize {
t.Errorf("Response too short: got %d bytes, expected at least %d", len(response), expectedMinSize)
}
// Check that ApiVersions API itself is properly listed
// API Key 18 should be the first entry
if len(response) >= offset+6 {
apiKey := binary.BigEndian.Uint16(response[offset:offset+2])
minVersion := binary.BigEndian.Uint16(response[offset+2:offset+4])
maxVersion := binary.BigEndian.Uint16(response[offset+4:offset+6])
if apiKey != 18 {
t.Errorf("First API key = %d, want 18 (ApiVersions)", apiKey)
}
if minVersion != 0 {
t.Errorf("ApiVersions min version = %d, want 0", minVersion)
}
if maxVersion != 3 {
t.Errorf("ApiVersions max version = %d, want 3", maxVersion)
}
}
t.Logf("ApiVersions v%d response: %d bytes, flexible: %v", tc.apiVersion, len(response), tc.expectFlexible)
})
}
}
func TestFlexibleVersions_HeaderParsingIntegration(t *testing.T) {
testCases := []struct {
name string
apiKey uint16
apiVersion uint16
clientID string
expectFlexible bool
}{
{"Metadata v8 (regular)", 3, 8, "test-client", false},
{"Metadata v9 (flexible)", 3, 9, "test-client", true},
{"ApiVersions v2 (regular)", 18, 2, "test-client", false},
{"ApiVersions v3 (flexible)", 18, 3, "test-client", true},
{"CreateTopics v1 (regular)", 19, 1, "test-client", false},
{"CreateTopics v2 (flexible)", 19, 2, "test-client", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Construct request header
var headerData []byte
// API Key (2 bytes)
headerData = append(headerData, byte(tc.apiKey>>8), byte(tc.apiKey))
// API Version (2 bytes)
headerData = append(headerData, byte(tc.apiVersion>>8), byte(tc.apiVersion))
// Correlation ID (4 bytes)
correlationID := uint32(54321)
corrBytes := make([]byte, 4)
binary.BigEndian.PutUint32(corrBytes, correlationID)
headerData = append(headerData, corrBytes...)
if tc.expectFlexible {
// Flexible version: compact string for client ID
headerData = append(headerData, FlexibleString(tc.clientID)...)
// Empty tagged fields
headerData = append(headerData, 0)
} else {
// Regular version: standard string for client ID
clientIDBytes := []byte(tc.clientID)
headerData = append(headerData, byte(len(clientIDBytes)>>8), byte(len(clientIDBytes)))
headerData = append(headerData, clientIDBytes...)
}
// Add dummy request body
headerData = append(headerData, 1, 2, 3, 4)
// Parse header
header, body, err := ParseRequestHeader(headerData)
if err != nil {
t.Fatalf("ParseRequestHeader failed: %v", err)
}
// Validate parsed header
if header.APIKey != tc.apiKey {
t.Errorf("APIKey = %d, want %d", header.APIKey, tc.apiKey)
}
if header.APIVersion != tc.apiVersion {
t.Errorf("APIVersion = %d, want %d", header.APIVersion, tc.apiVersion)
}
if header.CorrelationID != correlationID {
t.Errorf("CorrelationID = %d, want %d", header.CorrelationID, correlationID)
}
if header.ClientID == nil || *header.ClientID != tc.clientID {
t.Errorf("ClientID = %v, want %s", header.ClientID, tc.clientID)
}
// Check tagged fields presence
hasTaggedFields := header.TaggedFields != nil
if hasTaggedFields != tc.expectFlexible {
t.Errorf("Tagged fields present = %v, want %v", hasTaggedFields, tc.expectFlexible)
}
// Validate body
expectedBody := []byte{1, 2, 3, 4}
if len(body) != len(expectedBody) {
t.Errorf("Body length = %d, want %d", len(body), len(expectedBody))
}
for i, b := range expectedBody {
if i < len(body) && body[i] != b {
t.Errorf("Body[%d] = %d, want %d", i, body[i], b)
}
}
t.Logf("Header parsing for %s v%d: flexible=%v, client=%s",
getAPIName(tc.apiKey), tc.apiVersion, tc.expectFlexible, tc.clientID)
})
}
}
func TestCreateTopics_FlexibleVersionConsistency(t *testing.T) {
handler := NewTestHandler()
// Test that CreateTopics v2+ continues to work correctly with flexible version utilities
correlationID := uint32(99999)
// Build CreateTopics v2 request using flexible version utilities
var requestData []byte
// Topics array (compact: 1 topic = 2)
requestData = append(requestData, 2)
// Topic name (compact string)
topicName := "flexible-test-topic"
requestData = append(requestData, FlexibleString(topicName)...)
// Number of partitions (4 bytes)
requestData = append(requestData, 0, 0, 0, 3)
// Replication factor (2 bytes)
requestData = append(requestData, 0, 1)
// Configs array (compact: empty = 0)
requestData = append(requestData, 0)
// Tagged fields (empty)
requestData = append(requestData, 0)
// Timeout (4 bytes)
requestData = append(requestData, 0, 0, 0x27, 0x10) // 10000ms
// Validate only (1 byte)
requestData = append(requestData, 0)
// Tagged fields at end
requestData = append(requestData, 0)
// Call CreateTopics v2
response, err := handler.handleCreateTopicsV2Plus(correlationID, 2, requestData)
if err != nil {
t.Fatalf("handleCreateTopicsV2Plus failed: %v", err)
}
if len(response) < 8 {
t.Fatalf("Response too short: %d bytes", len(response))
}
// Check correlation ID
respCorrelationID := binary.BigEndian.Uint32(response[0:4])
if respCorrelationID != correlationID {
t.Errorf("Correlation ID = %d, want %d", respCorrelationID, correlationID)
}
// Verify topic was created
if !handler.seaweedMQHandler.TopicExists(topicName) {
t.Errorf("Topic '%s' was not created", topicName)
}
t.Logf("CreateTopics v2 with flexible utilities: topic created successfully")
}
func BenchmarkFlexibleVersions_HeaderParsing(b *testing.B) {
// Pre-construct headers for different scenarios
scenarios := []struct {
name string
data []byte
}{
{
name: "Regular_v1",
data: func() []byte {
var data []byte
data = append(data, 0, 3, 0, 1) // Metadata v1
corrBytes := make([]byte, 4)
binary.BigEndian.PutUint32(corrBytes, 12345)
data = append(data, corrBytes...)
data = append(data, 0, 11, 'b', 'e', 'n', 'c', 'h', '-', 'c', 'l', 'i', 'e', 'n', 't')
data = append(data, 1, 2, 3)
return data
}(),
},
{
name: "Flexible_v9",
data: func() []byte {
var data []byte
data = append(data, 0, 3, 0, 9) // Metadata v9
corrBytes := make([]byte, 4)
binary.BigEndian.PutUint32(corrBytes, 12345)
data = append(data, corrBytes...)
data = append(data, FlexibleString("bench-client")...)
data = append(data, 0) // empty tagged fields
data = append(data, 1, 2, 3)
return data
}(),
},
}
for _, scenario := range scenarios {
b.Run(scenario.name, func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, err := ParseRequestHeader(scenario.data)
if err != nil {
b.Fatalf("ParseRequestHeader failed: %v", err)
}
}
})
}
}

486
weed/mq/kafka/protocol/flexible_versions_test.go

@ -0,0 +1,486 @@
package protocol
import (
"bytes"
"encoding/binary"
"fmt"
"testing"
)
func TestEncodeDecodeUvarint(t *testing.T) {
testCases := []uint32{
0, 1, 127, 128, 255, 256, 16383, 16384, 32767, 32768, 65535, 65536,
0x1FFFFF, 0x200000, 0x0FFFFFFF, 0x10000000, 0xFFFFFFFF,
}
for _, value := range testCases {
t.Run(fmt.Sprintf("value_%d", value), func(t *testing.T) {
encoded := EncodeUvarint(value)
decoded, consumed, err := DecodeUvarint(encoded)
if err != nil {
t.Fatalf("DecodeUvarint failed: %v", err)
}
if decoded != value {
t.Errorf("Decoded value %d != original %d", decoded, value)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestCompactArrayLength(t *testing.T) {
testCases := []struct {
name string
length uint32
expected []byte
}{
{"Empty array", 0, []byte{0}},
{"Single element", 1, []byte{2}},
{"Small array", 10, []byte{11}},
{"Large array", 127, []byte{128, 1}}, // 128 = 127+1 encoded as varint (two bytes since >= 128)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := CompactArrayLength(tc.length)
if !bytes.Equal(encoded, tc.expected) {
t.Errorf("CompactArrayLength(%d) = %v, want %v", tc.length, encoded, tc.expected)
}
// Test round trip
decoded, consumed, err := DecodeCompactArrayLength(encoded)
if err != nil {
t.Fatalf("DecodeCompactArrayLength failed: %v", err)
}
if decoded != tc.length {
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestCompactStringLength(t *testing.T) {
testCases := []struct {
name string
length int
expected []byte
}{
{"Null string", -1, []byte{0}},
{"Empty string", 0, []byte{1}},
{"Short string", 5, []byte{6}},
{"Medium string", 100, []byte{101}}, // 101 encoded as varint (single byte since < 128)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := CompactStringLength(tc.length)
if !bytes.Equal(encoded, tc.expected) {
t.Errorf("CompactStringLength(%d) = %v, want %v", tc.length, encoded, tc.expected)
}
// Test round trip
decoded, consumed, err := DecodeCompactStringLength(encoded)
if err != nil {
t.Fatalf("DecodeCompactStringLength failed: %v", err)
}
if decoded != tc.length {
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestFlexibleString(t *testing.T) {
testCases := []struct {
name string
input string
expected []byte
}{
{"Empty string", "", []byte{0}},
{"Hello", "hello", []byte{6, 'h', 'e', 'l', 'l', 'o'}},
{"Unicode", "测试", []byte{7, 0xE6, 0xB5, 0x8B, 0xE8, 0xAF, 0x95}}, // UTF-8 encoded
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := FlexibleString(tc.input)
if !bytes.Equal(encoded, tc.expected) {
t.Errorf("FlexibleString(%q) = %v, want %v", tc.input, encoded, tc.expected)
}
// Test round trip
decoded, consumed, err := DecodeFlexibleString(encoded)
if err != nil {
t.Fatalf("DecodeFlexibleString failed: %v", err)
}
if decoded != tc.input {
t.Errorf("Round trip failed: got %q, want %q", decoded, tc.input)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestFlexibleNullableString(t *testing.T) {
// Null string
nullResult := FlexibleNullableString(nil)
expected := []byte{0}
if !bytes.Equal(nullResult, expected) {
t.Errorf("FlexibleNullableString(nil) = %v, want %v", nullResult, expected)
}
// Non-null string
str := "test"
nonNullResult := FlexibleNullableString(&str)
expectedNonNull := []byte{5, 't', 'e', 's', 't'}
if !bytes.Equal(nonNullResult, expectedNonNull) {
t.Errorf("FlexibleNullableString(&%q) = %v, want %v", str, nonNullResult, expectedNonNull)
}
}
func TestTaggedFields(t *testing.T) {
t.Run("Empty tagged fields", func(t *testing.T) {
tf := &TaggedFields{}
encoded := tf.Encode()
expected := []byte{0}
if !bytes.Equal(encoded, expected) {
t.Errorf("Empty TaggedFields.Encode() = %v, want %v", encoded, expected)
}
// Test round trip
decoded, consumed, err := DecodeTaggedFields(encoded)
if err != nil {
t.Fatalf("DecodeTaggedFields failed: %v", err)
}
if len(decoded.Fields) != 0 {
t.Errorf("Decoded tagged fields length = %d, want 0", len(decoded.Fields))
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
t.Run("Single tagged field", func(t *testing.T) {
tf := &TaggedFields{
Fields: []TaggedField{
{Tag: 1, Data: []byte("test")},
},
}
encoded := tf.Encode()
// Test round trip
decoded, consumed, err := DecodeTaggedFields(encoded)
if err != nil {
t.Fatalf("DecodeTaggedFields failed: %v", err)
}
if len(decoded.Fields) != 1 {
t.Fatalf("Decoded tagged fields length = %d, want 1", len(decoded.Fields))
}
field := decoded.Fields[0]
if field.Tag != 1 {
t.Errorf("Decoded tag = %d, want 1", field.Tag)
}
if !bytes.Equal(field.Data, []byte("test")) {
t.Errorf("Decoded data = %v, want %v", field.Data, []byte("test"))
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
t.Run("Multiple tagged fields", func(t *testing.T) {
tf := &TaggedFields{
Fields: []TaggedField{
{Tag: 1, Data: []byte("first")},
{Tag: 5, Data: []byte("second")},
},
}
encoded := tf.Encode()
// Test round trip
decoded, consumed, err := DecodeTaggedFields(encoded)
if err != nil {
t.Fatalf("DecodeTaggedFields failed: %v", err)
}
if len(decoded.Fields) != 2 {
t.Fatalf("Decoded tagged fields length = %d, want 2", len(decoded.Fields))
}
// Check first field
field1 := decoded.Fields[0]
if field1.Tag != 1 {
t.Errorf("Decoded field 1 tag = %d, want 1", field1.Tag)
}
if !bytes.Equal(field1.Data, []byte("first")) {
t.Errorf("Decoded field 1 data = %v, want %v", field1.Data, []byte("first"))
}
// Check second field
field2 := decoded.Fields[1]
if field2.Tag != 5 {
t.Errorf("Decoded field 2 tag = %d, want 5", field2.Tag)
}
if !bytes.Equal(field2.Data, []byte("second")) {
t.Errorf("Decoded field 2 data = %v, want %v", field2.Data, []byte("second"))
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
func TestIsFlexibleVersion(t *testing.T) {
testCases := []struct {
apiKey uint16
apiVersion uint16
expected bool
name string
}{
// ApiVersions
{18, 2, false, "ApiVersions v2"},
{18, 3, true, "ApiVersions v3"},
{18, 4, true, "ApiVersions v4"},
// Metadata
{3, 8, false, "Metadata v8"},
{3, 9, true, "Metadata v9"},
{3, 10, true, "Metadata v10"},
// Fetch
{1, 11, false, "Fetch v11"},
{1, 12, true, "Fetch v12"},
{1, 13, true, "Fetch v13"},
// Produce
{0, 8, false, "Produce v8"},
{0, 9, true, "Produce v9"},
{0, 10, true, "Produce v10"},
// CreateTopics
{19, 1, false, "CreateTopics v1"},
{19, 2, true, "CreateTopics v2"},
{19, 3, true, "CreateTopics v3"},
// Unknown API
{99, 1, false, "Unknown API"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsFlexibleVersion(tc.apiKey, tc.apiVersion)
if result != tc.expected {
t.Errorf("IsFlexibleVersion(%d, %d) = %v, want %v",
tc.apiKey, tc.apiVersion, result, tc.expected)
}
})
}
}
func TestParseRequestHeader(t *testing.T) {
t.Run("Regular version header", func(t *testing.T) {
// Construct a regular version header (Metadata v1)
data := make([]byte, 0)
data = append(data, 0, 3) // API Key = 3 (Metadata)
data = append(data, 0, 1) // API Version = 1
data = append(data, 0, 0, 0, 123) // Correlation ID = 123
data = append(data, 0, 4) // Client ID length = 4
data = append(data, 't', 'e', 's', 't') // Client ID = "test"
data = append(data, 1, 2, 3) // Request body
header, body, err := ParseRequestHeader(data)
if err != nil {
t.Fatalf("ParseRequestHeader failed: %v", err)
}
if header.APIKey != 3 {
t.Errorf("APIKey = %d, want 3", header.APIKey)
}
if header.APIVersion != 1 {
t.Errorf("APIVersion = %d, want 1", header.APIVersion)
}
if header.CorrelationID != 123 {
t.Errorf("CorrelationID = %d, want 123", header.CorrelationID)
}
if header.ClientID == nil || *header.ClientID != "test" {
t.Errorf("ClientID = %v, want 'test'", header.ClientID)
}
if header.TaggedFields != nil {
t.Errorf("TaggedFields should be nil for regular versions")
}
expectedBody := []byte{1, 2, 3}
if !bytes.Equal(body, expectedBody) {
t.Errorf("Body = %v, want %v", body, expectedBody)
}
})
t.Run("Flexible version header", func(t *testing.T) {
// Construct a flexible version header (ApiVersions v3)
data := make([]byte, 0)
data = append(data, 0, 18) // API Key = 18 (ApiVersions)
data = append(data, 0, 3) // API Version = 3 (flexible)
// Correlation ID = 456 (4 bytes, big endian)
correlationID := make([]byte, 4)
binary.BigEndian.PutUint32(correlationID, 456)
data = append(data, correlationID...)
data = append(data, 5, 't', 'e', 's', 't') // Client ID = "test" (compact string)
data = append(data, 0) // Empty tagged fields
data = append(data, 4, 5, 6) // Request body
header, body, err := ParseRequestHeader(data)
if err != nil {
t.Fatalf("ParseRequestHeader failed: %v", err)
}
if header.APIKey != 18 {
t.Errorf("APIKey = %d, want 18", header.APIKey)
}
if header.APIVersion != 3 {
t.Errorf("APIVersion = %d, want 3", header.APIVersion)
}
if header.CorrelationID != 456 {
t.Errorf("CorrelationID = %d, want 456", header.CorrelationID)
}
if header.ClientID == nil || *header.ClientID != "test" {
t.Errorf("ClientID = %v, want 'test'", header.ClientID)
}
if header.TaggedFields == nil {
t.Errorf("TaggedFields should not be nil for flexible versions")
}
if len(header.TaggedFields.Fields) != 0 {
t.Errorf("TaggedFields should be empty")
}
expectedBody := []byte{4, 5, 6}
if !bytes.Equal(body, expectedBody) {
t.Errorf("Body = %v, want %v", body, expectedBody)
}
})
t.Run("Null client ID", func(t *testing.T) {
// Regular version with null client ID
data := make([]byte, 0)
data = append(data, 0, 3) // API Key = 3 (Metadata)
data = append(data, 0, 1) // API Version = 1
// Correlation ID = 789 (4 bytes, big endian)
correlationID := make([]byte, 4)
binary.BigEndian.PutUint32(correlationID, 789)
data = append(data, correlationID...)
data = append(data, 0xFF, 0xFF) // Client ID length = -1 (null)
data = append(data, 7, 8, 9) // Request body
header, body, err := ParseRequestHeader(data)
if err != nil {
t.Fatalf("ParseRequestHeader failed: %v", err)
}
if header.ClientID != nil {
t.Errorf("ClientID = %v, want nil", header.ClientID)
}
expectedBody := []byte{7, 8, 9}
if !bytes.Equal(body, expectedBody) {
t.Errorf("Body = %v, want %v", body, expectedBody)
}
})
}
func TestEncodeFlexibleResponse(t *testing.T) {
correlationID := uint32(123)
data := []byte{1, 2, 3, 4}
t.Run("Without tagged fields", func(t *testing.T) {
result := EncodeFlexibleResponse(correlationID, data, false)
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4}
if !bytes.Equal(result, expected) {
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected)
}
})
t.Run("With tagged fields", func(t *testing.T) {
result := EncodeFlexibleResponse(correlationID, data, true)
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4, 0} // 0 at end for empty tagged fields
if !bytes.Equal(result, expected) {
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected)
}
})
}
func BenchmarkEncodeUvarint(b *testing.B) {
testValues := []uint32{0, 127, 128, 16383, 16384, 65535, 65536, 0xFFFFFFFF}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, val := range testValues {
EncodeUvarint(val)
}
}
}
func BenchmarkDecodeUvarint(b *testing.B) {
// Pre-encode test values
testData := [][]byte{
EncodeUvarint(0),
EncodeUvarint(127),
EncodeUvarint(128),
EncodeUvarint(16383),
EncodeUvarint(16384),
EncodeUvarint(65535),
EncodeUvarint(65536),
EncodeUvarint(0xFFFFFFFF),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, data := range testData {
DecodeUvarint(data)
}
}
}
func BenchmarkFlexibleString(b *testing.B) {
testStrings := []string{"", "a", "hello", "this is a longer test string", "测试中文字符串"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, s := range testStrings {
FlexibleString(s)
}
}
}

86
weed/mq/kafka/protocol/handler.go

@ -300,24 +300,49 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
continue
}
// Strip client_id (nullable STRING) from header to get pure request body
bodyOffset := 8
if len(messageBuf) < bodyOffset+2 {
return fmt.Errorf("invalid header: missing client_id length")
}
clientIDLen := int16(binary.BigEndian.Uint16(messageBuf[bodyOffset : bodyOffset+2]))
bodyOffset += 2
if clientIDLen >= 0 {
if len(messageBuf) < bodyOffset+int(clientIDLen) {
return fmt.Errorf("invalid header: client_id truncated")
// Parse header using flexible version utilities for validation and client ID extraction
header, requestBody, parseErr := ParseRequestHeader(messageBuf)
if parseErr != nil {
// Fall back to basic header parsing if flexible version parsing fails
fmt.Printf("DEBUG: Flexible header parsing failed, using basic parsing: %v\n", parseErr)
// Basic header parsing fallback (original logic)
bodyOffset := 8
if len(messageBuf) < bodyOffset+2 {
return fmt.Errorf("invalid header: missing client_id length")
}
clientIDLen := int16(binary.BigEndian.Uint16(messageBuf[bodyOffset : bodyOffset+2]))
bodyOffset += 2
if clientIDLen >= 0 {
if len(messageBuf) < bodyOffset+int(clientIDLen) {
return fmt.Errorf("invalid header: client_id truncated")
}
bodyOffset += int(clientIDLen)
}
// clientID := string(messageBuf[bodyOffset : bodyOffset+int(clientIDLen)])
bodyOffset += int(clientIDLen)
requestBody = messageBuf[bodyOffset:]
} else {
// client_id is null; nothing to skip
// Validate parsed header matches what we already extracted
if header.APIKey != apiKey || header.APIVersion != apiVersion || header.CorrelationID != correlationID {
fmt.Printf("DEBUG: Header parsing mismatch - using basic parsing as fallback\n")
// Fall back to basic parsing rather than failing
bodyOffset := 8
if len(messageBuf) < bodyOffset+2 {
return fmt.Errorf("invalid header: missing client_id length")
}
clientIDLen := int16(binary.BigEndian.Uint16(messageBuf[bodyOffset : bodyOffset+2]))
bodyOffset += 2
if clientIDLen >= 0 {
if len(messageBuf) < bodyOffset+int(clientIDLen) {
return fmt.Errorf("invalid header: client_id truncated")
}
bodyOffset += int(clientIDLen)
}
requestBody = messageBuf[bodyOffset:]
} else if header.ClientID != nil {
// Log client ID if available and parsing was successful
fmt.Printf("DEBUG: Client ID: %s\n", *header.ClientID)
}
}
// TODO: Flexible versions have tagged fields in header; ignored for now
requestBody := messageBuf[bodyOffset:]
// Handle the request based on API key and version
var response []byte
@ -325,7 +350,7 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
switch apiKey {
case 18: // ApiVersions
response, err = h.handleApiVersions(correlationID)
response, err = h.handleApiVersions(correlationID, apiVersion)
case 3: // Metadata
response, err = h.handleMetadata(correlationID, apiVersion, requestBody)
case 2: // ListOffsets
@ -410,11 +435,11 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
}
}
func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) {
// Build ApiVersions response manually
// Response format (v0): correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys
response := make([]byte, 0, 64)
func (h *Handler) handleApiVersions(correlationID uint32, apiVersion uint16) ([]byte, error) {
// Build ApiVersions response supporting flexible versions (v3+)
isFlexible := IsFlexibleVersion(18, apiVersion)
response := make([]byte, 0, 128)
// Correlation ID
correlationIDBytes := make([]byte, 4)
@ -424,8 +449,15 @@ func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) {
// Error code (0 = no error)
response = append(response, 0, 0)
// Number of API keys (compact array format in newer versions, but using basic format for simplicity)
response = append(response, 0, 0, 0, 14) // 14 API keys
// Number of API keys - use compact or regular array format based on version
apiKeysCount := uint32(14)
if isFlexible {
// Compact array format for flexible versions
response = append(response, CompactArrayLength(apiKeysCount)...)
} else {
// Regular array format for older versions
response = append(response, 0, 0, 0, 14) // 14 API keys
}
// API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2)
response = append(response, 0, 18) // API key 18
@ -500,7 +532,13 @@ func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) {
response = append(response, 0, 0) // min version 0
response = append(response, 0, 4) // max version 4
fmt.Printf("DEBUG: ApiVersions v0 response: %d bytes\n", len(response))
// Add tagged fields for flexible versions
if isFlexible {
// Empty tagged fields for now
response = append(response, 0)
}
fmt.Printf("DEBUG: ApiVersions v%d response: %d bytes\n", apiVersion, len(response))
return response, nil
}

Loading…
Cancel
Save