diff --git a/weed/mq/kafka/protocol/handler.go b/weed/mq/kafka/protocol/handler.go index 64397bdc1..465a6647e 100644 --- a/weed/mq/kafka/protocol/handler.go +++ b/weed/mq/kafka/protocol/handler.go @@ -279,51 +279,54 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { } } - // Read message size (4 bytes) + // Read message size (4 bytes) with context cancellation fmt.Printf("DEBUG: [%s] About to read message size header\n", connectionID) var sizeBytes [4]byte - if _, err := io.ReadFull(r, sizeBytes[:]); err != nil { - if err == io.EOF { - fmt.Printf("DEBUG: [%s] Client closed connection (clean EOF)\n", connectionID) - return nil // clean disconnect - } - // Check if it's a timeout error - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - fmt.Printf("DEBUG: [%s] Read timeout (likely due to context cancellation or client disconnect)\n", connectionID) - // Check if context was cancelled - select { - case <-ctx.Done(): - fmt.Printf("DEBUG: [%s] Context was cancelled, returning context error\n", connectionID) - return ctx.Err() - default: - fmt.Printf("DEBUG: [%s] Timeout without context cancellation, treating as client disconnect\n", connectionID) - return nil // treat as clean disconnect + // Use a channel to make the read operation cancellable + type readResult struct { + n int + err error + } + readChan := make(chan readResult, 1) + + go func() { + n, err := io.ReadFull(r, sizeBytes[:]) + readChan <- readResult{n: n, err: err} + }() + + // Wait for either the read to complete or context cancellation + select { + case <-ctx.Done(): + fmt.Printf("DEBUG: [%s] Context cancelled during read, closing connection\n", connectionID) + return ctx.Err() + case result := <-readChan: + if result.err != nil { + if result.err == io.EOF { + fmt.Printf("DEBUG: [%s] Client closed connection (clean EOF)\n", connectionID) + return nil // clean disconnect } - } - // Use centralized error classification - errorCode := ClassifyNetworkError(err) - switch errorCode { - case ErrorCodeRequestTimedOut: - // Check if error is due to context cancellation - select { - case <-ctx.Done(): - fmt.Printf("DEBUG: [%s] Read timeout due to context cancellation\n", connectionID) - return ctx.Err() - default: - fmt.Printf("DEBUG: [%s] Read timeout: %v\n", connectionID, err) - return fmt.Errorf("read timeout: %w", err) + // Check if it's a timeout error + if netErr, ok := result.err.(net.Error); ok && netErr.Timeout() { + fmt.Printf("DEBUG: [%s] Read timeout (likely due to context cancellation or client disconnect)\n", connectionID) + // Check if context was cancelled + select { + case <-ctx.Done(): + fmt.Printf("DEBUG: [%s] Context was cancelled, returning context error\n", connectionID) + return ctx.Err() + default: + fmt.Printf("DEBUG: [%s] Timeout without context cancellation, treating as client disconnect\n", connectionID) + return nil // treat as clean disconnect + } } - case ErrorCodeNetworkException: - fmt.Printf("DEBUG: [%s] Network error reading message size: %v\n", connectionID, err) - return fmt.Errorf("network error: %w", err) - default: - fmt.Printf("DEBUG: [%s] Error reading message size: %v (code: %d)\n", connectionID, err, errorCode) - return fmt.Errorf("read size: %w", err) + + fmt.Printf("DEBUG: [%s] Read error: %v\n", connectionID, result.err) + return fmt.Errorf("read message size: %w", result.err) } } + // Successfully read the message size size := binary.BigEndian.Uint32(sizeBytes[:]) fmt.Printf("DEBUG: [%s] Read message size header: %d bytes\n", connectionID, size) if size == 0 || size > 1024*1024 { // 1MB limit