Browse Source

context-cancellable read

pull/7231/head
chrislu 2 months ago
parent
commit
5a2fd1413f
  1. 49
      weed/mq/kafka/protocol/handler.go

49
weed/mq/kafka/protocol/handler.go

@ -279,17 +279,36 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
}
}
// Read message size (4 bytes)
// Read message size (4 bytes) with context cancellation
fmt.Printf("DEBUG: [%s] About to read message size header\n", connectionID)
var sizeBytes [4]byte
if _, err := io.ReadFull(r, sizeBytes[:]); err != nil {
if err == io.EOF {
// Use a channel to make the read operation cancellable
type readResult struct {
n int
err error
}
readChan := make(chan readResult, 1)
go func() {
n, err := io.ReadFull(r, sizeBytes[:])
readChan <- readResult{n: n, err: err}
}()
// Wait for either the read to complete or context cancellation
select {
case <-ctx.Done():
fmt.Printf("DEBUG: [%s] Context cancelled during read, closing connection\n", connectionID)
return ctx.Err()
case result := <-readChan:
if result.err != nil {
if result.err == io.EOF {
fmt.Printf("DEBUG: [%s] Client closed connection (clean EOF)\n", connectionID)
return nil // clean disconnect
}
// Check if it's a timeout error
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if netErr, ok := result.err.(net.Error); ok && netErr.Timeout() {
fmt.Printf("DEBUG: [%s] Read timeout (likely due to context cancellation or client disconnect)\n", connectionID)
// Check if context was cancelled
select {
@ -302,28 +321,12 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
}
}
// Use centralized error classification
errorCode := ClassifyNetworkError(err)
switch errorCode {
case ErrorCodeRequestTimedOut:
// Check if error is due to context cancellation
select {
case <-ctx.Done():
fmt.Printf("DEBUG: [%s] Read timeout due to context cancellation\n", connectionID)
return ctx.Err()
default:
fmt.Printf("DEBUG: [%s] Read timeout: %v\n", connectionID, err)
return fmt.Errorf("read timeout: %w", err)
}
case ErrorCodeNetworkException:
fmt.Printf("DEBUG: [%s] Network error reading message size: %v\n", connectionID, err)
return fmt.Errorf("network error: %w", err)
default:
fmt.Printf("DEBUG: [%s] Error reading message size: %v (code: %d)\n", connectionID, err, errorCode)
return fmt.Errorf("read size: %w", err)
fmt.Printf("DEBUG: [%s] Read error: %v\n", connectionID, result.err)
return fmt.Errorf("read message size: %w", result.err)
}
}
// Successfully read the message size
size := binary.BigEndian.Uint32(sizeBytes[:])
fmt.Printf("DEBUG: [%s] Read message size header: %d bytes\n", connectionID, size)
if size == 0 || size > 1024*1024 { // 1MB limit

Loading…
Cancel
Save