diff --git a/weed/mq/kafka/protocol/errors.go b/weed/mq/kafka/protocol/errors.go index 93bc85c80..a595afb31 100644 --- a/weed/mq/kafka/protocol/errors.go +++ b/weed/mq/kafka/protocol/errors.go @@ -249,7 +249,8 @@ func IsRetriableError(code int16) bool { return GetErrorInfo(code).Retriable } -// BuildErrorResponse builds a standard Kafka error response +// BuildErrorResponse builds a standard Kafka error response (2-byte error code only). +// Prefer BuildAPIErrorResponse for API-aware error bodies. func BuildErrorResponse(correlationID uint32, errorCode int16) []byte { response := make([]byte, 0, 8) @@ -264,6 +265,228 @@ func BuildErrorResponse(correlationID uint32, errorCode int16) []byte { return response } +// BuildAPIErrorResponse builds a minimal-but-valid error response body whose +// layout matches the schema the client expects for the given API key and +// version. The correlation ID and header-level tagged fields are NOT included +// (those are added by writeResponseWithHeader). +func BuildAPIErrorResponse(apiKey, apiVersion uint16, errorCode int16) []byte { + ec := make([]byte, 2) + binary.BigEndian.PutUint16(ec, uint16(errorCode)) + throttle := []byte{0, 0, 0, 0} // throttle_time_ms = 0 + emptyArr := []byte{0, 0, 0, 0} // regular array length = 0 + nullStr := []byte{0xFF, 0xFF} // nullable string = null + emptyStr := []byte{0, 0} // string length = 0 + + switch APIKey(apiKey) { + + // --- error_code is the first body field ----------------------------------- + + case APIKeyApiVersions: + // error_code(2) + api_keys_array + [throttle_time_ms v1+] [+ tagged_fields v3+] + buf := append([]byte{}, ec...) + if apiVersion >= 3 { + buf = append(buf, 1) // compact array length=0 (varint 1 = 0+1) + } else { + buf = append(buf, emptyArr...) + } + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + if apiVersion >= 3 { + buf = append(buf, 0) // body-level tagged fields + } + return buf + + // --- throttle_time_ms(4) + error_code(2) + trailing fields ---------------- + + case APIKeyFindCoordinator: + // [throttle v1+] + error_code + error_msg + node_id + host + port + buf := make([]byte, 0, 24) + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + buf = append(buf, ec...) + buf = append(buf, nullStr...) // error_message + buf = append(buf, 0xFF, 0xFF, 0xFF, 0xFF) // node_id = -1 + buf = append(buf, emptyStr...) // host + buf = append(buf, 0, 0, 0, 0) // port = 0 + return buf + + case APIKeyJoinGroup: + // [throttle v2+] + error_code + generation_id + protocol_name + leader + member_id + members[] + buf := make([]byte, 0, 24) + if apiVersion >= 2 { + buf = append(buf, throttle...) + } + buf = append(buf, ec...) + buf = append(buf, 0xFF, 0xFF, 0xFF, 0xFF) // generation_id = -1 + buf = append(buf, nullStr...) // protocol_name + buf = append(buf, emptyStr...) // leader + buf = append(buf, emptyStr...) // member_id + buf = append(buf, emptyArr...) // members = [] + return buf + + case APIKeySyncGroup: + // [throttle v1+] + error_code + assignment + buf := make([]byte, 0, 12) + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + buf = append(buf, ec...) + buf = append(buf, 0, 0, 0, 0) // assignment bytes length = 0 + return buf + + case APIKeyHeartbeat: + // [throttle v1+] + error_code + buf := make([]byte, 0, 8) + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + buf = append(buf, ec...) + return buf + + case APIKeyLeaveGroup: + // [throttle v1+] + error_code [+ members v3+] + buf := make([]byte, 0, 12) + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + buf = append(buf, ec...) + if apiVersion >= 3 { + buf = append(buf, emptyArr...) // members = [] + } + return buf + + case APIKeyInitProducerId: + // throttle + error_code + producer_id + producer_epoch + buf := make([]byte, 0, 18) + buf = append(buf, throttle...) + buf = append(buf, ec...) + buf = append(buf, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF) // producer_id = -1 + buf = append(buf, 0xFF, 0xFF) // producer_epoch = -1 + return buf + + case APIKeyListGroups: + // [throttle v1+] + error_code + groups[] + buf := make([]byte, 0, 12) + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + buf = append(buf, ec...) + buf = append(buf, emptyArr...) // groups = [] + return buf + + case APIKeyDescribeCluster: + // throttle + error_code + error_msg + cluster_id + controller_id + brokers[] + cluster_authorized_operations + buf := make([]byte, 0, 24) + buf = append(buf, throttle...) + buf = append(buf, ec...) + buf = append(buf, nullStr...) // error_message + buf = append(buf, emptyStr...) // cluster_id + buf = append(buf, 0xFF, 0xFF, 0xFF, 0xFF) // controller_id = -1 + buf = append(buf, emptyArr...) // brokers = [] + buf = append(buf, 0, 0, 0, 0) // cluster_authorized_operations + return buf + + // --- array-based responses (no top-level error_code) ---------------------- + + case APIKeyProduce: + // topics[] + [throttle v1+] + buf := append([]byte{}, emptyArr...) + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + return buf + + case APIKeyFetch: + // [throttle v1+] [+ error_code + session_id v7+] + topics[] + buf := make([]byte, 0, 16) + if apiVersion >= 1 { + buf = append(buf, throttle...) + } + if apiVersion >= 7 { + buf = append(buf, ec...) // error_code + buf = append(buf, 0, 0, 0, 0) // session_id = 0 + } + buf = append(buf, emptyArr...) // topics = [] + return buf + + case APIKeyMetadata: + // [throttle v3+] + brokers[] + [cluster_id v2+] + [controller_id v1+] + topics[] + buf := make([]byte, 0, 24) + if apiVersion >= 3 { + buf = append(buf, throttle...) + } + buf = append(buf, emptyArr...) // brokers = [] + if apiVersion >= 2 { + buf = append(buf, nullStr...) // cluster_id + } + if apiVersion >= 1 { + buf = append(buf, 0xFF, 0xFF, 0xFF, 0xFF) // controller_id = -1 + } + buf = append(buf, emptyArr...) // topics = [] + return buf + + case APIKeyListOffsets: + // [throttle v2+] + topics[] + buf := make([]byte, 0, 8) + if apiVersion >= 2 { + buf = append(buf, throttle...) + } + buf = append(buf, emptyArr...) + return buf + + case APIKeyOffsetCommit: + // [throttle v3+] + topics[] + buf := make([]byte, 0, 8) + if apiVersion >= 3 { + buf = append(buf, throttle...) + } + buf = append(buf, emptyArr...) + return buf + + case APIKeyOffsetFetch: + // [throttle v3+] + topics[] [+ error_code v2+] + buf := make([]byte, 0, 12) + if apiVersion >= 3 { + buf = append(buf, throttle...) + } + buf = append(buf, emptyArr...) // topics = [] + if apiVersion >= 2 { + buf = append(buf, ec...) // error_code + } + return buf + + case APIKeyCreateTopics: + // throttle + topics[] + buf := append([]byte{}, throttle...) + buf = append(buf, emptyArr...) + return buf + + case APIKeyDeleteTopics: + // throttle + topics[] + buf := append([]byte{}, throttle...) + buf = append(buf, emptyArr...) + return buf + + case APIKeyDescribeGroups: + // throttle + groups[] + buf := append([]byte{}, throttle...) + buf = append(buf, emptyArr...) + return buf + + case APIKeyDescribeConfigs: + // throttle + resources[] + buf := append([]byte{}, throttle...) + buf = append(buf, emptyArr...) + return buf + + default: + // Unknown API — emit just the error code as a best-effort fallback + return append([]byte{}, ec...) + } +} + // BuildErrorResponseWithMessage builds a Kafka error response with error message func BuildErrorResponseWithMessage(correlationID uint32, errorCode int16, message string) []byte { response := BuildErrorResponse(correlationID, errorCode) diff --git a/weed/mq/kafka/protocol/handler.go b/weed/mq/kafka/protocol/handler.go index 8dffd2313..707c38041 100644 --- a/weed/mq/kafka/protocol/handler.go +++ b/weed/mq/kafka/protocol/handler.go @@ -626,14 +626,21 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { } // Send this response + var respBody []byte if readyResp.err != nil { glog.Errorf("[%s] Error processing correlation=%d: %v", connectionID, readyResp.correlationID, readyResp.err) + // Build an API-versioned error response so the body matches + // the schema the client expects for this API key/version. + // A generic 2-byte error code would corrupt the protocol + // stream for APIs that start with throttle_time or arrays. + respBody = BuildAPIErrorResponse(readyResp.apiKey, readyResp.apiVersion, ErrorCodeUnknownServerError) } else { - if writeErr := h.writeResponseWithHeader(w, readyResp.correlationID, readyResp.apiKey, readyResp.apiVersion, readyResp.response, timeoutConfig.WriteTimeout); writeErr != nil { - glog.Errorf("[%s] Response writer WRITE ERROR correlation=%d: %v - EXITING", connectionID, readyResp.correlationID, writeErr) - correlationQueueMu.Unlock() - return - } + respBody = readyResp.response + } + if writeErr := h.writeResponseWithHeader(w, readyResp.correlationID, readyResp.apiKey, readyResp.apiVersion, respBody, timeoutConfig.WriteTimeout); writeErr != nil { + glog.Errorf("[%s] Response writer WRITE ERROR correlation=%d: %v - EXITING", connectionID, readyResp.correlationID, writeErr) + correlationQueueMu.Unlock() + return } // Remove from pending and advance @@ -689,7 +696,10 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { // Connection closed, stop processing return case <-time.After(5 * time.Second): - glog.Warningf("[%s] Control plane: timeout sending response correlation=%d", connectionID, req.correlationID) + // responseChan stuck — cancel context to tear down connection. + // The orphaned correlationID would stall the ordered writer permanently. + glog.Warningf("[%s] Control plane: timeout sending response correlation=%d, tearing down connection", connectionID, req.correlationID) + cancel() } case <-ctx.Done(): // Context cancelled, drain remaining requests before exiting @@ -767,7 +777,10 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { // Connection closed, stop processing return case <-time.After(5 * time.Second): - glog.Warningf("[%s] Data plane: timeout sending response correlation=%d", connectionID, req.correlationID) + // responseChan stuck — cancel context to tear down connection. + // The orphaned correlationID would stall the ordered writer permanently. + glog.Warningf("[%s] Data plane: timeout sending response correlation=%d, tearing down connection", connectionID, req.correlationID) + cancel() } case <-ctx.Done(): // Context cancelled, drain remaining requests before exiting @@ -804,13 +817,18 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { }() defer func() { - // Close channels in correct order to avoid panics - // 1. Close input channels to stop accepting new requests + // Cancel context FIRST so the response writer (and any worker stuck + // in an inner select) sees ctx.Done() and can exit. Previously + // cancel() ran in a separate defer registered earlier (LIFO: later + // defer runs first), so wg.Wait() below would deadlock waiting for + // the response writer which was itself waiting for ctx.Done(). + cancel() + // Close input channels to stop accepting new requests close(controlChan) close(dataChan) - // 2. Wait for worker goroutines to finish processing and sending responses + // Wait for worker goroutines to finish processing and sending responses wg.Wait() - // 3. NOW close responseChan to signal response writer to exit + // NOW close responseChan (safe — all senders have exited) close(responseChan) }() @@ -872,14 +890,21 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { if err := conn.SetReadDeadline(time.Now().Add(timeoutConfig.ReadTimeout)); err != nil { } - // Read the message - // OPTIMIZATION: Use buffer pool to reduce GC pressure (was 1MB/sec at 1000 req/s) - messageBuf := mem.Allocate(int(size)) - defer mem.Free(messageBuf) - if _, err := io.ReadFull(r, messageBuf); err != nil { + // Read the message into a pooled buffer, then copy to owned memory + // and return the pool buffer immediately. The previous code used + // defer mem.Free which accumulated one deferred free per iteration, + // leaking pool buffers for the connection lifetime and risking + // use-after-free when the defers ran before processing goroutines + // finished draining. + poolBuf := mem.Allocate(int(size)) + if _, err := io.ReadFull(r, poolBuf); err != nil { + mem.Free(poolBuf) _ = HandleTimeoutError(err, "read") // errorCode return fmt.Errorf("read message: %w", err) } + messageBuf := make([]byte, int(size)) + copy(messageBuf, poolBuf) + mem.Free(poolBuf) // Parse at least the basic header to get API key and correlation ID if len(messageBuf) < 8 { @@ -898,6 +923,13 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { if writeErr != nil { return fmt.Errorf("build error response: %w", writeErr) } + // Add to correlation queue BEFORE sending to responseChan so the + // response writer can match and send it. Without this the response + // sits in pendingResponses forever and the client hangs. + correlationQueueMu.Lock() + correlationQueue = append(correlationQueue, correlationID) + correlationQueueMu.Unlock() + // Send error response through response queue to maintain sequential ordering select { case responseChan <- &kafkaResponse{ @@ -1016,28 +1048,49 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { // Route to appropriate channel based on API key var targetChan chan *kafkaRequest - if apiKey == 2 { // ListOffsets - } if isDataPlaneAPI(apiKey) { targetChan = dataChan } else { targetChan = controlChan } - // Only add to correlation queue AFTER successful channel send - // If we add before and the channel blocks, the correlation ID is in the queue - // but the request never gets processed, causing response writer deadlock + // Add correlation ID BEFORE channel send to prevent race condition: + // If we add after, the processor can finish and send the response before the + // ID is in the queue. The response writer then can't match the response to a + // queue entry, causing a deadlock for sequential clients (e.g., Sarama). + // If the channel send fails, we send an error response so the response writer + // can advance past this entry. + correlationQueueMu.Lock() + correlationQueue = append(correlationQueue, correlationID) + correlationQueueMu.Unlock() + select { case targetChan <- req: - // Request queued successfully - NOW add to correlation tracking - correlationQueueMu.Lock() - correlationQueue = append(correlationQueue, correlationID) - correlationQueueMu.Unlock() + // Request queued successfully case <-ctx.Done(): + // Context cancelled - send error response so response writer can advance + select { + case responseChan <- &kafkaResponse{ + correlationID: correlationID, + apiKey: apiKey, + apiVersion: apiVersion, + err: ctx.Err(), + }: + default: + } return ctx.Err() case <-time.After(10 * time.Second): - // Channel full for too long - this shouldn't happen with proper backpressure + // Channel full for too long - send error response so response writer can advance glog.Errorf("[%s] Failed to queue correlation=%d - channel full (10s timeout)", connectionID, correlationID) + select { + case responseChan <- &kafkaResponse{ + correlationID: correlationID, + apiKey: apiKey, + apiVersion: apiVersion, + err: fmt.Errorf("request queue full"), + }: + default: + } return fmt.Errorf("request queue full: correlation=%d", correlationID) } } diff --git a/weed/mq/kafka/protocol/produce.go b/weed/mq/kafka/protocol/produce.go index 849d1148d..a35a1fb21 100644 --- a/weed/mq/kafka/protocol/produce.go +++ b/weed/mq/kafka/protocol/produce.go @@ -28,31 +28,21 @@ func (h *Handler) handleProduce(ctx context.Context, correlationID uint32, apiVe } func (h *Handler) handleProduceV0V1(ctx context.Context, correlationID uint32, apiVersion uint16, requestBody []byte) ([]byte, error) { - // Parse Produce v0/v1 request - // Request format: client_id + acks(2) + timeout(4) + topics_array + // Parse Produce v0/v1 request body (client_id already stripped in HandleConn) + // Body format: acks(INT16) + timeout_ms(INT32) + topics(ARRAY) - if len(requestBody) < 8 { // client_id_size(2) + acks(2) + timeout(4) + if len(requestBody) < 10 { // acks(2) + timeout_ms(4) + topics_count(4) return nil, fmt.Errorf("Produce request too short") } - // Skip client_id - clientIDSize := binary.BigEndian.Uint16(requestBody[0:2]) - - if len(requestBody) < 2+int(clientIDSize) { - return nil, fmt.Errorf("Produce request client_id too short") - } - - _ = string(requestBody[2 : 2+int(clientIDSize)]) // clientID - offset := 2 + int(clientIDSize) - - if len(requestBody) < offset+10 { // acks(2) + timeout(4) + topics_count(4) - return nil, fmt.Errorf("Produce request missing data") - } + offset := 0 - // Parse acks and timeout _ = int16(binary.BigEndian.Uint16(requestBody[offset : offset+2])) // acks offset += 2 + _ = binary.BigEndian.Uint32(requestBody[offset : offset+4]) // timeout_ms + offset += 4 + topicsCount := binary.BigEndian.Uint32(requestBody[offset : offset+4]) offset += 4