|
|
|
@ -3,6 +3,7 @@ package protocol |
|
|
|
import ( |
|
|
|
"bufio" |
|
|
|
"bytes" |
|
|
|
"context" |
|
|
|
"encoding/binary" |
|
|
|
"fmt" |
|
|
|
"io" |
|
|
|
@ -333,7 +334,7 @@ func (h *Handler) SetBrokerAddress(host string, port int) { |
|
|
|
} |
|
|
|
|
|
|
|
// HandleConn processes a single client connection
|
|
|
|
func (h *Handler) HandleConn(conn net.Conn) error { |
|
|
|
func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { |
|
|
|
connectionID := fmt.Sprintf("%s->%s", conn.RemoteAddr(), conn.LocalAddr()) |
|
|
|
defer func() { |
|
|
|
fmt.Printf("DEBUG: [%s] Connection closing\n", connectionID) |
|
|
|
@ -345,6 +346,22 @@ func (h *Handler) HandleConn(conn net.Conn) error { |
|
|
|
defer w.Flush() |
|
|
|
|
|
|
|
for { |
|
|
|
// Check if context is cancelled
|
|
|
|
select { |
|
|
|
case <-ctx.Done(): |
|
|
|
fmt.Printf("DEBUG: [%s] Context cancelled, closing connection\n", connectionID) |
|
|
|
return ctx.Err() |
|
|
|
default: |
|
|
|
} |
|
|
|
|
|
|
|
// Set a read deadline for the connection based on context
|
|
|
|
if deadline, ok := ctx.Deadline(); ok { |
|
|
|
conn.SetReadDeadline(deadline) |
|
|
|
} else { |
|
|
|
// Set a reasonable timeout if no deadline is set
|
|
|
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) |
|
|
|
} |
|
|
|
|
|
|
|
// Read message size (4 bytes)
|
|
|
|
var sizeBytes [4]byte |
|
|
|
if _, err := io.ReadFull(r, sizeBytes[:]); err != nil { |
|
|
|
@ -352,6 +369,16 @@ func (h *Handler) HandleConn(conn net.Conn) error { |
|
|
|
fmt.Printf("DEBUG: Client closed connection (clean EOF)\n") |
|
|
|
return nil // clean disconnect
|
|
|
|
} |
|
|
|
// Check if error is due to context cancellation
|
|
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { |
|
|
|
select { |
|
|
|
case <-ctx.Done(): |
|
|
|
fmt.Printf("DEBUG: [%s] Read timeout due to context cancellation\n", connectionID) |
|
|
|
return ctx.Err() |
|
|
|
default: |
|
|
|
// Actual timeout, continue with error
|
|
|
|
} |
|
|
|
} |
|
|
|
fmt.Printf("DEBUG: Error reading message size: %v\n", err) |
|
|
|
return fmt.Errorf("read size: %w", err) |
|
|
|
} |
|
|
|
|