|
|
|
@ -279,51 +279,54 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Read message size (4 bytes)
|
|
|
|
// Read message size (4 bytes) with context cancellation
|
|
|
|
fmt.Printf("DEBUG: [%s] About to read message size header\n", connectionID) |
|
|
|
var sizeBytes [4]byte |
|
|
|
if _, err := io.ReadFull(r, sizeBytes[:]); err != nil { |
|
|
|
if err == io.EOF { |
|
|
|
fmt.Printf("DEBUG: [%s] Client closed connection (clean EOF)\n", connectionID) |
|
|
|
return nil // clean disconnect
|
|
|
|
} |
|
|
|
|
|
|
|
// Check if it's a timeout error
|
|
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { |
|
|
|
fmt.Printf("DEBUG: [%s] Read timeout (likely due to context cancellation or client disconnect)\n", connectionID) |
|
|
|
// Check if context was cancelled
|
|
|
|
select { |
|
|
|
case <-ctx.Done(): |
|
|
|
fmt.Printf("DEBUG: [%s] Context was cancelled, returning context error\n", connectionID) |
|
|
|
return ctx.Err() |
|
|
|
default: |
|
|
|
fmt.Printf("DEBUG: [%s] Timeout without context cancellation, treating as client disconnect\n", connectionID) |
|
|
|
return nil // treat as clean disconnect
|
|
|
|
// Use a channel to make the read operation cancellable
|
|
|
|
type readResult struct { |
|
|
|
n int |
|
|
|
err error |
|
|
|
} |
|
|
|
readChan := make(chan readResult, 1) |
|
|
|
|
|
|
|
go func() { |
|
|
|
n, err := io.ReadFull(r, sizeBytes[:]) |
|
|
|
readChan <- readResult{n: n, err: err} |
|
|
|
}() |
|
|
|
|
|
|
|
// Wait for either the read to complete or context cancellation
|
|
|
|
select { |
|
|
|
case <-ctx.Done(): |
|
|
|
fmt.Printf("DEBUG: [%s] Context cancelled during read, closing connection\n", connectionID) |
|
|
|
return ctx.Err() |
|
|
|
case result := <-readChan: |
|
|
|
if result.err != nil { |
|
|
|
if result.err == io.EOF { |
|
|
|
fmt.Printf("DEBUG: [%s] Client closed connection (clean EOF)\n", connectionID) |
|
|
|
return nil // clean disconnect
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Use centralized error classification
|
|
|
|
errorCode := ClassifyNetworkError(err) |
|
|
|
switch errorCode { |
|
|
|
case ErrorCodeRequestTimedOut: |
|
|
|
// Check if error is due to context cancellation
|
|
|
|
select { |
|
|
|
case <-ctx.Done(): |
|
|
|
fmt.Printf("DEBUG: [%s] Read timeout due to context cancellation\n", connectionID) |
|
|
|
return ctx.Err() |
|
|
|
default: |
|
|
|
fmt.Printf("DEBUG: [%s] Read timeout: %v\n", connectionID, err) |
|
|
|
return fmt.Errorf("read timeout: %w", err) |
|
|
|
// Check if it's a timeout error
|
|
|
|
if netErr, ok := result.err.(net.Error); ok && netErr.Timeout() { |
|
|
|
fmt.Printf("DEBUG: [%s] Read timeout (likely due to context cancellation or client disconnect)\n", connectionID) |
|
|
|
// Check if context was cancelled
|
|
|
|
select { |
|
|
|
case <-ctx.Done(): |
|
|
|
fmt.Printf("DEBUG: [%s] Context was cancelled, returning context error\n", connectionID) |
|
|
|
return ctx.Err() |
|
|
|
default: |
|
|
|
fmt.Printf("DEBUG: [%s] Timeout without context cancellation, treating as client disconnect\n", connectionID) |
|
|
|
return nil // treat as clean disconnect
|
|
|
|
} |
|
|
|
} |
|
|
|
case ErrorCodeNetworkException: |
|
|
|
fmt.Printf("DEBUG: [%s] Network error reading message size: %v\n", connectionID, err) |
|
|
|
return fmt.Errorf("network error: %w", err) |
|
|
|
default: |
|
|
|
fmt.Printf("DEBUG: [%s] Error reading message size: %v (code: %d)\n", connectionID, err, errorCode) |
|
|
|
return fmt.Errorf("read size: %w", err) |
|
|
|
|
|
|
|
fmt.Printf("DEBUG: [%s] Read error: %v\n", connectionID, result.err) |
|
|
|
return fmt.Errorf("read message size: %w", result.err) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Successfully read the message size
|
|
|
|
size := binary.BigEndian.Uint32(sizeBytes[:]) |
|
|
|
fmt.Printf("DEBUG: [%s] Read message size header: %d bytes\n", connectionID, size) |
|
|
|
if size == 0 || size > 1024*1024 { // 1MB limit
|
|
|
|
|