diff --git a/weed/mq/kafka/protocol/fetch.go b/weed/mq/kafka/protocol/fetch.go index f2216ff95..ced44ac36 100644 --- a/weed/mq/kafka/protocol/fetch.go +++ b/weed/mq/kafka/protocol/fetch.go @@ -1,6 +1,7 @@ package protocol import ( + "context" "encoding/binary" "fmt" "hash/crc32" @@ -12,7 +13,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" ) -func (h *Handler) handleFetch(correlationID uint32, apiVersion uint16, requestBody []byte) ([]byte, error) { +func (h *Handler) handleFetch(ctx context.Context, correlationID uint32, apiVersion uint16, requestBody []byte) ([]byte, error) { fmt.Printf("DEBUG: *** FETCH HANDLER CALLED *** Correlation: %d, Version: %d\n", correlationID, apiVersion) // Parse the Fetch request to get the requested topics and partitions fetchRequest, err := h.parseFetchRequest(apiVersion, requestBody) @@ -55,6 +56,15 @@ func (h *Handler) handleFetch(correlationID uint32, apiVersion uint16, requestBo start := time.Now() deadline := start.Add(time.Duration(maxWaitMs) * time.Millisecond) for time.Now().Before(deadline) { + // Check for context cancellation first + select { + case <-ctx.Done(): + fmt.Printf("DEBUG: Fetch polling cancelled due to context cancellation\n") + throttleTimeMs = int32(time.Since(start) / time.Millisecond) + break + default: + } + time.Sleep(10 * time.Millisecond) if hasDataAvailable() { break diff --git a/weed/mq/kafka/protocol/handler.go b/weed/mq/kafka/protocol/handler.go index 465a6647e..b5c07c5db 100644 --- a/weed/mq/kafka/protocol/handler.go +++ b/weed/mq/kafka/protocol/handler.go @@ -443,7 +443,7 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { response, err = h.handleProduce(correlationID, apiVersion, requestBody) case 1: // Fetch fmt.Printf("DEBUG: *** FETCH HANDLER CALLED *** Correlation: %d, Version: %d\n", correlationID, apiVersion) - response, err = h.handleFetch(correlationID, apiVersion, requestBody) + response, err = h.handleFetch(ctx, correlationID, apiVersion, requestBody) if err != nil { fmt.Printf("DEBUG: Fetch error: %v\n", err) } else {