You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
486 lines
13 KiB
486 lines
13 KiB
package protocol
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"testing"
|
|
)
|
|
|
|
func TestEncodeDecodeUvarint(t *testing.T) {
|
|
testCases := []uint32{
|
|
0, 1, 127, 128, 255, 256, 16383, 16384, 32767, 32768, 65535, 65536,
|
|
0x1FFFFF, 0x200000, 0x0FFFFFFF, 0x10000000, 0xFFFFFFFF,
|
|
}
|
|
|
|
for _, value := range testCases {
|
|
t.Run(fmt.Sprintf("value_%d", value), func(t *testing.T) {
|
|
encoded := EncodeUvarint(value)
|
|
decoded, consumed, err := DecodeUvarint(encoded)
|
|
|
|
if err != nil {
|
|
t.Fatalf("DecodeUvarint failed: %v", err)
|
|
}
|
|
|
|
if decoded != value {
|
|
t.Errorf("Decoded value %d != original %d", decoded, value)
|
|
}
|
|
|
|
if consumed != len(encoded) {
|
|
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCompactArrayLength(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
length uint32
|
|
expected []byte
|
|
}{
|
|
{"Empty array", 0, []byte{0}},
|
|
{"Single element", 1, []byte{2}},
|
|
{"Small array", 10, []byte{11}},
|
|
{"Large array", 127, []byte{128, 1}}, // 128 = 127+1 encoded as varint (two bytes since >= 128)
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
encoded := CompactArrayLength(tc.length)
|
|
if !bytes.Equal(encoded, tc.expected) {
|
|
t.Errorf("CompactArrayLength(%d) = %v, want %v", tc.length, encoded, tc.expected)
|
|
}
|
|
|
|
// Test round trip
|
|
decoded, consumed, err := DecodeCompactArrayLength(encoded)
|
|
if err != nil {
|
|
t.Fatalf("DecodeCompactArrayLength failed: %v", err)
|
|
}
|
|
|
|
if decoded != tc.length {
|
|
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length)
|
|
}
|
|
|
|
if consumed != len(encoded) {
|
|
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCompactStringLength(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
length int
|
|
expected []byte
|
|
}{
|
|
{"Null string", -1, []byte{0}},
|
|
{"Empty string", 0, []byte{1}},
|
|
{"Short string", 5, []byte{6}},
|
|
{"Medium string", 100, []byte{101}}, // 101 encoded as varint (single byte since < 128)
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
encoded := CompactStringLength(tc.length)
|
|
if !bytes.Equal(encoded, tc.expected) {
|
|
t.Errorf("CompactStringLength(%d) = %v, want %v", tc.length, encoded, tc.expected)
|
|
}
|
|
|
|
// Test round trip
|
|
decoded, consumed, err := DecodeCompactStringLength(encoded)
|
|
if err != nil {
|
|
t.Fatalf("DecodeCompactStringLength failed: %v", err)
|
|
}
|
|
|
|
if decoded != tc.length {
|
|
t.Errorf("Round trip failed: got %d, want %d", decoded, tc.length)
|
|
}
|
|
|
|
if consumed != len(encoded) {
|
|
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFlexibleString(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expected []byte
|
|
}{
|
|
{"Empty string", "", []byte{0}},
|
|
{"Hello", "hello", []byte{6, 'h', 'e', 'l', 'l', 'o'}},
|
|
{"Unicode", "测试", []byte{7, 0xE6, 0xB5, 0x8B, 0xE8, 0xAF, 0x95}}, // UTF-8 encoded
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
encoded := FlexibleString(tc.input)
|
|
if !bytes.Equal(encoded, tc.expected) {
|
|
t.Errorf("FlexibleString(%q) = %v, want %v", tc.input, encoded, tc.expected)
|
|
}
|
|
|
|
// Test round trip
|
|
decoded, consumed, err := DecodeFlexibleString(encoded)
|
|
if err != nil {
|
|
t.Fatalf("DecodeFlexibleString failed: %v", err)
|
|
}
|
|
|
|
if decoded != tc.input {
|
|
t.Errorf("Round trip failed: got %q, want %q", decoded, tc.input)
|
|
}
|
|
|
|
if consumed != len(encoded) {
|
|
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFlexibleNullableString(t *testing.T) {
|
|
// Null string
|
|
nullResult := FlexibleNullableString(nil)
|
|
expected := []byte{0}
|
|
if !bytes.Equal(nullResult, expected) {
|
|
t.Errorf("FlexibleNullableString(nil) = %v, want %v", nullResult, expected)
|
|
}
|
|
|
|
// Non-null string
|
|
str := "test"
|
|
nonNullResult := FlexibleNullableString(&str)
|
|
expectedNonNull := []byte{5, 't', 'e', 's', 't'}
|
|
if !bytes.Equal(nonNullResult, expectedNonNull) {
|
|
t.Errorf("FlexibleNullableString(&%q) = %v, want %v", str, nonNullResult, expectedNonNull)
|
|
}
|
|
}
|
|
|
|
func TestTaggedFields(t *testing.T) {
|
|
t.Run("Empty tagged fields", func(t *testing.T) {
|
|
tf := &TaggedFields{}
|
|
encoded := tf.Encode()
|
|
expected := []byte{0}
|
|
|
|
if !bytes.Equal(encoded, expected) {
|
|
t.Errorf("Empty TaggedFields.Encode() = %v, want %v", encoded, expected)
|
|
}
|
|
|
|
// Test round trip
|
|
decoded, consumed, err := DecodeTaggedFields(encoded)
|
|
if err != nil {
|
|
t.Fatalf("DecodeTaggedFields failed: %v", err)
|
|
}
|
|
|
|
if len(decoded.Fields) != 0 {
|
|
t.Errorf("Decoded tagged fields length = %d, want 0", len(decoded.Fields))
|
|
}
|
|
|
|
if consumed != len(encoded) {
|
|
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
|
|
}
|
|
})
|
|
|
|
t.Run("Single tagged field", func(t *testing.T) {
|
|
tf := &TaggedFields{
|
|
Fields: []TaggedField{
|
|
{Tag: 1, Data: []byte("test")},
|
|
},
|
|
}
|
|
|
|
encoded := tf.Encode()
|
|
|
|
// Test round trip
|
|
decoded, consumed, err := DecodeTaggedFields(encoded)
|
|
if err != nil {
|
|
t.Fatalf("DecodeTaggedFields failed: %v", err)
|
|
}
|
|
|
|
if len(decoded.Fields) != 1 {
|
|
t.Fatalf("Decoded tagged fields length = %d, want 1", len(decoded.Fields))
|
|
}
|
|
|
|
field := decoded.Fields[0]
|
|
if field.Tag != 1 {
|
|
t.Errorf("Decoded tag = %d, want 1", field.Tag)
|
|
}
|
|
|
|
if !bytes.Equal(field.Data, []byte("test")) {
|
|
t.Errorf("Decoded data = %v, want %v", field.Data, []byte("test"))
|
|
}
|
|
|
|
if consumed != len(encoded) {
|
|
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
|
|
}
|
|
})
|
|
|
|
t.Run("Multiple tagged fields", func(t *testing.T) {
|
|
tf := &TaggedFields{
|
|
Fields: []TaggedField{
|
|
{Tag: 1, Data: []byte("first")},
|
|
{Tag: 5, Data: []byte("second")},
|
|
},
|
|
}
|
|
|
|
encoded := tf.Encode()
|
|
|
|
// Test round trip
|
|
decoded, consumed, err := DecodeTaggedFields(encoded)
|
|
if err != nil {
|
|
t.Fatalf("DecodeTaggedFields failed: %v", err)
|
|
}
|
|
|
|
if len(decoded.Fields) != 2 {
|
|
t.Fatalf("Decoded tagged fields length = %d, want 2", len(decoded.Fields))
|
|
}
|
|
|
|
// Check first field
|
|
field1 := decoded.Fields[0]
|
|
if field1.Tag != 1 {
|
|
t.Errorf("Decoded field 1 tag = %d, want 1", field1.Tag)
|
|
}
|
|
if !bytes.Equal(field1.Data, []byte("first")) {
|
|
t.Errorf("Decoded field 1 data = %v, want %v", field1.Data, []byte("first"))
|
|
}
|
|
|
|
// Check second field
|
|
field2 := decoded.Fields[1]
|
|
if field2.Tag != 5 {
|
|
t.Errorf("Decoded field 2 tag = %d, want 5", field2.Tag)
|
|
}
|
|
if !bytes.Equal(field2.Data, []byte("second")) {
|
|
t.Errorf("Decoded field 2 data = %v, want %v", field2.Data, []byte("second"))
|
|
}
|
|
|
|
if consumed != len(encoded) {
|
|
t.Errorf("Consumed %d bytes but encoded %d bytes", consumed, len(encoded))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestIsFlexibleVersion(t *testing.T) {
|
|
testCases := []struct {
|
|
apiKey uint16
|
|
apiVersion uint16
|
|
expected bool
|
|
name string
|
|
}{
|
|
// ApiVersions
|
|
{18, 2, false, "ApiVersions v2"},
|
|
{18, 3, true, "ApiVersions v3"},
|
|
{18, 4, true, "ApiVersions v4"},
|
|
|
|
// Metadata
|
|
{3, 8, false, "Metadata v8"},
|
|
{3, 9, true, "Metadata v9"},
|
|
{3, 10, true, "Metadata v10"},
|
|
|
|
// Fetch
|
|
{1, 11, false, "Fetch v11"},
|
|
{1, 12, true, "Fetch v12"},
|
|
{1, 13, true, "Fetch v13"},
|
|
|
|
// Produce
|
|
{0, 8, false, "Produce v8"},
|
|
{0, 9, true, "Produce v9"},
|
|
{0, 10, true, "Produce v10"},
|
|
|
|
// CreateTopics
|
|
{19, 1, false, "CreateTopics v1"},
|
|
{19, 2, true, "CreateTopics v2"},
|
|
{19, 3, true, "CreateTopics v3"},
|
|
|
|
// Unknown API
|
|
{99, 1, false, "Unknown API"},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
result := IsFlexibleVersion(tc.apiKey, tc.apiVersion)
|
|
if result != tc.expected {
|
|
t.Errorf("IsFlexibleVersion(%d, %d) = %v, want %v",
|
|
tc.apiKey, tc.apiVersion, result, tc.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseRequestHeader(t *testing.T) {
|
|
t.Run("Regular version header", func(t *testing.T) {
|
|
// Construct a regular version header (Metadata v1)
|
|
data := make([]byte, 0)
|
|
data = append(data, 0, 3) // API Key = 3 (Metadata)
|
|
data = append(data, 0, 1) // API Version = 1
|
|
data = append(data, 0, 0, 0, 123) // Correlation ID = 123
|
|
data = append(data, 0, 4) // Client ID length = 4
|
|
data = append(data, 't', 'e', 's', 't') // Client ID = "test"
|
|
data = append(data, 1, 2, 3) // Request body
|
|
|
|
header, body, err := ParseRequestHeader(data)
|
|
if err != nil {
|
|
t.Fatalf("ParseRequestHeader failed: %v", err)
|
|
}
|
|
|
|
if header.APIKey != 3 {
|
|
t.Errorf("APIKey = %d, want 3", header.APIKey)
|
|
}
|
|
if header.APIVersion != 1 {
|
|
t.Errorf("APIVersion = %d, want 1", header.APIVersion)
|
|
}
|
|
if header.CorrelationID != 123 {
|
|
t.Errorf("CorrelationID = %d, want 123", header.CorrelationID)
|
|
}
|
|
if header.ClientID == nil || *header.ClientID != "test" {
|
|
t.Errorf("ClientID = %v, want 'test'", header.ClientID)
|
|
}
|
|
if header.TaggedFields != nil {
|
|
t.Errorf("TaggedFields should be nil for regular versions")
|
|
}
|
|
|
|
expectedBody := []byte{1, 2, 3}
|
|
if !bytes.Equal(body, expectedBody) {
|
|
t.Errorf("Body = %v, want %v", body, expectedBody)
|
|
}
|
|
})
|
|
|
|
t.Run("Flexible version header", func(t *testing.T) {
|
|
// Construct a flexible version header (ApiVersions v3)
|
|
data := make([]byte, 0)
|
|
data = append(data, 0, 18) // API Key = 18 (ApiVersions)
|
|
data = append(data, 0, 3) // API Version = 3 (flexible)
|
|
|
|
// Correlation ID = 456 (4 bytes, big endian)
|
|
correlationID := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(correlationID, 456)
|
|
data = append(data, correlationID...)
|
|
|
|
data = append(data, 5, 't', 'e', 's', 't') // Client ID = "test" (compact string)
|
|
data = append(data, 0) // Empty tagged fields
|
|
data = append(data, 4, 5, 6) // Request body
|
|
|
|
header, body, err := ParseRequestHeader(data)
|
|
if err != nil {
|
|
t.Fatalf("ParseRequestHeader failed: %v", err)
|
|
}
|
|
|
|
if header.APIKey != 18 {
|
|
t.Errorf("APIKey = %d, want 18", header.APIKey)
|
|
}
|
|
if header.APIVersion != 3 {
|
|
t.Errorf("APIVersion = %d, want 3", header.APIVersion)
|
|
}
|
|
if header.CorrelationID != 456 {
|
|
t.Errorf("CorrelationID = %d, want 456", header.CorrelationID)
|
|
}
|
|
if header.ClientID == nil || *header.ClientID != "test" {
|
|
t.Errorf("ClientID = %v, want 'test'", header.ClientID)
|
|
}
|
|
if header.TaggedFields == nil {
|
|
t.Errorf("TaggedFields should not be nil for flexible versions")
|
|
}
|
|
if len(header.TaggedFields.Fields) != 0 {
|
|
t.Errorf("TaggedFields should be empty")
|
|
}
|
|
|
|
expectedBody := []byte{4, 5, 6}
|
|
if !bytes.Equal(body, expectedBody) {
|
|
t.Errorf("Body = %v, want %v", body, expectedBody)
|
|
}
|
|
})
|
|
|
|
t.Run("Null client ID", func(t *testing.T) {
|
|
// Regular version with null client ID
|
|
data := make([]byte, 0)
|
|
data = append(data, 0, 3) // API Key = 3 (Metadata)
|
|
data = append(data, 0, 1) // API Version = 1
|
|
|
|
// Correlation ID = 789 (4 bytes, big endian)
|
|
correlationID := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(correlationID, 789)
|
|
data = append(data, correlationID...)
|
|
|
|
data = append(data, 0xFF, 0xFF) // Client ID length = -1 (null)
|
|
data = append(data, 7, 8, 9) // Request body
|
|
|
|
header, body, err := ParseRequestHeader(data)
|
|
if err != nil {
|
|
t.Fatalf("ParseRequestHeader failed: %v", err)
|
|
}
|
|
|
|
if header.ClientID != nil {
|
|
t.Errorf("ClientID = %v, want nil", header.ClientID)
|
|
}
|
|
|
|
expectedBody := []byte{7, 8, 9}
|
|
if !bytes.Equal(body, expectedBody) {
|
|
t.Errorf("Body = %v, want %v", body, expectedBody)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestEncodeFlexibleResponse(t *testing.T) {
|
|
correlationID := uint32(123)
|
|
data := []byte{1, 2, 3, 4}
|
|
|
|
t.Run("Without tagged fields", func(t *testing.T) {
|
|
result := EncodeFlexibleResponse(correlationID, data, false)
|
|
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4}
|
|
|
|
if !bytes.Equal(result, expected) {
|
|
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected)
|
|
}
|
|
})
|
|
|
|
t.Run("With tagged fields", func(t *testing.T) {
|
|
result := EncodeFlexibleResponse(correlationID, data, true)
|
|
expected := []byte{0, 0, 0, 123, 1, 2, 3, 4, 0} // 0 at end for empty tagged fields
|
|
|
|
if !bytes.Equal(result, expected) {
|
|
t.Errorf("EncodeFlexibleResponse = %v, want %v", result, expected)
|
|
}
|
|
})
|
|
}
|
|
|
|
func BenchmarkEncodeUvarint(b *testing.B) {
|
|
testValues := []uint32{0, 127, 128, 16383, 16384, 65535, 65536, 0xFFFFFFFF}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
for _, val := range testValues {
|
|
EncodeUvarint(val)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkDecodeUvarint(b *testing.B) {
|
|
// Pre-encode test values
|
|
testData := [][]byte{
|
|
EncodeUvarint(0),
|
|
EncodeUvarint(127),
|
|
EncodeUvarint(128),
|
|
EncodeUvarint(16383),
|
|
EncodeUvarint(16384),
|
|
EncodeUvarint(65535),
|
|
EncodeUvarint(65536),
|
|
EncodeUvarint(0xFFFFFFFF),
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
for _, data := range testData {
|
|
DecodeUvarint(data)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkFlexibleString(b *testing.B) {
|
|
testStrings := []string{"", "a", "hello", "this is a longer test string", "测试中文字符串"}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
for _, s := range testStrings {
|
|
FlexibleString(s)
|
|
}
|
|
}
|
|
}
|