You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

414 lines
12 KiB

package protocol
import (
"context"
"encoding/binary"
"errors"
"testing"
"time"
)
func TestKafkaErrorCodes(t *testing.T) {
tests := []struct {
name string
errorCode int16
expectedInfo ErrorInfo
}{
{
name: "No error",
errorCode: ErrorCodeNone,
expectedInfo: ErrorInfo{
Code: 0, Name: "NONE", Description: "No error", Retriable: false,
},
},
{
name: "Unknown server error",
errorCode: ErrorCodeUnknownServerError,
expectedInfo: ErrorInfo{
Code: 1, Name: "UNKNOWN_SERVER_ERROR", Description: "Unknown server error", Retriable: true,
},
},
{
name: "Topic already exists",
errorCode: ErrorCodeTopicAlreadyExists,
expectedInfo: ErrorInfo{
Code: 36, Name: "TOPIC_ALREADY_EXISTS", Description: "Topic already exists", Retriable: false,
},
},
{
name: "Invalid partitions",
errorCode: ErrorCodeInvalidPartitions,
expectedInfo: ErrorInfo{
Code: 37, Name: "INVALID_PARTITIONS", Description: "Invalid number of partitions", Retriable: false,
},
},
{
name: "Request timed out",
errorCode: ErrorCodeRequestTimedOut,
expectedInfo: ErrorInfo{
Code: 7, Name: "REQUEST_TIMED_OUT", Description: "Request timed out", Retriable: true,
},
},
{
name: "Connection timeout",
errorCode: ErrorCodeConnectionTimeout,
expectedInfo: ErrorInfo{
Code: 61, Name: "CONNECTION_TIMEOUT", Description: "Connection timeout", Retriable: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := GetErrorInfo(tt.errorCode)
if info.Code != tt.expectedInfo.Code {
t.Errorf("GetErrorInfo().Code = %d, want %d", info.Code, tt.expectedInfo.Code)
}
if info.Name != tt.expectedInfo.Name {
t.Errorf("GetErrorInfo().Name = %s, want %s", info.Name, tt.expectedInfo.Name)
}
if info.Description != tt.expectedInfo.Description {
t.Errorf("GetErrorInfo().Description = %s, want %s", info.Description, tt.expectedInfo.Description)
}
if info.Retriable != tt.expectedInfo.Retriable {
t.Errorf("GetErrorInfo().Retriable = %v, want %v", info.Retriable, tt.expectedInfo.Retriable)
}
})
}
}
func TestIsRetriableError(t *testing.T) {
tests := []struct {
name string
errorCode int16
retriable bool
}{
{"None", ErrorCodeNone, false},
{"Unknown server error", ErrorCodeUnknownServerError, true},
{"Topic already exists", ErrorCodeTopicAlreadyExists, false},
{"Request timed out", ErrorCodeRequestTimedOut, true},
{"Rebalance in progress", ErrorCodeRebalanceInProgress, true},
{"Invalid group ID", ErrorCodeInvalidGroupID, false},
{"Network exception", ErrorCodeNetworkException, true},
{"Connection timeout", ErrorCodeConnectionTimeout, true},
{"Read timeout", ErrorCodeReadTimeout, true},
{"Write timeout", ErrorCodeWriteTimeout, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsRetriableError(tt.errorCode); got != tt.retriable {
t.Errorf("IsRetriableError() = %v, want %v", got, tt.retriable)
}
})
}
}
func TestBuildErrorResponse(t *testing.T) {
tests := []struct {
name string
correlationID uint32
errorCode int16
expectedLen int
}{
{"Basic error response", 12345, ErrorCodeUnknownServerError, 6},
{"Topic already exists", 67890, ErrorCodeTopicAlreadyExists, 6},
{"No error", 11111, ErrorCodeNone, 6},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
response := BuildErrorResponse(tt.correlationID, tt.errorCode)
if len(response) != tt.expectedLen {
t.Errorf("BuildErrorResponse() length = %d, want %d", len(response), tt.expectedLen)
}
// Verify correlation ID
if len(response) >= 4 {
correlationID := binary.BigEndian.Uint32(response[0:4])
if correlationID != tt.correlationID {
t.Errorf("Correlation ID = %d, want %d", correlationID, tt.correlationID)
}
}
// Verify error code
if len(response) >= 6 {
errorCode := binary.BigEndian.Uint16(response[4:6])
if errorCode != uint16(tt.errorCode) {
t.Errorf("Error code = %d, want %d", errorCode, uint16(tt.errorCode))
}
}
})
}
}
func TestBuildErrorResponseWithMessage(t *testing.T) {
tests := []struct {
name string
correlationID uint32
errorCode int16
message string
expectNullMsg bool
}{
{"Error with message", 12345, ErrorCodeUnknownServerError, "Test error message", false},
{"Error with empty message", 67890, ErrorCodeTopicAlreadyExists, "", true},
{"Error with long message", 11111, ErrorCodeInvalidPartitions, "This is a longer error message for testing", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
response := BuildErrorResponseWithMessage(tt.correlationID, tt.errorCode, tt.message)
// Should have at least correlation ID (4) + error code (2) + message length (2)
minLen := 8
if len(response) < minLen {
t.Errorf("BuildErrorResponseWithMessage() length = %d, want at least %d", len(response), minLen)
}
// Verify correlation ID
correlationID := binary.BigEndian.Uint32(response[0:4])
if correlationID != tt.correlationID {
t.Errorf("Correlation ID = %d, want %d", correlationID, tt.correlationID)
}
// Verify error code
errorCode := binary.BigEndian.Uint16(response[4:6])
if errorCode != uint16(tt.errorCode) {
t.Errorf("Error code = %d, want %d", errorCode, uint16(tt.errorCode))
}
// Verify message
if tt.expectNullMsg {
// Should have null string marker (0xFF, 0xFF)
if len(response) >= 8 && (response[6] != 0xFF || response[7] != 0xFF) {
t.Errorf("Expected null string marker, got %x %x", response[6], response[7])
}
} else {
// Should have message length and message
if len(response) >= 8 {
messageLen := binary.BigEndian.Uint16(response[6:8])
if messageLen != uint16(len(tt.message)) {
t.Errorf("Message length = %d, want %d", messageLen, len(tt.message))
}
if len(response) >= 8+len(tt.message) {
actualMessage := string(response[8 : 8+len(tt.message)])
if actualMessage != tt.message {
t.Errorf("Message = %q, want %q", actualMessage, tt.message)
}
}
}
}
})
}
}
func TestClassifyNetworkError(t *testing.T) {
tests := []struct {
name string
err error
expected int16
}{
{"No error", nil, ErrorCodeNone},
{"Generic error", errors.New("generic error"), ErrorCodeUnknownServerError},
{"Connection refused", errors.New("connection refused"), ErrorCodeConnectionRefused},
{"Connection timeout", errors.New("connection timeout"), ErrorCodeConnectionTimeout},
{"Network timeout", &timeoutError{}, ErrorCodeRequestTimedOut},
{"Network error", &networkError{}, ErrorCodeNetworkException},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ClassifyNetworkError(tt.err); got != tt.expected {
t.Errorf("ClassifyNetworkError() = %v, want %v", got, tt.expected)
}
})
}
}
func TestHandleTimeoutError(t *testing.T) {
tests := []struct {
name string
err error
operation string
expected int16
}{
{"No error", nil, "read", ErrorCodeNone},
{"Read timeout", &timeoutError{}, "read", ErrorCodeReadTimeout},
{"Write timeout", &timeoutError{}, "write", ErrorCodeWriteTimeout},
{"Connect timeout", &timeoutError{}, "connect", ErrorCodeConnectionTimeout},
{"Generic timeout", &timeoutError{}, "unknown", ErrorCodeRequestTimedOut},
{"Non-timeout error", errors.New("other error"), "read", ErrorCodeUnknownServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HandleTimeoutError(tt.err, tt.operation); got != tt.expected {
t.Errorf("HandleTimeoutError() = %v, want %v", got, tt.expected)
}
})
}
}
func TestDefaultTimeoutConfig(t *testing.T) {
config := DefaultTimeoutConfig()
if config.ConnectionTimeout != 30*time.Second {
t.Errorf("ConnectionTimeout = %v, want %v", config.ConnectionTimeout, 30*time.Second)
}
if config.ReadTimeout != 10*time.Second {
t.Errorf("ReadTimeout = %v, want %v", config.ReadTimeout, 10*time.Second)
}
if config.WriteTimeout != 10*time.Second {
t.Errorf("WriteTimeout = %v, want %v", config.WriteTimeout, 10*time.Second)
}
if config.RequestTimeout != 30*time.Second {
t.Errorf("RequestTimeout = %v, want %v", config.RequestTimeout, 30*time.Second)
}
}
func TestSafeFormatError(t *testing.T) {
tests := []struct {
name string
err error
expected string
}{
{"No error", nil, ""},
{"Generic error", errors.New("test error"), "Error: test error"},
{"Network error", &networkError{}, "Error: network error"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := SafeFormatError(tt.err); got != tt.expected {
t.Errorf("SafeFormatError() = %q, want %q", got, tt.expected)
}
})
}
}
func TestGetErrorInfo_UnknownErrorCode(t *testing.T) {
unknownCode := int16(9999)
info := GetErrorInfo(unknownCode)
if info.Code != unknownCode {
t.Errorf("Code = %d, want %d", info.Code, unknownCode)
}
if info.Name != "UNKNOWN" {
t.Errorf("Name = %s, want UNKNOWN", info.Name)
}
if info.Description != "Unknown error code" {
t.Errorf("Description = %s, want 'Unknown error code'", info.Description)
}
if info.Retriable != false {
t.Errorf("Retriable = %v, want false", info.Retriable)
}
}
// Integration test for error handling in protocol context
func TestErrorHandling_Integration(t *testing.T) {
// Test building various protocol error responses
tests := []struct {
name string
apiKey uint16
errorCode int16
message string
}{
{"ApiVersions error", 18, ErrorCodeUnsupportedVersion, "Version not supported"},
{"Metadata error", 3, ErrorCodeUnknownTopicOrPartition, "Topic not found"},
{"Produce error", 0, ErrorCodeMessageTooLarge, "Message exceeds size limit"},
{"Fetch error", 1, ErrorCodeOffsetOutOfRange, "Offset out of range"},
{"CreateTopics error", 19, ErrorCodeTopicAlreadyExists, "Topic already exists"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
correlationID := uint32(12345)
// Test basic error response
basicResponse := BuildErrorResponse(correlationID, tt.errorCode)
if len(basicResponse) != 6 {
t.Errorf("Basic response length = %d, want 6", len(basicResponse))
}
// Test error response with message
messageResponse := BuildErrorResponseWithMessage(correlationID, tt.errorCode, tt.message)
expectedMinLen := 8 + len(tt.message) // 4 (correlationID) + 2 (errorCode) + 2 (messageLen) + len(message)
if len(messageResponse) < expectedMinLen {
t.Errorf("Message response length = %d, want at least %d", len(messageResponse), expectedMinLen)
}
// Verify error is correctly classified
info := GetErrorInfo(tt.errorCode)
if info.Code != tt.errorCode {
t.Errorf("Error info code = %d, want %d", info.Code, tt.errorCode)
}
})
}
}
// Mock error types for testing
type timeoutError struct{}
func (e *timeoutError) Error() string { return "timeout error" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
type networkError struct{}
func (e *networkError) Error() string { return "network error" }
func (e *networkError) Timeout() bool { return false }
func (e *networkError) Temporary() bool { return true }
// Test timeout detection
func TestTimeoutDetection(t *testing.T) {
// Test with context timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
// Wait for context to timeout
time.Sleep(2 * time.Millisecond)
select {
case <-ctx.Done():
err := ctx.Err()
errorCode := HandleTimeoutError(err, "context")
if errorCode != ErrorCodeRequestTimedOut {
t.Errorf("Context timeout error code = %v, want %v", errorCode, ErrorCodeRequestTimedOut)
}
default:
t.Error("Context should have timed out")
}
}
// Benchmark error response building
func BenchmarkBuildErrorResponse(b *testing.B) {
correlationID := uint32(12345)
errorCode := ErrorCodeUnknownServerError
b.ResetTimer()
for i := 0; i < b.N; i++ {
BuildErrorResponse(correlationID, errorCode)
}
}
func BenchmarkBuildErrorResponseWithMessage(b *testing.B) {
correlationID := uint32(12345)
errorCode := ErrorCodeUnknownServerError
message := "This is a test error message"
b.ResetTimer()
for i := 0; i < b.N; i++ {
BuildErrorResponseWithMessage(correlationID, errorCode, message)
}
}
func BenchmarkClassifyNetworkError(b *testing.B) {
err := &timeoutError{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ClassifyNetworkError(err)
}
}