You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

486 lines
13 KiB

package protocol
import (
"bytes"
"encoding/binary"
"fmt"
"testing"
)
func TestEncodeDecodeUvarint(t *testing.T) {
testCases := []uint32{
0, 1, 127, 128, 255, 256, 16383, 16384, 32767, 32768, 65535, 65536,
0x1FFFFF, 0x200000, 0x0FFFFFFF, 0x10000000, 0xFFFFFFFF,
}
for _, value := range testCases {
t.Run(fmt.Sprintf("value_%d", value), func(t *testing.T) {
encoded := EncodeUvarint(value)
decoded, consumed, err := DecodeUvarint(encoded)
if err != nil {
t.Fatalf("DecodeUvarint failed: %v", err)
}
if decoded != value {
t.Errorf("Decoded value %d != original %d", decoded, value)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestCompactArrayLength(t *testing.T) {
testCases := []struct {
name string
length uint32
expected []byte
}{
{"Empty array", 0, []byte{0}},
{"Single element", 1, []byte{2}},
{"Small array", 10, []byte{11}},
{"Large array", 127, []byte{128, 1}}, // 128 = 127+1 encoded as varint (two bytes since >= 128)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := CompactArrayLength(tc.length)
if !bytes.Equal(encoded, tc.expected) {
t.Errorf("CompactArrayLength(%d) = %v, want %v", tc.length, encoded, tc.expected)
}
// Test round trip
decoded, consumed, err := DecodeCompactArrayLength(encoded)
if err != nil {
t.Fatalf("DecodeCompactArrayLength failed: %v", err)
}
if decoded != tc.length {
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestCompactStringLength(t *testing.T) {
testCases := []struct {
name string
length int
expected []byte
}{
{"Null string", -1, []byte{0}},
{"Empty string", 0, []byte{1}},
{"Short string", 5, []byte{6}},
{"Medium string", 100, []byte{101}}, // 101 encoded as varint (single byte since < 128)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := CompactStringLength(tc.length)
if !bytes.Equal(encoded, tc.expected) {
t.Errorf("CompactStringLength(%d) = %v, want %v", tc.length, encoded, tc.expected)
}
// Test round trip
decoded, consumed, err := DecodeCompactStringLength(encoded)
if err != nil {
t.Fatalf("DecodeCompactStringLength failed: %v", err)
}
if decoded != tc.length {
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestFlexibleString(t *testing.T) {
testCases := []struct {
name string
input string
expected []byte
}{
{"Empty string", "", []byte{0}},
{"Hello", "hello", []byte{6, 'h', 'e', 'l', 'l', 'o'}},
{"Unicode", "测试", []byte{7, 0xE6, 0xB5, 0x8B, 0xE8, 0xAF, 0x95}}, // UTF-8 encoded
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := FlexibleString(tc.input)
if !bytes.Equal(encoded, tc.expected) {
t.Errorf("FlexibleString(%q) = %v, want %v", tc.input, encoded, tc.expected)
}
// Test round trip
decoded, consumed, err := DecodeFlexibleString(encoded)
if err != nil {
t.Fatalf("DecodeFlexibleString failed: %v", err)
}
if decoded != tc.input {
t.Errorf("Round trip failed: got %q, want %q", decoded, tc.input)
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
}
func TestFlexibleNullableString(t *testing.T) {
// Null string
nullResult := FlexibleNullableString(nil)
expected := []byte{0}
if !bytes.Equal(nullResult, expected) {
t.Errorf("FlexibleNullableString(nil) = %v, want %v", nullResult, expected)
}
// Non-null string
str := "test"
nonNullResult := FlexibleNullableString(&str)
expectedNonNull := []byte{5, 't', 'e', 's', 't'}
if !bytes.Equal(nonNullResult, expectedNonNull) {
t.Errorf("FlexibleNullableString(&%q) = %v, want %v", str, nonNullResult, expectedNonNull)
}
}
func TestTaggedFields(t *testing.T) {
t.Run("Empty tagged fields", func(t *testing.T) {
tf := &TaggedFields{}
encoded := tf.Encode()
expected := []byte{0}
if !bytes.Equal(encoded, expected) {
t.Errorf("Empty TaggedFields.Encode() = %v, want %v", encoded, expected)
}
// Test round trip
decoded, consumed, err := DecodeTaggedFields(encoded)
if err != nil {
t.Fatalf("DecodeTaggedFields failed: %v", err)
}
if len(decoded.Fields) != 0 {
t.Errorf("Decoded tagged fields length = %d, want 0", len(decoded.Fields))
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
t.Run("Single tagged field", func(t *testing.T) {
tf := &TaggedFields{
Fields: []TaggedField{
{Tag: 1, Data: []byte("test")},
},
}
encoded := tf.Encode()
// Test round trip
decoded, consumed, err := DecodeTaggedFields(encoded)
if err != nil {
t.Fatalf("DecodeTaggedFields failed: %v", err)
}
if len(decoded.Fields) != 1 {
t.Fatalf("Decoded tagged fields length = %d, want 1", len(decoded.Fields))
}
field := decoded.Fields[0]
if field.Tag != 1 {
t.Errorf("Decoded tag = %d, want 1", field.Tag)
}
if !bytes.Equal(field.Data, []byte("test")) {
t.Errorf("Decoded data = %v, want %v", field.Data, []byte("test"))
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
t.Run("Multiple tagged fields", func(t *testing.T) {
tf := &TaggedFields{
Fields: []TaggedField{
{Tag: 1, Data: []byte("first")},
{Tag: 5, Data: []byte("second")},
},
}
encoded := tf.Encode()
// Test round trip
decoded, consumed, err := DecodeTaggedFields(encoded)
if err != nil {
t.Fatalf("DecodeTaggedFields failed: %v", err)
}
if len(decoded.Fields) != 2 {
t.Fatalf("Decoded tagged fields length = %d, want 2", len(decoded.Fields))
}
// Check first field
field1 := decoded.Fields[0]
if field1.Tag != 1 {
t.Errorf("Decoded field 1 tag = %d, want 1", field1.Tag)
}
if !bytes.Equal(field1.Data, []byte("first")) {
t.Errorf("Decoded field 1 data = %v, want %v", field1.Data, []byte("first"))
}
// Check second field
field2 := decoded.Fields[1]
if field2.Tag != 5 {
t.Errorf("Decoded field 2 tag = %d, want 5", field2.Tag)
}
if !bytes.Equal(field2.Data, []byte("second")) {
t.Errorf("Decoded field 2 data = %v, want %v", field2.Data, []byte("second"))
}
if consumed != len(encoded) {
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
}
})
}
func TestIsFlexibleVersion(t *testing.T) {
testCases := []struct {
apiKey uint16
apiVersion uint16
expected bool
name string
}{
// ApiVersions
{18, 2, false, "ApiVersions v2"},
{18, 3, true, "ApiVersions v3"},
{18, 4, true, "ApiVersions v4"},
// Metadata
{3, 8, false, "Metadata v8"},
{3, 9, true, "Metadata v9"},
{3, 10, true, "Metadata v10"},
// Fetch
{1, 11, false, "Fetch v11"},
{1, 12, true, "Fetch v12"},
{1, 13, true, "Fetch v13"},
// Produce
{0, 8, false, "Produce v8"},
{0, 9, true, "Produce v9"},
{0, 10, true, "Produce v10"},
// CreateTopics
{19, 1, false, "CreateTopics v1"},
{19, 2, true, "CreateTopics v2"},
{19, 3, true, "CreateTopics v3"},
// Unknown API
{99, 1, false, "Unknown API"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsFlexibleVersion(tc.apiKey, tc.apiVersion)
if result != tc.expected {
t.Errorf("IsFlexibleVersion(%d, %d) = %v, want %v",
tc.apiKey, tc.apiVersion, result, tc.expected)
}
})
}
}
func TestParseRequestHeader(t *testing.T) {
t.Run("Regular version header", func(t *testing.T) {
// Construct a regular version header (Metadata v1)
data := make([]byte, 0)
data = append(data, 0, 3) // API Key = 3 (Metadata)
data = append(data, 0, 1) // API Version = 1
data = append(data, 0, 0, 0, 123) // Correlation ID = 123
data = append(data, 0, 4) // Client ID length = 4
data = append(data, 't', 'e', 's', 't') // Client ID = "test"
data = append(data, 1, 2, 3) // Request body
header, body, err := ParseRequestHeader(data)
if err != nil {
t.Fatalf("ParseRequestHeader failed: %v", err)
}
if header.APIKey != 3 {
t.Errorf("APIKey = %d, want 3", header.APIKey)
}
if header.APIVersion != 1 {
t.Errorf("APIVersion = %d, want 1", header.APIVersion)
}
if header.CorrelationID != 123 {
t.Errorf("CorrelationID = %d, want 123", header.CorrelationID)
}
if header.ClientID == nil || *header.ClientID != "test" {
t.Errorf("ClientID = %v, want 'test'", header.ClientID)
}
if header.TaggedFields != nil {
t.Errorf("TaggedFields should be nil for regular versions")
}
expectedBody := []byte{1, 2, 3}
if !bytes.Equal(body, expectedBody) {
t.Errorf("Body = %v, want %v", body, expectedBody)
}
})
t.Run("Flexible version header", func(t *testing.T) {
// Construct a flexible version header (ApiVersions v3)
data := make([]byte, 0)
data = append(data, 0, 18) // API Key = 18 (ApiVersions)
data = append(data, 0, 3) // API Version = 3 (flexible)
// Correlation ID = 456 (4 bytes, big endian)
correlationID := make([]byte, 4)
binary.BigEndian.PutUint32(correlationID, 456)
data = append(data, correlationID...)
data = append(data, 5, 't', 'e', 's', 't') // Client ID = "test" (compact string)
data = append(data, 0) // Empty tagged fields
data = append(data, 4, 5, 6) // Request body
header, body, err := ParseRequestHeader(data)
if err != nil {
t.Fatalf("ParseRequestHeader failed: %v", err)
}
if header.APIKey != 18 {
t.Errorf("APIKey = %d, want 18", header.APIKey)
}
if header.APIVersion != 3 {
t.Errorf("APIVersion = %d, want 3", header.APIVersion)
}
if header.CorrelationID != 456 {
t.Errorf("CorrelationID = %d, want 456", header.CorrelationID)
}
if header.ClientID == nil || *header.ClientID != "test" {
t.Errorf("ClientID = %v, want 'test'", header.ClientID)
}
if header.TaggedFields == nil {
t.Errorf("TaggedFields should not be nil for flexible versions")
}
if len(header.TaggedFields.Fields) != 0 {
t.Errorf("TaggedFields should be empty")
}
expectedBody := []byte{4, 5, 6}
if !bytes.Equal(body, expectedBody) {
t.Errorf("Body = %v, want %v", body, expectedBody)
}
})
t.Run("Null client ID", func(t *testing.T) {
// Regular version with null client ID
data := make([]byte, 0)
data = append(data, 0, 3) // API Key = 3 (Metadata)
data = append(data, 0, 1) // API Version = 1
// Correlation ID = 789 (4 bytes, big endian)
correlationID := make([]byte, 4)
binary.BigEndian.PutUint32(correlationID, 789)
data = append(data, correlationID...)
data = append(data, 0xFF, 0xFF) // Client ID length = -1 (null)
data = append(data, 7, 8, 9) // Request body
header, body, err := ParseRequestHeader(data)
if err != nil {
t.Fatalf("ParseRequestHeader failed: %v", err)
}
if header.ClientID != nil {
t.Errorf("ClientID = %v, want nil", header.ClientID)
}
expectedBody := []byte{7, 8, 9}
if !bytes.Equal(body, expectedBody) {
t.Errorf("Body = %v, want %v", body, expectedBody)
}
})
}
func TestEncodeFlexibleResponse(t *testing.T) {
correlationID := uint32(123)
data := []byte{1, 2, 3, 4}
t.Run("Without tagged fields", func(t *testing.T) {
result := EncodeFlexibleResponse(correlationID, data, false)
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4}
if !bytes.Equal(result, expected) {
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected)
}
})
t.Run("With tagged fields", func(t *testing.T) {
result := EncodeFlexibleResponse(correlationID, data, true)
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4, 0} // 0 at end for empty tagged fields
if !bytes.Equal(result, expected) {
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected)
}
})
}
func BenchmarkEncodeUvarint(b *testing.B) {
testValues := []uint32{0, 127, 128, 16383, 16384, 65535, 65536, 0xFFFFFFFF}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, val := range testValues {
EncodeUvarint(val)
}
}
}
func BenchmarkDecodeUvarint(b *testing.B) {
// Pre-encode test values
testData := [][]byte{
EncodeUvarint(0),
EncodeUvarint(127),
EncodeUvarint(128),
EncodeUvarint(16383),
EncodeUvarint(16384),
EncodeUvarint(65535),
EncodeUvarint(65536),
EncodeUvarint(0xFFFFFFFF),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, data := range testData {
DecodeUvarint(data)
}
}
}
func BenchmarkFlexibleString(b *testing.B) {
testStrings := []string{"", "a", "hello", "this is a longer test string", "测试中文字符串"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, s := range testStrings {
FlexibleString(s)
}
}
}