You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
414 lines
12 KiB
414 lines
12 KiB
package protocol
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestKafkaErrorCodes(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
errorCode int16
|
|
expectedInfo ErrorInfo
|
|
}{
|
|
{
|
|
name: "No error",
|
|
errorCode: ErrorCodeNone,
|
|
expectedInfo: ErrorInfo{
|
|
Code: 0, Name: "NONE", Description: "No error", Retriable: false,
|
|
},
|
|
},
|
|
{
|
|
name: "Unknown server error",
|
|
errorCode: ErrorCodeUnknownServerError,
|
|
expectedInfo: ErrorInfo{
|
|
Code: 1, Name: "UNKNOWN_SERVER_ERROR", Description: "Unknown server error", Retriable: true,
|
|
},
|
|
},
|
|
{
|
|
name: "Topic already exists",
|
|
errorCode: ErrorCodeTopicAlreadyExists,
|
|
expectedInfo: ErrorInfo{
|
|
Code: 36, Name: "TOPIC_ALREADY_EXISTS", Description: "Topic already exists", Retriable: false,
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid partitions",
|
|
errorCode: ErrorCodeInvalidPartitions,
|
|
expectedInfo: ErrorInfo{
|
|
Code: 37, Name: "INVALID_PARTITIONS", Description: "Invalid number of partitions", Retriable: false,
|
|
},
|
|
},
|
|
{
|
|
name: "Request timed out",
|
|
errorCode: ErrorCodeRequestTimedOut,
|
|
expectedInfo: ErrorInfo{
|
|
Code: 7, Name: "REQUEST_TIMED_OUT", Description: "Request timed out", Retriable: true,
|
|
},
|
|
},
|
|
{
|
|
name: "Connection timeout",
|
|
errorCode: ErrorCodeConnectionTimeout,
|
|
expectedInfo: ErrorInfo{
|
|
Code: 61, Name: "CONNECTION_TIMEOUT", Description: "Connection timeout", Retriable: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
info := GetErrorInfo(tt.errorCode)
|
|
if info.Code != tt.expectedInfo.Code {
|
|
t.Errorf("GetErrorInfo().Code = %d, want %d", info.Code, tt.expectedInfo.Code)
|
|
}
|
|
if info.Name != tt.expectedInfo.Name {
|
|
t.Errorf("GetErrorInfo().Name = %s, want %s", info.Name, tt.expectedInfo.Name)
|
|
}
|
|
if info.Description != tt.expectedInfo.Description {
|
|
t.Errorf("GetErrorInfo().Description = %s, want %s", info.Description, tt.expectedInfo.Description)
|
|
}
|
|
if info.Retriable != tt.expectedInfo.Retriable {
|
|
t.Errorf("GetErrorInfo().Retriable = %v, want %v", info.Retriable, tt.expectedInfo.Retriable)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsRetriableError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
errorCode int16
|
|
retriable bool
|
|
}{
|
|
{"None", ErrorCodeNone, false},
|
|
{"Unknown server error", ErrorCodeUnknownServerError, true},
|
|
{"Topic already exists", ErrorCodeTopicAlreadyExists, false},
|
|
{"Request timed out", ErrorCodeRequestTimedOut, true},
|
|
{"Rebalance in progress", ErrorCodeRebalanceInProgress, true},
|
|
{"Invalid group ID", ErrorCodeInvalidGroupID, false},
|
|
{"Network exception", ErrorCodeNetworkException, true},
|
|
{"Connection timeout", ErrorCodeConnectionTimeout, true},
|
|
{"Read timeout", ErrorCodeReadTimeout, true},
|
|
{"Write timeout", ErrorCodeWriteTimeout, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := IsRetriableError(tt.errorCode); got != tt.retriable {
|
|
t.Errorf("IsRetriableError() = %v, want %v", got, tt.retriable)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBuildErrorResponse(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
correlationID uint32
|
|
errorCode int16
|
|
expectedLen int
|
|
}{
|
|
{"Basic error response", 12345, ErrorCodeUnknownServerError, 6},
|
|
{"Topic already exists", 67890, ErrorCodeTopicAlreadyExists, 6},
|
|
{"No error", 11111, ErrorCodeNone, 6},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
response := BuildErrorResponse(tt.correlationID, tt.errorCode)
|
|
|
|
if len(response) != tt.expectedLen {
|
|
t.Errorf("BuildErrorResponse() length = %d, want %d", len(response), tt.expectedLen)
|
|
}
|
|
|
|
// Verify correlation ID
|
|
if len(response) >= 4 {
|
|
correlationID := binary.BigEndian.Uint32(response[0:4])
|
|
if correlationID != tt.correlationID {
|
|
t.Errorf("Correlation ID = %d, want %d", correlationID, tt.correlationID)
|
|
}
|
|
}
|
|
|
|
// Verify error code
|
|
if len(response) >= 6 {
|
|
errorCode := binary.BigEndian.Uint16(response[4:6])
|
|
if errorCode != uint16(tt.errorCode) {
|
|
t.Errorf("Error code = %d, want %d", errorCode, uint16(tt.errorCode))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBuildErrorResponseWithMessage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
correlationID uint32
|
|
errorCode int16
|
|
message string
|
|
expectNullMsg bool
|
|
}{
|
|
{"Error with message", 12345, ErrorCodeUnknownServerError, "Test error message", false},
|
|
{"Error with empty message", 67890, ErrorCodeTopicAlreadyExists, "", true},
|
|
{"Error with long message", 11111, ErrorCodeInvalidPartitions, "This is a longer error message for testing", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
response := BuildErrorResponseWithMessage(tt.correlationID, tt.errorCode, tt.message)
|
|
|
|
// Should have at least correlation ID (4) + error code (2) + message length (2)
|
|
minLen := 8
|
|
if len(response) < minLen {
|
|
t.Errorf("BuildErrorResponseWithMessage() length = %d, want at least %d", len(response), minLen)
|
|
}
|
|
|
|
// Verify correlation ID
|
|
correlationID := binary.BigEndian.Uint32(response[0:4])
|
|
if correlationID != tt.correlationID {
|
|
t.Errorf("Correlation ID = %d, want %d", correlationID, tt.correlationID)
|
|
}
|
|
|
|
// Verify error code
|
|
errorCode := binary.BigEndian.Uint16(response[4:6])
|
|
if errorCode != uint16(tt.errorCode) {
|
|
t.Errorf("Error code = %d, want %d", errorCode, uint16(tt.errorCode))
|
|
}
|
|
|
|
// Verify message
|
|
if tt.expectNullMsg {
|
|
// Should have null string marker (0xFF, 0xFF)
|
|
if len(response) >= 8 && (response[6] != 0xFF || response[7] != 0xFF) {
|
|
t.Errorf("Expected null string marker, got %x %x", response[6], response[7])
|
|
}
|
|
} else {
|
|
// Should have message length and message
|
|
if len(response) >= 8 {
|
|
messageLen := binary.BigEndian.Uint16(response[6:8])
|
|
if messageLen != uint16(len(tt.message)) {
|
|
t.Errorf("Message length = %d, want %d", messageLen, len(tt.message))
|
|
}
|
|
|
|
if len(response) >= 8+len(tt.message) {
|
|
actualMessage := string(response[8 : 8+len(tt.message)])
|
|
if actualMessage != tt.message {
|
|
t.Errorf("Message = %q, want %q", actualMessage, tt.message)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClassifyNetworkError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expected int16
|
|
}{
|
|
{"No error", nil, ErrorCodeNone},
|
|
{"Generic error", errors.New("generic error"), ErrorCodeUnknownServerError},
|
|
{"Connection refused", errors.New("connection refused"), ErrorCodeConnectionRefused},
|
|
{"Connection timeout", errors.New("connection timeout"), ErrorCodeConnectionTimeout},
|
|
{"Network timeout", &timeoutError{}, ErrorCodeRequestTimedOut},
|
|
{"Network error", &networkError{}, ErrorCodeNetworkException},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := ClassifyNetworkError(tt.err); got != tt.expected {
|
|
t.Errorf("ClassifyNetworkError() = %v, want %v", got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleTimeoutError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
operation string
|
|
expected int16
|
|
}{
|
|
{"No error", nil, "read", ErrorCodeNone},
|
|
{"Read timeout", &timeoutError{}, "read", ErrorCodeReadTimeout},
|
|
{"Write timeout", &timeoutError{}, "write", ErrorCodeWriteTimeout},
|
|
{"Connect timeout", &timeoutError{}, "connect", ErrorCodeConnectionTimeout},
|
|
{"Generic timeout", &timeoutError{}, "unknown", ErrorCodeRequestTimedOut},
|
|
{"Non-timeout error", errors.New("other error"), "read", ErrorCodeUnknownServerError},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := HandleTimeoutError(tt.err, tt.operation); got != tt.expected {
|
|
t.Errorf("HandleTimeoutError() = %v, want %v", got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultTimeoutConfig(t *testing.T) {
|
|
config := DefaultTimeoutConfig()
|
|
|
|
if config.ConnectionTimeout != 30*time.Second {
|
|
t.Errorf("ConnectionTimeout = %v, want %v", config.ConnectionTimeout, 30*time.Second)
|
|
}
|
|
if config.ReadTimeout != 10*time.Second {
|
|
t.Errorf("ReadTimeout = %v, want %v", config.ReadTimeout, 10*time.Second)
|
|
}
|
|
if config.WriteTimeout != 10*time.Second {
|
|
t.Errorf("WriteTimeout = %v, want %v", config.WriteTimeout, 10*time.Second)
|
|
}
|
|
if config.RequestTimeout != 30*time.Second {
|
|
t.Errorf("RequestTimeout = %v, want %v", config.RequestTimeout, 30*time.Second)
|
|
}
|
|
}
|
|
|
|
func TestSafeFormatError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expected string
|
|
}{
|
|
{"No error", nil, ""},
|
|
{"Generic error", errors.New("test error"), "Error: test error"},
|
|
{"Network error", &networkError{}, "Error: network error"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := SafeFormatError(tt.err); got != tt.expected {
|
|
t.Errorf("SafeFormatError() = %q, want %q", got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetErrorInfo_UnknownErrorCode(t *testing.T) {
|
|
unknownCode := int16(9999)
|
|
info := GetErrorInfo(unknownCode)
|
|
|
|
if info.Code != unknownCode {
|
|
t.Errorf("Code = %d, want %d", info.Code, unknownCode)
|
|
}
|
|
if info.Name != "UNKNOWN" {
|
|
t.Errorf("Name = %s, want UNKNOWN", info.Name)
|
|
}
|
|
if info.Description != "Unknown error code" {
|
|
t.Errorf("Description = %s, want 'Unknown error code'", info.Description)
|
|
}
|
|
if info.Retriable != false {
|
|
t.Errorf("Retriable = %v, want false", info.Retriable)
|
|
}
|
|
}
|
|
|
|
// Integration test for error handling in protocol context
|
|
func TestErrorHandling_Integration(t *testing.T) {
|
|
// Test building various protocol error responses
|
|
tests := []struct {
|
|
name string
|
|
apiKey uint16
|
|
errorCode int16
|
|
message string
|
|
}{
|
|
{"ApiVersions error", 18, ErrorCodeUnsupportedVersion, "Version not supported"},
|
|
{"Metadata error", 3, ErrorCodeUnknownTopicOrPartition, "Topic not found"},
|
|
{"Produce error", 0, ErrorCodeMessageTooLarge, "Message exceeds size limit"},
|
|
{"Fetch error", 1, ErrorCodeOffsetOutOfRange, "Offset out of range"},
|
|
{"CreateTopics error", 19, ErrorCodeTopicAlreadyExists, "Topic already exists"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
correlationID := uint32(12345)
|
|
|
|
// Test basic error response
|
|
basicResponse := BuildErrorResponse(correlationID, tt.errorCode)
|
|
if len(basicResponse) != 6 {
|
|
t.Errorf("Basic response length = %d, want 6", len(basicResponse))
|
|
}
|
|
|
|
// Test error response with message
|
|
messageResponse := BuildErrorResponseWithMessage(correlationID, tt.errorCode, tt.message)
|
|
expectedMinLen := 8 + len(tt.message) // 4 (correlationID) + 2 (errorCode) + 2 (messageLen) + len(message)
|
|
if len(messageResponse) < expectedMinLen {
|
|
t.Errorf("Message response length = %d, want at least %d", len(messageResponse), expectedMinLen)
|
|
}
|
|
|
|
// Verify error is correctly classified
|
|
info := GetErrorInfo(tt.errorCode)
|
|
if info.Code != tt.errorCode {
|
|
t.Errorf("Error info code = %d, want %d", info.Code, tt.errorCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Mock error types for testing
|
|
type timeoutError struct{}
|
|
|
|
func (e *timeoutError) Error() string { return "timeout error" }
|
|
func (e *timeoutError) Timeout() bool { return true }
|
|
func (e *timeoutError) Temporary() bool { return true }
|
|
|
|
type networkError struct{}
|
|
|
|
func (e *networkError) Error() string { return "network error" }
|
|
func (e *networkError) Timeout() bool { return false }
|
|
func (e *networkError) Temporary() bool { return true }
|
|
|
|
// Test timeout detection
|
|
func TestTimeoutDetection(t *testing.T) {
|
|
// Test with context timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
|
defer cancel()
|
|
|
|
// Wait for context to timeout
|
|
time.Sleep(2 * time.Millisecond)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
err := ctx.Err()
|
|
errorCode := HandleTimeoutError(err, "context")
|
|
if errorCode != ErrorCodeRequestTimedOut {
|
|
t.Errorf("Context timeout error code = %v, want %v", errorCode, ErrorCodeRequestTimedOut)
|
|
}
|
|
default:
|
|
t.Error("Context should have timed out")
|
|
}
|
|
}
|
|
|
|
// Benchmark error response building
|
|
func BenchmarkBuildErrorResponse(b *testing.B) {
|
|
correlationID := uint32(12345)
|
|
errorCode := ErrorCodeUnknownServerError
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
BuildErrorResponse(correlationID, errorCode)
|
|
}
|
|
}
|
|
|
|
func BenchmarkBuildErrorResponseWithMessage(b *testing.B) {
|
|
correlationID := uint32(12345)
|
|
errorCode := ErrorCodeUnknownServerError
|
|
message := "This is a test error message"
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
BuildErrorResponseWithMessage(correlationID, errorCode, message)
|
|
}
|
|
}
|
|
|
|
func BenchmarkClassifyNetworkError(b *testing.B) {
|
|
err := &timeoutError{}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
ClassifyNetworkError(err)
|
|
}
|
|
}
|