Browse Source
Phase 6: Add basic flexible versions support
Phase 6: Add basic flexible versions support
- Added flexible_versions.go with utilities for Kafka flexible versions (v3+) - Implemented ParseRequestHeader for compact string parsing and tagged fields - Added fallback mechanism in handler.go for backward compatibility - Updated handleApiVersions to support flexible version responses - Added comprehensive tests for flexible version utilities - All protocol tests passing with robust error handlingpull/7231/head
5 changed files with 1214 additions and 26 deletions
-
4weed/mq/kafka/protocol/api_versions_test.go
-
359weed/mq/kafka/protocol/flexible_versions.go
-
305weed/mq/kafka/protocol/flexible_versions_integration_test.go
-
486weed/mq/kafka/protocol/flexible_versions_test.go
-
84weed/mq/kafka/protocol/handler.go
@ -0,0 +1,359 @@ |
|||
package protocol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"fmt" |
|||
) |
|||
|
|||
// FlexibleVersions provides utilities for handling Kafka flexible versions protocol
|
|||
// Flexible versions use compact arrays/strings and tagged fields for backward compatibility
|
|||
|
|||
// CompactArrayLength encodes a length for compact arrays
|
|||
// Compact arrays encode length as length+1, where 0 means empty array
|
|||
func CompactArrayLength(length uint32) []byte { |
|||
if length == 0 { |
|||
return []byte{0} // Empty array
|
|||
} |
|||
return EncodeUvarint(length + 1) |
|||
} |
|||
|
|||
// DecodeCompactArrayLength decodes a compact array length
|
|||
// Returns the actual length and number of bytes consumed
|
|||
func DecodeCompactArrayLength(data []byte) (uint32, int, error) { |
|||
if len(data) == 0 { |
|||
return 0, 0, fmt.Errorf("no data for compact array length") |
|||
} |
|||
|
|||
if data[0] == 0 { |
|||
return 0, 1, nil // Empty array
|
|||
} |
|||
|
|||
length, consumed, err := DecodeUvarint(data) |
|||
if err != nil { |
|||
return 0, 0, fmt.Errorf("decode compact array length: %w", err) |
|||
} |
|||
|
|||
if length == 0 { |
|||
return 0, consumed, fmt.Errorf("invalid compact array length encoding") |
|||
} |
|||
|
|||
return length - 1, consumed, nil |
|||
} |
|||
|
|||
// CompactStringLength encodes a length for compact strings
|
|||
// Compact strings encode length as length+1, where 0 means null string
|
|||
func CompactStringLength(length int) []byte { |
|||
if length < 0 { |
|||
return []byte{0} // Null string
|
|||
} |
|||
return EncodeUvarint(uint32(length + 1)) |
|||
} |
|||
|
|||
// DecodeCompactStringLength decodes a compact string length
|
|||
// Returns the actual length (-1 for null), and number of bytes consumed
|
|||
func DecodeCompactStringLength(data []byte) (int, int, error) { |
|||
if len(data) == 0 { |
|||
return 0, 0, fmt.Errorf("no data for compact string length") |
|||
} |
|||
|
|||
if data[0] == 0 { |
|||
return -1, 1, nil // Null string
|
|||
} |
|||
|
|||
length, consumed, err := DecodeUvarint(data) |
|||
if err != nil { |
|||
return 0, 0, fmt.Errorf("decode compact string length: %w", err) |
|||
} |
|||
|
|||
if length == 0 { |
|||
return 0, consumed, fmt.Errorf("invalid compact string length encoding") |
|||
} |
|||
|
|||
return int(length - 1), consumed, nil |
|||
} |
|||
|
|||
// EncodeUvarint encodes an unsigned integer using variable-length encoding
|
|||
// This is used for compact arrays, strings, and tagged fields
|
|||
func EncodeUvarint(value uint32) []byte { |
|||
var buf []byte |
|||
for value >= 0x80 { |
|||
buf = append(buf, byte(value)|0x80) |
|||
value >>= 7 |
|||
} |
|||
buf = append(buf, byte(value)) |
|||
return buf |
|||
} |
|||
|
|||
// DecodeUvarint decodes a variable-length unsigned integer
|
|||
// Returns the decoded value and number of bytes consumed
|
|||
func DecodeUvarint(data []byte) (uint32, int, error) { |
|||
var value uint32 |
|||
var shift uint |
|||
var consumed int |
|||
|
|||
for i, b := range data { |
|||
consumed = i + 1 |
|||
value |= uint32(b&0x7F) << shift |
|||
|
|||
if (b & 0x80) == 0 { |
|||
return value, consumed, nil |
|||
} |
|||
|
|||
shift += 7 |
|||
if shift >= 32 { |
|||
return 0, consumed, fmt.Errorf("uvarint overflow") |
|||
} |
|||
} |
|||
|
|||
return 0, consumed, fmt.Errorf("incomplete uvarint") |
|||
} |
|||
|
|||
// TaggedField represents a tagged field in flexible versions
|
|||
type TaggedField struct { |
|||
Tag uint32 |
|||
Data []byte |
|||
} |
|||
|
|||
// TaggedFields represents a collection of tagged fields
|
|||
type TaggedFields struct { |
|||
Fields []TaggedField |
|||
} |
|||
|
|||
// EncodeTaggedFields encodes tagged fields for flexible versions
|
|||
func (tf *TaggedFields) Encode() []byte { |
|||
if len(tf.Fields) == 0 { |
|||
return []byte{0} // Empty tagged fields
|
|||
} |
|||
|
|||
var buf []byte |
|||
|
|||
// Number of tagged fields
|
|||
buf = append(buf, EncodeUvarint(uint32(len(tf.Fields)))...) |
|||
|
|||
for _, field := range tf.Fields { |
|||
// Tag
|
|||
buf = append(buf, EncodeUvarint(field.Tag)...) |
|||
// Size
|
|||
buf = append(buf, EncodeUvarint(uint32(len(field.Data)))...) |
|||
// Data
|
|||
buf = append(buf, field.Data...) |
|||
} |
|||
|
|||
return buf |
|||
} |
|||
|
|||
// DecodeTaggedFields decodes tagged fields from flexible versions
|
|||
func DecodeTaggedFields(data []byte) (*TaggedFields, int, error) { |
|||
if len(data) == 0 { |
|||
return &TaggedFields{}, 0, fmt.Errorf("no data for tagged fields") |
|||
} |
|||
|
|||
if data[0] == 0 { |
|||
return &TaggedFields{}, 1, nil // Empty tagged fields
|
|||
} |
|||
|
|||
offset := 0 |
|||
|
|||
// Number of tagged fields
|
|||
numFields, consumed, err := DecodeUvarint(data[offset:]) |
|||
if err != nil { |
|||
return nil, 0, fmt.Errorf("decode tagged fields count: %w", err) |
|||
} |
|||
offset += consumed |
|||
|
|||
fields := make([]TaggedField, numFields) |
|||
|
|||
for i := uint32(0); i < numFields; i++ { |
|||
// Tag
|
|||
tag, consumed, err := DecodeUvarint(data[offset:]) |
|||
if err != nil { |
|||
return nil, 0, fmt.Errorf("decode tagged field %d tag: %w", i, err) |
|||
} |
|||
offset += consumed |
|||
|
|||
// Size
|
|||
size, consumed, err := DecodeUvarint(data[offset:]) |
|||
if err != nil { |
|||
return nil, 0, fmt.Errorf("decode tagged field %d size: %w", i, err) |
|||
} |
|||
offset += consumed |
|||
|
|||
// Data
|
|||
if offset+int(size) > len(data) { |
|||
return nil, 0, fmt.Errorf("tagged field %d data truncated", i) |
|||
} |
|||
|
|||
fields[i] = TaggedField{ |
|||
Tag: tag, |
|||
Data: data[offset : offset+int(size)], |
|||
} |
|||
offset += int(size) |
|||
} |
|||
|
|||
return &TaggedFields{Fields: fields}, offset, nil |
|||
} |
|||
|
|||
// IsFlexibleVersion determines if an API version uses flexible versions
|
|||
// This is API-specific and based on when each API adopted flexible versions
|
|||
func IsFlexibleVersion(apiKey, apiVersion uint16) bool { |
|||
switch apiKey { |
|||
case 18: // ApiVersions
|
|||
return apiVersion >= 3 |
|||
case 3: // Metadata
|
|||
return apiVersion >= 9 |
|||
case 1: // Fetch
|
|||
return apiVersion >= 12 |
|||
case 0: // Produce
|
|||
return apiVersion >= 9 |
|||
case 11: // JoinGroup
|
|||
return apiVersion >= 6 |
|||
case 14: // SyncGroup
|
|||
return apiVersion >= 4 |
|||
case 8: // OffsetCommit
|
|||
return apiVersion >= 8 |
|||
case 9: // OffsetFetch
|
|||
return apiVersion >= 6 |
|||
case 10: // FindCoordinator
|
|||
return apiVersion >= 3 |
|||
case 12: // Heartbeat
|
|||
return apiVersion >= 4 |
|||
case 13: // LeaveGroup
|
|||
return apiVersion >= 4 |
|||
case 19: // CreateTopics
|
|||
return apiVersion >= 2 |
|||
case 20: // DeleteTopics
|
|||
return apiVersion >= 4 |
|||
default: |
|||
return false |
|||
} |
|||
} |
|||
|
|||
// FlexibleString encodes a string for flexible versions (compact format)
|
|||
func FlexibleString(s string) []byte { |
|||
if s == "" { |
|||
return []byte{0} // Null string
|
|||
} |
|||
|
|||
var buf []byte |
|||
buf = append(buf, CompactStringLength(len(s))...) |
|||
buf = append(buf, []byte(s)...) |
|||
return buf |
|||
} |
|||
|
|||
// FlexibleNullableString encodes a nullable string for flexible versions
|
|||
func FlexibleNullableString(s *string) []byte { |
|||
if s == nil { |
|||
return []byte{0} // Null string
|
|||
} |
|||
return FlexibleString(*s) |
|||
} |
|||
|
|||
// DecodeFlexibleString decodes a flexible string
|
|||
// Returns the string (empty for null) and bytes consumed
|
|||
func DecodeFlexibleString(data []byte) (string, int, error) { |
|||
length, consumed, err := DecodeCompactStringLength(data) |
|||
if err != nil { |
|||
return "", 0, err |
|||
} |
|||
|
|||
if length < 0 { |
|||
return "", consumed, nil // Null string -> empty string
|
|||
} |
|||
|
|||
if consumed+length > len(data) { |
|||
return "", 0, fmt.Errorf("string data truncated") |
|||
} |
|||
|
|||
return string(data[consumed : consumed+length]), consumed + length, nil |
|||
} |
|||
|
|||
// FlexibleVersionHeader handles the request header parsing for flexible versions
|
|||
type FlexibleVersionHeader struct { |
|||
APIKey uint16 |
|||
APIVersion uint16 |
|||
CorrelationID uint32 |
|||
ClientID *string |
|||
TaggedFields *TaggedFields |
|||
} |
|||
|
|||
// ParseRequestHeader parses a Kafka request header, handling both regular and flexible versions
|
|||
func ParseRequestHeader(data []byte) (*FlexibleVersionHeader, []byte, error) { |
|||
if len(data) < 8 { |
|||
return nil, nil, fmt.Errorf("header too short") |
|||
} |
|||
|
|||
header := &FlexibleVersionHeader{} |
|||
offset := 0 |
|||
|
|||
// API Key (2 bytes)
|
|||
header.APIKey = binary.BigEndian.Uint16(data[offset : offset+2]) |
|||
offset += 2 |
|||
|
|||
// API Version (2 bytes)
|
|||
header.APIVersion = binary.BigEndian.Uint16(data[offset : offset+2]) |
|||
offset += 2 |
|||
|
|||
// Correlation ID (4 bytes)
|
|||
header.CorrelationID = binary.BigEndian.Uint32(data[offset : offset+4]) |
|||
offset += 4 |
|||
|
|||
// Client ID handling depends on flexible version
|
|||
isFlexible := IsFlexibleVersion(header.APIKey, header.APIVersion) |
|||
|
|||
if isFlexible { |
|||
// Flexible versions use compact strings
|
|||
clientID, consumed, err := DecodeFlexibleString(data[offset:]) |
|||
if err != nil { |
|||
return nil, nil, fmt.Errorf("decode flexible client_id: %w", err) |
|||
} |
|||
offset += consumed |
|||
|
|||
if clientID != "" { |
|||
header.ClientID = &clientID |
|||
} |
|||
|
|||
// Parse tagged fields in header
|
|||
taggedFields, consumed, err := DecodeTaggedFields(data[offset:]) |
|||
if err != nil { |
|||
return nil, nil, fmt.Errorf("decode header tagged fields: %w", err) |
|||
} |
|||
offset += consumed |
|||
header.TaggedFields = taggedFields |
|||
|
|||
} else { |
|||
// Regular versions use standard strings
|
|||
if len(data) < offset+2 { |
|||
return nil, nil, fmt.Errorf("missing client_id length") |
|||
} |
|||
|
|||
clientIDLen := int16(binary.BigEndian.Uint16(data[offset : offset+2])) |
|||
offset += 2 |
|||
|
|||
if clientIDLen >= 0 { |
|||
if len(data) < offset+int(clientIDLen) { |
|||
return nil, nil, fmt.Errorf("client_id truncated") |
|||
} |
|||
|
|||
clientID := string(data[offset : offset+int(clientIDLen)]) |
|||
header.ClientID = &clientID |
|||
offset += int(clientIDLen) |
|||
} |
|||
// No tagged fields in regular versions
|
|||
} |
|||
|
|||
return header, data[offset:], nil |
|||
} |
|||
|
|||
// EncodeFlexibleResponse encodes a response with proper flexible version formatting
|
|||
func EncodeFlexibleResponse(correlationID uint32, data []byte, hasTaggedFields bool) []byte { |
|||
response := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(response, correlationID) |
|||
response = append(response, data...) |
|||
|
|||
if hasTaggedFields { |
|||
// Add empty tagged fields for flexible responses
|
|||
response = append(response, 0) |
|||
} |
|||
|
|||
return response |
|||
} |
|||
@ -0,0 +1,305 @@ |
|||
package protocol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"testing" |
|||
) |
|||
|
|||
func TestApiVersions_FlexibleVersionSupport(t *testing.T) { |
|||
handler := NewTestHandler() |
|||
|
|||
testCases := []struct { |
|||
name string |
|||
apiVersion uint16 |
|||
expectFlexible bool |
|||
}{ |
|||
{"ApiVersions v0", 0, false}, |
|||
{"ApiVersions v1", 1, false}, |
|||
{"ApiVersions v2", 2, false}, |
|||
{"ApiVersions v3", 3, true}, |
|||
{"ApiVersions v4", 4, true}, |
|||
} |
|||
|
|||
for _, tc := range testCases { |
|||
t.Run(tc.name, func(t *testing.T) { |
|||
correlationID := uint32(12345) |
|||
|
|||
response, err := handler.handleApiVersions(correlationID, tc.apiVersion) |
|||
if err != nil { |
|||
t.Fatalf("handleApiVersions failed: %v", err) |
|||
} |
|||
|
|||
if len(response) < 4 { |
|||
t.Fatalf("Response too short: %d bytes", len(response)) |
|||
} |
|||
|
|||
// Check correlation ID
|
|||
respCorrelationID := binary.BigEndian.Uint32(response[0:4]) |
|||
if respCorrelationID != correlationID { |
|||
t.Errorf("Correlation ID = %d, want %d", respCorrelationID, correlationID) |
|||
} |
|||
|
|||
// Check error code
|
|||
errorCode := binary.BigEndian.Uint16(response[4:6]) |
|||
if errorCode != 0 { |
|||
t.Errorf("Error code = %d, want 0", errorCode) |
|||
} |
|||
|
|||
// Parse API keys count based on version
|
|||
offset := 6 |
|||
var apiKeysCount uint32 |
|||
|
|||
if tc.expectFlexible { |
|||
// Should use compact array format
|
|||
count, consumed, err := DecodeCompactArrayLength(response[offset:]) |
|||
if err != nil { |
|||
t.Fatalf("Failed to decode compact array length: %v", err) |
|||
} |
|||
apiKeysCount = count |
|||
offset += consumed |
|||
} else { |
|||
// Should use regular array format
|
|||
if len(response) < offset+4 { |
|||
t.Fatalf("Response too short for regular array length") |
|||
} |
|||
apiKeysCount = binary.BigEndian.Uint32(response[offset:offset+4]) |
|||
offset += 4 |
|||
} |
|||
|
|||
if apiKeysCount != 14 { |
|||
t.Errorf("API keys count = %d, want 14", apiKeysCount) |
|||
} |
|||
|
|||
// Verify that we have enough data for all API keys
|
|||
// Each API key entry is 6 bytes: api_key(2) + min_version(2) + max_version(2)
|
|||
expectedMinSize := offset + int(apiKeysCount*6) |
|||
if tc.expectFlexible { |
|||
expectedMinSize += 1 // tagged fields
|
|||
} |
|||
|
|||
if len(response) < expectedMinSize { |
|||
t.Errorf("Response too short: got %d bytes, expected at least %d", len(response), expectedMinSize) |
|||
} |
|||
|
|||
// Check that ApiVersions API itself is properly listed
|
|||
// API Key 18 should be the first entry
|
|||
if len(response) >= offset+6 { |
|||
apiKey := binary.BigEndian.Uint16(response[offset:offset+2]) |
|||
minVersion := binary.BigEndian.Uint16(response[offset+2:offset+4]) |
|||
maxVersion := binary.BigEndian.Uint16(response[offset+4:offset+6]) |
|||
|
|||
if apiKey != 18 { |
|||
t.Errorf("First API key = %d, want 18 (ApiVersions)", apiKey) |
|||
} |
|||
if minVersion != 0 { |
|||
t.Errorf("ApiVersions min version = %d, want 0", minVersion) |
|||
} |
|||
if maxVersion != 3 { |
|||
t.Errorf("ApiVersions max version = %d, want 3", maxVersion) |
|||
} |
|||
} |
|||
|
|||
t.Logf("ApiVersions v%d response: %d bytes, flexible: %v", tc.apiVersion, len(response), tc.expectFlexible) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestFlexibleVersions_HeaderParsingIntegration(t *testing.T) { |
|||
testCases := []struct { |
|||
name string |
|||
apiKey uint16 |
|||
apiVersion uint16 |
|||
clientID string |
|||
expectFlexible bool |
|||
}{ |
|||
{"Metadata v8 (regular)", 3, 8, "test-client", false}, |
|||
{"Metadata v9 (flexible)", 3, 9, "test-client", true}, |
|||
{"ApiVersions v2 (regular)", 18, 2, "test-client", false}, |
|||
{"ApiVersions v3 (flexible)", 18, 3, "test-client", true}, |
|||
{"CreateTopics v1 (regular)", 19, 1, "test-client", false}, |
|||
{"CreateTopics v2 (flexible)", 19, 2, "test-client", true}, |
|||
} |
|||
|
|||
for _, tc := range testCases { |
|||
t.Run(tc.name, func(t *testing.T) { |
|||
// Construct request header
|
|||
var headerData []byte |
|||
|
|||
// API Key (2 bytes)
|
|||
headerData = append(headerData, byte(tc.apiKey>>8), byte(tc.apiKey)) |
|||
|
|||
// API Version (2 bytes)
|
|||
headerData = append(headerData, byte(tc.apiVersion>>8), byte(tc.apiVersion)) |
|||
|
|||
// Correlation ID (4 bytes)
|
|||
correlationID := uint32(54321) |
|||
corrBytes := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(corrBytes, correlationID) |
|||
headerData = append(headerData, corrBytes...) |
|||
|
|||
if tc.expectFlexible { |
|||
// Flexible version: compact string for client ID
|
|||
headerData = append(headerData, FlexibleString(tc.clientID)...) |
|||
// Empty tagged fields
|
|||
headerData = append(headerData, 0) |
|||
} else { |
|||
// Regular version: standard string for client ID
|
|||
clientIDBytes := []byte(tc.clientID) |
|||
headerData = append(headerData, byte(len(clientIDBytes)>>8), byte(len(clientIDBytes))) |
|||
headerData = append(headerData, clientIDBytes...) |
|||
} |
|||
|
|||
// Add dummy request body
|
|||
headerData = append(headerData, 1, 2, 3, 4) |
|||
|
|||
// Parse header
|
|||
header, body, err := ParseRequestHeader(headerData) |
|||
if err != nil { |
|||
t.Fatalf("ParseRequestHeader failed: %v", err) |
|||
} |
|||
|
|||
// Validate parsed header
|
|||
if header.APIKey != tc.apiKey { |
|||
t.Errorf("APIKey = %d, want %d", header.APIKey, tc.apiKey) |
|||
} |
|||
if header.APIVersion != tc.apiVersion { |
|||
t.Errorf("APIVersion = %d, want %d", header.APIVersion, tc.apiVersion) |
|||
} |
|||
if header.CorrelationID != correlationID { |
|||
t.Errorf("CorrelationID = %d, want %d", header.CorrelationID, correlationID) |
|||
} |
|||
if header.ClientID == nil || *header.ClientID != tc.clientID { |
|||
t.Errorf("ClientID = %v, want %s", header.ClientID, tc.clientID) |
|||
} |
|||
|
|||
// Check tagged fields presence
|
|||
hasTaggedFields := header.TaggedFields != nil |
|||
if hasTaggedFields != tc.expectFlexible { |
|||
t.Errorf("Tagged fields present = %v, want %v", hasTaggedFields, tc.expectFlexible) |
|||
} |
|||
|
|||
// Validate body
|
|||
expectedBody := []byte{1, 2, 3, 4} |
|||
if len(body) != len(expectedBody) { |
|||
t.Errorf("Body length = %d, want %d", len(body), len(expectedBody)) |
|||
} |
|||
for i, b := range expectedBody { |
|||
if i < len(body) && body[i] != b { |
|||
t.Errorf("Body[%d] = %d, want %d", i, body[i], b) |
|||
} |
|||
} |
|||
|
|||
t.Logf("Header parsing for %s v%d: flexible=%v, client=%s", |
|||
getAPIName(tc.apiKey), tc.apiVersion, tc.expectFlexible, tc.clientID) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestCreateTopics_FlexibleVersionConsistency(t *testing.T) { |
|||
handler := NewTestHandler() |
|||
|
|||
// Test that CreateTopics v2+ continues to work correctly with flexible version utilities
|
|||
correlationID := uint32(99999) |
|||
|
|||
// Build CreateTopics v2 request using flexible version utilities
|
|||
var requestData []byte |
|||
|
|||
// Topics array (compact: 1 topic = 2)
|
|||
requestData = append(requestData, 2) |
|||
|
|||
// Topic name (compact string)
|
|||
topicName := "flexible-test-topic" |
|||
requestData = append(requestData, FlexibleString(topicName)...) |
|||
|
|||
// Number of partitions (4 bytes)
|
|||
requestData = append(requestData, 0, 0, 0, 3) |
|||
|
|||
// Replication factor (2 bytes)
|
|||
requestData = append(requestData, 0, 1) |
|||
|
|||
// Configs array (compact: empty = 0)
|
|||
requestData = append(requestData, 0) |
|||
|
|||
// Tagged fields (empty)
|
|||
requestData = append(requestData, 0) |
|||
|
|||
// Timeout (4 bytes)
|
|||
requestData = append(requestData, 0, 0, 0x27, 0x10) // 10000ms
|
|||
|
|||
// Validate only (1 byte)
|
|||
requestData = append(requestData, 0) |
|||
|
|||
// Tagged fields at end
|
|||
requestData = append(requestData, 0) |
|||
|
|||
// Call CreateTopics v2
|
|||
response, err := handler.handleCreateTopicsV2Plus(correlationID, 2, requestData) |
|||
if err != nil { |
|||
t.Fatalf("handleCreateTopicsV2Plus failed: %v", err) |
|||
} |
|||
|
|||
if len(response) < 8 { |
|||
t.Fatalf("Response too short: %d bytes", len(response)) |
|||
} |
|||
|
|||
// Check correlation ID
|
|||
respCorrelationID := binary.BigEndian.Uint32(response[0:4]) |
|||
if respCorrelationID != correlationID { |
|||
t.Errorf("Correlation ID = %d, want %d", respCorrelationID, correlationID) |
|||
} |
|||
|
|||
// Verify topic was created
|
|||
if !handler.seaweedMQHandler.TopicExists(topicName) { |
|||
t.Errorf("Topic '%s' was not created", topicName) |
|||
} |
|||
|
|||
t.Logf("CreateTopics v2 with flexible utilities: topic created successfully") |
|||
} |
|||
|
|||
func BenchmarkFlexibleVersions_HeaderParsing(b *testing.B) { |
|||
// Pre-construct headers for different scenarios
|
|||
scenarios := []struct { |
|||
name string |
|||
data []byte |
|||
}{ |
|||
{ |
|||
name: "Regular_v1", |
|||
data: func() []byte { |
|||
var data []byte |
|||
data = append(data, 0, 3, 0, 1) // Metadata v1
|
|||
corrBytes := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(corrBytes, 12345) |
|||
data = append(data, corrBytes...) |
|||
data = append(data, 0, 11, 'b', 'e', 'n', 'c', 'h', '-', 'c', 'l', 'i', 'e', 'n', 't') |
|||
data = append(data, 1, 2, 3) |
|||
return data |
|||
}(), |
|||
}, |
|||
{ |
|||
name: "Flexible_v9", |
|||
data: func() []byte { |
|||
var data []byte |
|||
data = append(data, 0, 3, 0, 9) // Metadata v9
|
|||
corrBytes := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(corrBytes, 12345) |
|||
data = append(data, corrBytes...) |
|||
data = append(data, FlexibleString("bench-client")...) |
|||
data = append(data, 0) // empty tagged fields
|
|||
data = append(data, 1, 2, 3) |
|||
return data |
|||
}(), |
|||
}, |
|||
} |
|||
|
|||
for _, scenario := range scenarios { |
|||
b.Run(scenario.name, func(b *testing.B) { |
|||
b.ResetTimer() |
|||
for i := 0; i < b.N; i++ { |
|||
_, _, err := ParseRequestHeader(scenario.data) |
|||
if err != nil { |
|||
b.Fatalf("ParseRequestHeader failed: %v", err) |
|||
} |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
@ -0,0 +1,486 @@ |
|||
package protocol |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"fmt" |
|||
"testing" |
|||
) |
|||
|
|||
func TestEncodeDecodeUvarint(t *testing.T) { |
|||
testCases := []uint32{ |
|||
0, 1, 127, 128, 255, 256, 16383, 16384, 32767, 32768, 65535, 65536, |
|||
0x1FFFFF, 0x200000, 0x0FFFFFFF, 0x10000000, 0xFFFFFFFF, |
|||
} |
|||
|
|||
for _, value := range testCases { |
|||
t.Run(fmt.Sprintf("value_%d", value), func(t *testing.T) { |
|||
encoded := EncodeUvarint(value) |
|||
decoded, consumed, err := DecodeUvarint(encoded) |
|||
|
|||
if err != nil { |
|||
t.Fatalf("DecodeUvarint failed: %v", err) |
|||
} |
|||
|
|||
if decoded != value { |
|||
t.Errorf("Decoded value %d != original %d", decoded, value) |
|||
} |
|||
|
|||
if consumed != len(encoded) { |
|||
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestCompactArrayLength(t *testing.T) { |
|||
testCases := []struct { |
|||
name string |
|||
length uint32 |
|||
expected []byte |
|||
}{ |
|||
{"Empty array", 0, []byte{0}}, |
|||
{"Single element", 1, []byte{2}}, |
|||
{"Small array", 10, []byte{11}}, |
|||
{"Large array", 127, []byte{128, 1}}, // 128 = 127+1 encoded as varint (two bytes since >= 128)
|
|||
} |
|||
|
|||
for _, tc := range testCases { |
|||
t.Run(tc.name, func(t *testing.T) { |
|||
encoded := CompactArrayLength(tc.length) |
|||
if !bytes.Equal(encoded, tc.expected) { |
|||
t.Errorf("CompactArrayLength(%d) = %v, want %v", tc.length, encoded, tc.expected) |
|||
} |
|||
|
|||
// Test round trip
|
|||
decoded, consumed, err := DecodeCompactArrayLength(encoded) |
|||
if err != nil { |
|||
t.Fatalf("DecodeCompactArrayLength failed: %v", err) |
|||
} |
|||
|
|||
if decoded != tc.length { |
|||
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length) |
|||
} |
|||
|
|||
if consumed != len(encoded) { |
|||
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestCompactStringLength(t *testing.T) { |
|||
testCases := []struct { |
|||
name string |
|||
length int |
|||
expected []byte |
|||
}{ |
|||
{"Null string", -1, []byte{0}}, |
|||
{"Empty string", 0, []byte{1}}, |
|||
{"Short string", 5, []byte{6}}, |
|||
{"Medium string", 100, []byte{101}}, // 101 encoded as varint (single byte since < 128)
|
|||
} |
|||
|
|||
for _, tc := range testCases { |
|||
t.Run(tc.name, func(t *testing.T) { |
|||
encoded := CompactStringLength(tc.length) |
|||
if !bytes.Equal(encoded, tc.expected) { |
|||
t.Errorf("CompactStringLength(%d) = %v, want %v", tc.length, encoded, tc.expected) |
|||
} |
|||
|
|||
// Test round trip
|
|||
decoded, consumed, err := DecodeCompactStringLength(encoded) |
|||
if err != nil { |
|||
t.Fatalf("DecodeCompactStringLength failed: %v", err) |
|||
} |
|||
|
|||
if decoded != tc.length { |
|||
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length) |
|||
} |
|||
|
|||
if consumed != len(encoded) { |
|||
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestFlexibleString(t *testing.T) { |
|||
testCases := []struct { |
|||
name string |
|||
input string |
|||
expected []byte |
|||
}{ |
|||
{"Empty string", "", []byte{0}}, |
|||
{"Hello", "hello", []byte{6, 'h', 'e', 'l', 'l', 'o'}}, |
|||
{"Unicode", "测试", []byte{7, 0xE6, 0xB5, 0x8B, 0xE8, 0xAF, 0x95}}, // UTF-8 encoded
|
|||
} |
|||
|
|||
for _, tc := range testCases { |
|||
t.Run(tc.name, func(t *testing.T) { |
|||
encoded := FlexibleString(tc.input) |
|||
if !bytes.Equal(encoded, tc.expected) { |
|||
t.Errorf("FlexibleString(%q) = %v, want %v", tc.input, encoded, tc.expected) |
|||
} |
|||
|
|||
// Test round trip
|
|||
decoded, consumed, err := DecodeFlexibleString(encoded) |
|||
if err != nil { |
|||
t.Fatalf("DecodeFlexibleString failed: %v", err) |
|||
} |
|||
|
|||
if decoded != tc.input { |
|||
t.Errorf("Round trip failed: got %q, want %q", decoded, tc.input) |
|||
} |
|||
|
|||
if consumed != len(encoded) { |
|||
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestFlexibleNullableString(t *testing.T) { |
|||
// Null string
|
|||
nullResult := FlexibleNullableString(nil) |
|||
expected := []byte{0} |
|||
if !bytes.Equal(nullResult, expected) { |
|||
t.Errorf("FlexibleNullableString(nil) = %v, want %v", nullResult, expected) |
|||
} |
|||
|
|||
// Non-null string
|
|||
str := "test" |
|||
nonNullResult := FlexibleNullableString(&str) |
|||
expectedNonNull := []byte{5, 't', 'e', 's', 't'} |
|||
if !bytes.Equal(nonNullResult, expectedNonNull) { |
|||
t.Errorf("FlexibleNullableString(&%q) = %v, want %v", str, nonNullResult, expectedNonNull) |
|||
} |
|||
} |
|||
|
|||
func TestTaggedFields(t *testing.T) { |
|||
t.Run("Empty tagged fields", func(t *testing.T) { |
|||
tf := &TaggedFields{} |
|||
encoded := tf.Encode() |
|||
expected := []byte{0} |
|||
|
|||
if !bytes.Equal(encoded, expected) { |
|||
t.Errorf("Empty TaggedFields.Encode() = %v, want %v", encoded, expected) |
|||
} |
|||
|
|||
// Test round trip
|
|||
decoded, consumed, err := DecodeTaggedFields(encoded) |
|||
if err != nil { |
|||
t.Fatalf("DecodeTaggedFields failed: %v", err) |
|||
} |
|||
|
|||
if len(decoded.Fields) != 0 { |
|||
t.Errorf("Decoded tagged fields length = %d, want 0", len(decoded.Fields)) |
|||
} |
|||
|
|||
if consumed != len(encoded) { |
|||
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) |
|||
} |
|||
}) |
|||
|
|||
t.Run("Single tagged field", func(t *testing.T) { |
|||
tf := &TaggedFields{ |
|||
Fields: []TaggedField{ |
|||
{Tag: 1, Data: []byte("test")}, |
|||
}, |
|||
} |
|||
|
|||
encoded := tf.Encode() |
|||
|
|||
// Test round trip
|
|||
decoded, consumed, err := DecodeTaggedFields(encoded) |
|||
if err != nil { |
|||
t.Fatalf("DecodeTaggedFields failed: %v", err) |
|||
} |
|||
|
|||
if len(decoded.Fields) != 1 { |
|||
t.Fatalf("Decoded tagged fields length = %d, want 1", len(decoded.Fields)) |
|||
} |
|||
|
|||
field := decoded.Fields[0] |
|||
if field.Tag != 1 { |
|||
t.Errorf("Decoded tag = %d, want 1", field.Tag) |
|||
} |
|||
|
|||
if !bytes.Equal(field.Data, []byte("test")) { |
|||
t.Errorf("Decoded data = %v, want %v", field.Data, []byte("test")) |
|||
} |
|||
|
|||
if consumed != len(encoded) { |
|||
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) |
|||
} |
|||
}) |
|||
|
|||
t.Run("Multiple tagged fields", func(t *testing.T) { |
|||
tf := &TaggedFields{ |
|||
Fields: []TaggedField{ |
|||
{Tag: 1, Data: []byte("first")}, |
|||
{Tag: 5, Data: []byte("second")}, |
|||
}, |
|||
} |
|||
|
|||
encoded := tf.Encode() |
|||
|
|||
// Test round trip
|
|||
decoded, consumed, err := DecodeTaggedFields(encoded) |
|||
if err != nil { |
|||
t.Fatalf("DecodeTaggedFields failed: %v", err) |
|||
} |
|||
|
|||
if len(decoded.Fields) != 2 { |
|||
t.Fatalf("Decoded tagged fields length = %d, want 2", len(decoded.Fields)) |
|||
} |
|||
|
|||
// Check first field
|
|||
field1 := decoded.Fields[0] |
|||
if field1.Tag != 1 { |
|||
t.Errorf("Decoded field 1 tag = %d, want 1", field1.Tag) |
|||
} |
|||
if !bytes.Equal(field1.Data, []byte("first")) { |
|||
t.Errorf("Decoded field 1 data = %v, want %v", field1.Data, []byte("first")) |
|||
} |
|||
|
|||
// Check second field
|
|||
field2 := decoded.Fields[1] |
|||
if field2.Tag != 5 { |
|||
t.Errorf("Decoded field 2 tag = %d, want 5", field2.Tag) |
|||
} |
|||
if !bytes.Equal(field2.Data, []byte("second")) { |
|||
t.Errorf("Decoded field 2 data = %v, want %v", field2.Data, []byte("second")) |
|||
} |
|||
|
|||
if consumed != len(encoded) { |
|||
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded)) |
|||
} |
|||
}) |
|||
} |
|||
|
|||
func TestIsFlexibleVersion(t *testing.T) { |
|||
testCases := []struct { |
|||
apiKey uint16 |
|||
apiVersion uint16 |
|||
expected bool |
|||
name string |
|||
}{ |
|||
// ApiVersions
|
|||
{18, 2, false, "ApiVersions v2"}, |
|||
{18, 3, true, "ApiVersions v3"}, |
|||
{18, 4, true, "ApiVersions v4"}, |
|||
|
|||
// Metadata
|
|||
{3, 8, false, "Metadata v8"}, |
|||
{3, 9, true, "Metadata v9"}, |
|||
{3, 10, true, "Metadata v10"}, |
|||
|
|||
// Fetch
|
|||
{1, 11, false, "Fetch v11"}, |
|||
{1, 12, true, "Fetch v12"}, |
|||
{1, 13, true, "Fetch v13"}, |
|||
|
|||
// Produce
|
|||
{0, 8, false, "Produce v8"}, |
|||
{0, 9, true, "Produce v9"}, |
|||
{0, 10, true, "Produce v10"}, |
|||
|
|||
// CreateTopics
|
|||
{19, 1, false, "CreateTopics v1"}, |
|||
{19, 2, true, "CreateTopics v2"}, |
|||
{19, 3, true, "CreateTopics v3"}, |
|||
|
|||
// Unknown API
|
|||
{99, 1, false, "Unknown API"}, |
|||
} |
|||
|
|||
for _, tc := range testCases { |
|||
t.Run(tc.name, func(t *testing.T) { |
|||
result := IsFlexibleVersion(tc.apiKey, tc.apiVersion) |
|||
if result != tc.expected { |
|||
t.Errorf("IsFlexibleVersion(%d, %d) = %v, want %v", |
|||
tc.apiKey, tc.apiVersion, result, tc.expected) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestParseRequestHeader(t *testing.T) { |
|||
t.Run("Regular version header", func(t *testing.T) { |
|||
// Construct a regular version header (Metadata v1)
|
|||
data := make([]byte, 0) |
|||
data = append(data, 0, 3) // API Key = 3 (Metadata)
|
|||
data = append(data, 0, 1) // API Version = 1
|
|||
data = append(data, 0, 0, 0, 123) // Correlation ID = 123
|
|||
data = append(data, 0, 4) // Client ID length = 4
|
|||
data = append(data, 't', 'e', 's', 't') // Client ID = "test"
|
|||
data = append(data, 1, 2, 3) // Request body
|
|||
|
|||
header, body, err := ParseRequestHeader(data) |
|||
if err != nil { |
|||
t.Fatalf("ParseRequestHeader failed: %v", err) |
|||
} |
|||
|
|||
if header.APIKey != 3 { |
|||
t.Errorf("APIKey = %d, want 3", header.APIKey) |
|||
} |
|||
if header.APIVersion != 1 { |
|||
t.Errorf("APIVersion = %d, want 1", header.APIVersion) |
|||
} |
|||
if header.CorrelationID != 123 { |
|||
t.Errorf("CorrelationID = %d, want 123", header.CorrelationID) |
|||
} |
|||
if header.ClientID == nil || *header.ClientID != "test" { |
|||
t.Errorf("ClientID = %v, want 'test'", header.ClientID) |
|||
} |
|||
if header.TaggedFields != nil { |
|||
t.Errorf("TaggedFields should be nil for regular versions") |
|||
} |
|||
|
|||
expectedBody := []byte{1, 2, 3} |
|||
if !bytes.Equal(body, expectedBody) { |
|||
t.Errorf("Body = %v, want %v", body, expectedBody) |
|||
} |
|||
}) |
|||
|
|||
t.Run("Flexible version header", func(t *testing.T) { |
|||
// Construct a flexible version header (ApiVersions v3)
|
|||
data := make([]byte, 0) |
|||
data = append(data, 0, 18) // API Key = 18 (ApiVersions)
|
|||
data = append(data, 0, 3) // API Version = 3 (flexible)
|
|||
|
|||
// Correlation ID = 456 (4 bytes, big endian)
|
|||
correlationID := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(correlationID, 456) |
|||
data = append(data, correlationID...) |
|||
|
|||
data = append(data, 5, 't', 'e', 's', 't') // Client ID = "test" (compact string)
|
|||
data = append(data, 0) // Empty tagged fields
|
|||
data = append(data, 4, 5, 6) // Request body
|
|||
|
|||
header, body, err := ParseRequestHeader(data) |
|||
if err != nil { |
|||
t.Fatalf("ParseRequestHeader failed: %v", err) |
|||
} |
|||
|
|||
if header.APIKey != 18 { |
|||
t.Errorf("APIKey = %d, want 18", header.APIKey) |
|||
} |
|||
if header.APIVersion != 3 { |
|||
t.Errorf("APIVersion = %d, want 3", header.APIVersion) |
|||
} |
|||
if header.CorrelationID != 456 { |
|||
t.Errorf("CorrelationID = %d, want 456", header.CorrelationID) |
|||
} |
|||
if header.ClientID == nil || *header.ClientID != "test" { |
|||
t.Errorf("ClientID = %v, want 'test'", header.ClientID) |
|||
} |
|||
if header.TaggedFields == nil { |
|||
t.Errorf("TaggedFields should not be nil for flexible versions") |
|||
} |
|||
if len(header.TaggedFields.Fields) != 0 { |
|||
t.Errorf("TaggedFields should be empty") |
|||
} |
|||
|
|||
expectedBody := []byte{4, 5, 6} |
|||
if !bytes.Equal(body, expectedBody) { |
|||
t.Errorf("Body = %v, want %v", body, expectedBody) |
|||
} |
|||
}) |
|||
|
|||
t.Run("Null client ID", func(t *testing.T) { |
|||
// Regular version with null client ID
|
|||
data := make([]byte, 0) |
|||
data = append(data, 0, 3) // API Key = 3 (Metadata)
|
|||
data = append(data, 0, 1) // API Version = 1
|
|||
|
|||
// Correlation ID = 789 (4 bytes, big endian)
|
|||
correlationID := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(correlationID, 789) |
|||
data = append(data, correlationID...) |
|||
|
|||
data = append(data, 0xFF, 0xFF) // Client ID length = -1 (null)
|
|||
data = append(data, 7, 8, 9) // Request body
|
|||
|
|||
header, body, err := ParseRequestHeader(data) |
|||
if err != nil { |
|||
t.Fatalf("ParseRequestHeader failed: %v", err) |
|||
} |
|||
|
|||
if header.ClientID != nil { |
|||
t.Errorf("ClientID = %v, want nil", header.ClientID) |
|||
} |
|||
|
|||
expectedBody := []byte{7, 8, 9} |
|||
if !bytes.Equal(body, expectedBody) { |
|||
t.Errorf("Body = %v, want %v", body, expectedBody) |
|||
} |
|||
}) |
|||
} |
|||
|
|||
func TestEncodeFlexibleResponse(t *testing.T) { |
|||
correlationID := uint32(123) |
|||
data := []byte{1, 2, 3, 4} |
|||
|
|||
t.Run("Without tagged fields", func(t *testing.T) { |
|||
result := EncodeFlexibleResponse(correlationID, data, false) |
|||
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4} |
|||
|
|||
if !bytes.Equal(result, expected) { |
|||
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected) |
|||
} |
|||
}) |
|||
|
|||
t.Run("With tagged fields", func(t *testing.T) { |
|||
result := EncodeFlexibleResponse(correlationID, data, true) |
|||
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4, 0} // 0 at end for empty tagged fields
|
|||
|
|||
if !bytes.Equal(result, expected) { |
|||
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected) |
|||
} |
|||
}) |
|||
} |
|||
|
|||
func BenchmarkEncodeUvarint(b *testing.B) { |
|||
testValues := []uint32{0, 127, 128, 16383, 16384, 65535, 65536, 0xFFFFFFFF} |
|||
|
|||
b.ResetTimer() |
|||
for i := 0; i < b.N; i++ { |
|||
for _, val := range testValues { |
|||
EncodeUvarint(val) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func BenchmarkDecodeUvarint(b *testing.B) { |
|||
// Pre-encode test values
|
|||
testData := [][]byte{ |
|||
EncodeUvarint(0), |
|||
EncodeUvarint(127), |
|||
EncodeUvarint(128), |
|||
EncodeUvarint(16383), |
|||
EncodeUvarint(16384), |
|||
EncodeUvarint(65535), |
|||
EncodeUvarint(65536), |
|||
EncodeUvarint(0xFFFFFFFF), |
|||
} |
|||
|
|||
b.ResetTimer() |
|||
for i := 0; i < b.N; i++ { |
|||
for _, data := range testData { |
|||
DecodeUvarint(data) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func BenchmarkFlexibleString(b *testing.B) { |
|||
testStrings := []string{"", "a", "hello", "this is a longer test string", "测试中文字符串"} |
|||
|
|||
b.ResetTimer() |
|||
for i := 0; i < b.N; i++ { |
|||
for _, s := range testStrings { |
|||
FlexibleString(s) |
|||
} |
|||
} |
|||
} |
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue