diff --git a/weed/mq/kafka/protocol/consumer_coordination.go b/weed/mq/kafka/protocol/consumer_coordination.go new file mode 100644 index 000000000..baac1e8bf --- /dev/null +++ b/weed/mq/kafka/protocol/consumer_coordination.go @@ -0,0 +1,387 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "time" + + "github.com/seaweedfs/seaweedfs/weed/mq/kafka/consumer" +) + +// Heartbeat API (key 12) - Consumer group heartbeat +// Consumers send periodic heartbeats to stay in the group and receive rebalancing signals + +// HeartbeatRequest represents a Heartbeat request from a Kafka client +type HeartbeatRequest struct { + GroupID string + GenerationID int32 + MemberID string + GroupInstanceID string // Optional static membership ID +} + +// HeartbeatResponse represents a Heartbeat response to a Kafka client +type HeartbeatResponse struct { + CorrelationID uint32 + ErrorCode int16 +} + +// LeaveGroup API (key 13) - Consumer graceful departure +// Consumers call this when shutting down to trigger immediate rebalancing + +// LeaveGroupRequest represents a LeaveGroup request from a Kafka client +type LeaveGroupRequest struct { + GroupID string + MemberID string + GroupInstanceID string // Optional static membership ID + Members []LeaveGroupMember // For newer versions, can leave multiple members +} + +// LeaveGroupMember represents a member leaving the group (for batch departures) +type LeaveGroupMember struct { + MemberID string + GroupInstanceID string + Reason string // Optional reason for leaving +} + +// LeaveGroupResponse represents a LeaveGroup response to a Kafka client +type LeaveGroupResponse struct { + CorrelationID uint32 + ErrorCode int16 + Members []LeaveGroupMemberResponse // Per-member responses for newer versions +} + +// LeaveGroupMemberResponse represents per-member leave group response +type LeaveGroupMemberResponse struct { + MemberID string + GroupInstanceID string + ErrorCode int16 +} + +// Error codes specific to consumer coordination +const ( + ErrorCodeUnstableOffsetCommit int16 = 95 // Consumer group is rebalancing + ErrorCodeGroupMaxSizeReached int16 = 84 // Group has reached maximum size +) + +func (h *Handler) handleHeartbeat(correlationID uint32, requestBody []byte) ([]byte, error) { + // Parse Heartbeat request + request, err := h.parseHeartbeatRequest(requestBody) + if err != nil { + return h.buildHeartbeatErrorResponse(correlationID, ErrorCodeInvalidGroupID), nil + } + + // Validate request + if request.GroupID == "" || request.MemberID == "" { + return h.buildHeartbeatErrorResponse(correlationID, ErrorCodeInvalidGroupID), nil + } + + // Get consumer group + group := h.groupCoordinator.GetGroup(request.GroupID) + if group == nil { + return h.buildHeartbeatErrorResponse(correlationID, ErrorCodeInvalidGroupID), nil + } + + group.Mu.Lock() + defer group.Mu.Unlock() + + // Update group's last activity + group.LastActivity = time.Now() + + // Validate member exists + member, exists := group.Members[request.MemberID] + if !exists { + return h.buildHeartbeatErrorResponse(correlationID, ErrorCodeUnknownMemberID), nil + } + + // Validate generation + if request.GenerationID != group.Generation { + return h.buildHeartbeatErrorResponse(correlationID, ErrorCodeIllegalGeneration), nil + } + + // Update member's last heartbeat + member.LastHeartbeat = time.Now() + + // Check if rebalancing is in progress + var errorCode int16 = ErrorCodeNone + switch group.State { + case consumer.GroupStatePreparingRebalance, consumer.GroupStateCompletingRebalance: + // Signal the consumer that rebalancing is happening + errorCode = ErrorCodeRebalanceInProgress + case consumer.GroupStateDead: + errorCode = ErrorCodeInvalidGroupID + case consumer.GroupStateEmpty: + // This shouldn't happen if member exists, but handle gracefully + errorCode = ErrorCodeUnknownMemberID + case consumer.GroupStateStable: + // Normal case - heartbeat accepted + errorCode = ErrorCodeNone + } + + // Build successful response + response := HeartbeatResponse{ + CorrelationID: correlationID, + ErrorCode: errorCode, + } + + return h.buildHeartbeatResponse(response), nil +} + +func (h *Handler) handleLeaveGroup(correlationID uint32, requestBody []byte) ([]byte, error) { + // Parse LeaveGroup request + request, err := h.parseLeaveGroupRequest(requestBody) + if err != nil { + return h.buildLeaveGroupErrorResponse(correlationID, ErrorCodeInvalidGroupID), nil + } + + // Validate request + if request.GroupID == "" || request.MemberID == "" { + return h.buildLeaveGroupErrorResponse(correlationID, ErrorCodeInvalidGroupID), nil + } + + // Get consumer group + group := h.groupCoordinator.GetGroup(request.GroupID) + if group == nil { + return h.buildLeaveGroupErrorResponse(correlationID, ErrorCodeInvalidGroupID), nil + } + + group.Mu.Lock() + defer group.Mu.Unlock() + + // Update group's last activity + group.LastActivity = time.Now() + + // Validate member exists + _, exists := group.Members[request.MemberID] + if !exists { + return h.buildLeaveGroupErrorResponse(correlationID, ErrorCodeUnknownMemberID), nil + } + + // Remove the member from the group + delete(group.Members, request.MemberID) + + // Update group state based on remaining members + if len(group.Members) == 0 { + // Group becomes empty + group.State = consumer.GroupStateEmpty + group.Generation++ + group.Leader = "" + } else { + // Trigger rebalancing for remaining members + group.State = consumer.GroupStatePreparingRebalance + group.Generation++ + + // If the leaving member was the leader, select a new leader + if group.Leader == request.MemberID { + // Select first remaining member as new leader + for memberID := range group.Members { + group.Leader = memberID + break + } + } + + // Mark remaining members as pending to trigger rebalancing + for _, member := range group.Members { + member.State = consumer.MemberStatePending + } + } + + // Update group's subscribed topics (may have changed with member leaving) + h.updateGroupSubscriptionFromMembers(group) + + // Build successful response + response := LeaveGroupResponse{ + CorrelationID: correlationID, + ErrorCode: ErrorCodeNone, + Members: []LeaveGroupMemberResponse{ + { + MemberID: request.MemberID, + GroupInstanceID: request.GroupInstanceID, + ErrorCode: ErrorCodeNone, + }, + }, + } + + return h.buildLeaveGroupResponse(response), nil +} + +func (h *Handler) parseHeartbeatRequest(data []byte) (*HeartbeatRequest, error) { + if len(data) < 8 { + return nil, fmt.Errorf("request too short") + } + + offset := 0 + + // GroupID (string) + groupIDLength := int(binary.BigEndian.Uint16(data[offset:])) + offset += 2 + if offset+groupIDLength > len(data) { + return nil, fmt.Errorf("invalid group ID length") + } + groupID := string(data[offset : offset+groupIDLength]) + offset += groupIDLength + + // Generation ID (4 bytes) + if offset+4 > len(data) { + return nil, fmt.Errorf("missing generation ID") + } + generationID := int32(binary.BigEndian.Uint32(data[offset:])) + offset += 4 + + // MemberID (string) + if offset+2 > len(data) { + return nil, fmt.Errorf("missing member ID length") + } + memberIDLength := int(binary.BigEndian.Uint16(data[offset:])) + offset += 2 + if offset+memberIDLength > len(data) { + return nil, fmt.Errorf("invalid member ID length") + } + memberID := string(data[offset : offset+memberIDLength]) + offset += memberIDLength + + return &HeartbeatRequest{ + GroupID: groupID, + GenerationID: generationID, + MemberID: memberID, + GroupInstanceID: "", // Simplified - would parse from remaining data + }, nil +} + +func (h *Handler) parseLeaveGroupRequest(data []byte) (*LeaveGroupRequest, error) { + if len(data) < 4 { + return nil, fmt.Errorf("request too short") + } + + offset := 0 + + // GroupID (string) + groupIDLength := int(binary.BigEndian.Uint16(data[offset:])) + offset += 2 + if offset+groupIDLength > len(data) { + return nil, fmt.Errorf("invalid group ID length") + } + groupID := string(data[offset : offset+groupIDLength]) + offset += groupIDLength + + // MemberID (string) + if offset+2 > len(data) { + return nil, fmt.Errorf("missing member ID length") + } + memberIDLength := int(binary.BigEndian.Uint16(data[offset:])) + offset += 2 + if offset+memberIDLength > len(data) { + return nil, fmt.Errorf("invalid member ID length") + } + memberID := string(data[offset : offset+memberIDLength]) + offset += memberIDLength + + return &LeaveGroupRequest{ + GroupID: groupID, + MemberID: memberID, + GroupInstanceID: "", // Simplified - would parse from remaining data + Members: []LeaveGroupMember{}, // Would parse members array for batch operations + }, nil +} + +func (h *Handler) buildHeartbeatResponse(response HeartbeatResponse) []byte { + result := make([]byte, 0, 12) + + // Correlation ID (4 bytes) + correlationIDBytes := make([]byte, 4) + binary.BigEndian.PutUint32(correlationIDBytes, response.CorrelationID) + result = append(result, correlationIDBytes...) + + // Error code (2 bytes) + errorCodeBytes := make([]byte, 2) + binary.BigEndian.PutUint16(errorCodeBytes, uint16(response.ErrorCode)) + result = append(result, errorCodeBytes...) + + // Throttle time (4 bytes, 0 = no throttling) + result = append(result, 0, 0, 0, 0) + + return result +} + +func (h *Handler) buildLeaveGroupResponse(response LeaveGroupResponse) []byte { + estimatedSize := 16 + for _, member := range response.Members { + estimatedSize += len(member.MemberID) + len(member.GroupInstanceID) + 8 + } + + result := make([]byte, 0, estimatedSize) + + // Correlation ID (4 bytes) + correlationIDBytes := make([]byte, 4) + binary.BigEndian.PutUint32(correlationIDBytes, response.CorrelationID) + result = append(result, correlationIDBytes...) + + // Error code (2 bytes) + errorCodeBytes := make([]byte, 2) + binary.BigEndian.PutUint16(errorCodeBytes, uint16(response.ErrorCode)) + result = append(result, errorCodeBytes...) + + // Members array length (4 bytes) + membersLengthBytes := make([]byte, 4) + binary.BigEndian.PutUint32(membersLengthBytes, uint32(len(response.Members))) + result = append(result, membersLengthBytes...) + + // Members + for _, member := range response.Members { + // Member ID length (2 bytes) + memberIDLength := make([]byte, 2) + binary.BigEndian.PutUint16(memberIDLength, uint16(len(member.MemberID))) + result = append(result, memberIDLength...) + + // Member ID + result = append(result, []byte(member.MemberID)...) + + // Group instance ID length (2 bytes) + instanceIDLength := make([]byte, 2) + binary.BigEndian.PutUint16(instanceIDLength, uint16(len(member.GroupInstanceID))) + result = append(result, instanceIDLength...) + + // Group instance ID + if len(member.GroupInstanceID) > 0 { + result = append(result, []byte(member.GroupInstanceID)...) + } + + // Error code (2 bytes) + memberErrorBytes := make([]byte, 2) + binary.BigEndian.PutUint16(memberErrorBytes, uint16(member.ErrorCode)) + result = append(result, memberErrorBytes...) + } + + // Throttle time (4 bytes, 0 = no throttling) + result = append(result, 0, 0, 0, 0) + + return result +} + +func (h *Handler) buildHeartbeatErrorResponse(correlationID uint32, errorCode int16) []byte { + response := HeartbeatResponse{ + CorrelationID: correlationID, + ErrorCode: errorCode, + } + + return h.buildHeartbeatResponse(response) +} + +func (h *Handler) buildLeaveGroupErrorResponse(correlationID uint32, errorCode int16) []byte { + response := LeaveGroupResponse{ + CorrelationID: correlationID, + ErrorCode: errorCode, + Members: []LeaveGroupMemberResponse{}, + } + + return h.buildLeaveGroupResponse(response) +} + +func (h *Handler) updateGroupSubscriptionFromMembers(group *consumer.ConsumerGroup) { + // Update group's subscribed topics from remaining members + group.SubscribedTopics = make(map[string]bool) + for _, member := range group.Members { + for _, topic := range member.Subscription { + group.SubscribedTopics[topic] = true + } + } +} diff --git a/weed/mq/kafka/protocol/consumer_coordination_test.go b/weed/mq/kafka/protocol/consumer_coordination_test.go new file mode 100644 index 000000000..6bd6f803f --- /dev/null +++ b/weed/mq/kafka/protocol/consumer_coordination_test.go @@ -0,0 +1,559 @@ +package protocol + +import ( + "encoding/binary" + "net" + "testing" + "time" + + "github.com/seaweedfs/seaweedfs/weed/mq/kafka/consumer" +) + +func TestHandler_handleHeartbeat(t *testing.T) { + h := NewHandler() + defer h.Close() + + // Create a consumer group with a stable member + group := h.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStateStable + group.Generation = 1 + group.Members["member1"] = &consumer.GroupMember{ + ID: "member1", + State: consumer.MemberStateStable, + LastHeartbeat: time.Now().Add(-5 * time.Second), // 5 seconds ago + } + group.Mu.Unlock() + + // Create a basic heartbeat request + requestBody := createHeartbeatRequestBody("test-group", 1, "member1") + + correlationID := uint32(123) + response, err := h.handleHeartbeat(correlationID, requestBody) + + if err != nil { + t.Fatalf("handleHeartbeat failed: %v", err) + } + + if len(response) < 8 { + t.Fatalf("response too short: %d bytes", len(response)) + } + + // Check correlation ID in response + respCorrelationID := binary.BigEndian.Uint32(response[0:4]) + if respCorrelationID != correlationID { + t.Errorf("expected correlation ID %d, got %d", correlationID, respCorrelationID) + } + + // Check error code (should be ErrorCodeNone for successful heartbeat) + errorCode := int16(binary.BigEndian.Uint16(response[4:6])) + if errorCode != ErrorCodeNone { + t.Errorf("expected error code %d, got %d", ErrorCodeNone, errorCode) + } + + // Verify heartbeat timestamp was updated + group.Mu.RLock() + member := group.Members["member1"] + if time.Since(member.LastHeartbeat) > 1*time.Second { + t.Error("heartbeat timestamp was not updated") + } + group.Mu.RUnlock() +} + +func TestHandler_handleHeartbeat_RebalanceInProgress(t *testing.T) { + h := NewHandler() + defer h.Close() + + // Create a consumer group in rebalancing state + group := h.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStatePreparingRebalance // Rebalancing + group.Generation = 1 + group.Members["member1"] = &consumer.GroupMember{ + ID: "member1", + State: consumer.MemberStatePending, + LastHeartbeat: time.Now().Add(-5 * time.Second), + } + group.Mu.Unlock() + + requestBody := createHeartbeatRequestBody("test-group", 1, "member1") + + correlationID := uint32(124) + response, err := h.handleHeartbeat(correlationID, requestBody) + + if err != nil { + t.Fatalf("handleHeartbeat failed: %v", err) + } + + if len(response) < 8 { + t.Fatalf("response too short: %d bytes", len(response)) + } + + // Should return ErrorCodeRebalanceInProgress + errorCode := int16(binary.BigEndian.Uint16(response[4:6])) + if errorCode != ErrorCodeRebalanceInProgress { + t.Errorf("expected error code %d (rebalance in progress), got %d", ErrorCodeRebalanceInProgress, errorCode) + } +} + +func TestHandler_handleHeartbeat_WrongGeneration(t *testing.T) { + h := NewHandler() + defer h.Close() + + // Create a consumer group with generation 2 + group := h.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStateStable + group.Generation = 2 + group.Members["member1"] = &consumer.GroupMember{ + ID: "member1", + State: consumer.MemberStateStable, + LastHeartbeat: time.Now().Add(-5 * time.Second), + } + group.Mu.Unlock() + + // Send heartbeat with wrong generation (1 instead of 2) + requestBody := createHeartbeatRequestBody("test-group", 1, "member1") + + correlationID := uint32(125) + response, err := h.handleHeartbeat(correlationID, requestBody) + + if err != nil { + t.Fatalf("handleHeartbeat failed: %v", err) + } + + // Should return ErrorCodeIllegalGeneration + errorCode := int16(binary.BigEndian.Uint16(response[4:6])) + if errorCode != ErrorCodeIllegalGeneration { + t.Errorf("expected error code %d (illegal generation), got %d", ErrorCodeIllegalGeneration, errorCode) + } +} + +func TestHandler_handleHeartbeat_UnknownMember(t *testing.T) { + h := NewHandler() + defer h.Close() + + // Create a consumer group without the requested member + group := h.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStateStable + group.Generation = 1 + // No members in group + group.Mu.Unlock() + + requestBody := createHeartbeatRequestBody("test-group", 1, "unknown-member") + + correlationID := uint32(126) + response, err := h.handleHeartbeat(correlationID, requestBody) + + if err != nil { + t.Fatalf("handleHeartbeat failed: %v", err) + } + + // Should return ErrorCodeUnknownMemberID + errorCode := int16(binary.BigEndian.Uint16(response[4:6])) + if errorCode != ErrorCodeUnknownMemberID { + t.Errorf("expected error code %d (unknown member), got %d", ErrorCodeUnknownMemberID, errorCode) + } +} + +func TestHandler_handleLeaveGroup(t *testing.T) { + h := NewHandler() + defer h.Close() + + // Create a consumer group with multiple members + group := h.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStateStable + group.Generation = 1 + group.Leader = "member1" + group.Members["member1"] = &consumer.GroupMember{ + ID: "member1", + State: consumer.MemberStateStable, + Subscription: []string{"topic1"}, + } + group.Members["member2"] = &consumer.GroupMember{ + ID: "member2", + State: consumer.MemberStateStable, + Subscription: []string{"topic1", "topic2"}, + } + group.SubscribedTopics = map[string]bool{ + "topic1": true, + "topic2": true, + } + group.Mu.Unlock() + + // Create a leave group request + requestBody := createLeaveGroupRequestBody("test-group", "member2") + + correlationID := uint32(127) + response, err := h.handleLeaveGroup(correlationID, requestBody) + + if err != nil { + t.Fatalf("handleLeaveGroup failed: %v", err) + } + + if len(response) < 8 { + t.Fatalf("response too short: %d bytes", len(response)) + } + + // Check correlation ID in response + respCorrelationID := binary.BigEndian.Uint32(response[0:4]) + if respCorrelationID != correlationID { + t.Errorf("expected correlation ID %d, got %d", correlationID, respCorrelationID) + } + + // Check error code (should be ErrorCodeNone for successful leave) + errorCode := int16(binary.BigEndian.Uint16(response[4:6])) + if errorCode != ErrorCodeNone { + t.Errorf("expected error code %d, got %d", ErrorCodeNone, errorCode) + } + + // Verify member was removed and group state updated + group.Mu.RLock() + if _, exists := group.Members["member2"]; exists { + t.Error("member2 should have been removed from group") + } + + if len(group.Members) != 1 { + t.Errorf("expected 1 remaining member, got %d", len(group.Members)) + } + + // Group should be in rebalancing state + if group.State != consumer.GroupStatePreparingRebalance { + t.Errorf("expected group state PreparingRebalance, got %s", group.State) + } + + // Generation should be incremented + if group.Generation != 2 { + t.Errorf("expected generation 2, got %d", group.Generation) + } + + // Subscribed topics should be updated (only topic1 remains) + if len(group.SubscribedTopics) != 1 || !group.SubscribedTopics["topic1"] { + t.Error("group subscribed topics were not updated correctly") + } + + if group.SubscribedTopics["topic2"] { + t.Error("topic2 should have been removed from subscribed topics") + } + group.Mu.RUnlock() +} + +func TestHandler_handleLeaveGroup_LastMember(t *testing.T) { + h := NewHandler() + defer h.Close() + + // Create a consumer group with only one member + group := h.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStateStable + group.Generation = 1 + group.Leader = "member1" + group.Members["member1"] = &consumer.GroupMember{ + ID: "member1", + State: consumer.MemberStateStable, + Subscription: []string{"topic1"}, + } + group.Mu.Unlock() + + requestBody := createLeaveGroupRequestBody("test-group", "member1") + + correlationID := uint32(128) + response, err := h.handleLeaveGroup(correlationID, requestBody) + + if err != nil { + t.Fatalf("handleLeaveGroup failed: %v", err) + } + + // Check error code + errorCode := int16(binary.BigEndian.Uint16(response[4:6])) + if errorCode != ErrorCodeNone { + t.Errorf("expected error code %d, got %d", ErrorCodeNone, errorCode) + } + + // Verify group became empty + group.Mu.RLock() + if len(group.Members) != 0 { + t.Errorf("expected 0 members, got %d", len(group.Members)) + } + + if group.State != consumer.GroupStateEmpty { + t.Errorf("expected group state Empty, got %s", group.State) + } + + if group.Leader != "" { + t.Errorf("expected empty leader, got %s", group.Leader) + } + + if group.Generation != 2 { + t.Errorf("expected generation 2, got %d", group.Generation) + } + group.Mu.RUnlock() +} + +func TestHandler_handleLeaveGroup_LeaderLeaves(t *testing.T) { + h := NewHandler() + defer h.Close() + + // Create a consumer group where the leader is leaving + group := h.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStateStable + group.Generation = 1 + group.Leader = "leader-member" + group.Members["leader-member"] = &consumer.GroupMember{ + ID: "leader-member", + State: consumer.MemberStateStable, + } + group.Members["other-member"] = &consumer.GroupMember{ + ID: "other-member", + State: consumer.MemberStateStable, + } + group.Mu.Unlock() + + requestBody := createLeaveGroupRequestBody("test-group", "leader-member") + + correlationID := uint32(129) + _, err := h.handleLeaveGroup(correlationID, requestBody) + + if err != nil { + t.Fatalf("handleLeaveGroup failed: %v", err) + } + + // Verify leader was changed + group.Mu.RLock() + if group.Leader == "leader-member" { + t.Error("leader should have been changed after leader left") + } + + if group.Leader != "other-member" { + t.Errorf("expected new leader to be 'other-member', got '%s'", group.Leader) + } + + if len(group.Members) != 1 { + t.Errorf("expected 1 remaining member, got %d", len(group.Members)) + } + group.Mu.RUnlock() +} + +func TestHandler_parseHeartbeatRequest(t *testing.T) { + h := NewHandler() + defer h.Close() + + requestBody := createHeartbeatRequestBody("test-group", 1, "member1") + + request, err := h.parseHeartbeatRequest(requestBody) + if err != nil { + t.Fatalf("parseHeartbeatRequest failed: %v", err) + } + + if request.GroupID != "test-group" { + t.Errorf("expected group ID 'test-group', got '%s'", request.GroupID) + } + + if request.GenerationID != 1 { + t.Errorf("expected generation ID 1, got %d", request.GenerationID) + } + + if request.MemberID != "member1" { + t.Errorf("expected member ID 'member1', got '%s'", request.MemberID) + } +} + +func TestHandler_parseLeaveGroupRequest(t *testing.T) { + h := NewHandler() + defer h.Close() + + requestBody := createLeaveGroupRequestBody("test-group", "member1") + + request, err := h.parseLeaveGroupRequest(requestBody) + if err != nil { + t.Fatalf("parseLeaveGroupRequest failed: %v", err) + } + + if request.GroupID != "test-group" { + t.Errorf("expected group ID 'test-group', got '%s'", request.GroupID) + } + + if request.MemberID != "member1" { + t.Errorf("expected member ID 'member1', got '%s'", request.MemberID) + } +} + +func TestHandler_buildHeartbeatResponse(t *testing.T) { + h := NewHandler() + defer h.Close() + + response := HeartbeatResponse{ + CorrelationID: 123, + ErrorCode: ErrorCodeRebalanceInProgress, + } + + responseBytes := h.buildHeartbeatResponse(response) + + if len(responseBytes) != 10 { // 4 + 2 + 4 bytes + t.Fatalf("expected response length 10, got %d", len(responseBytes)) + } + + // Check correlation ID + correlationID := binary.BigEndian.Uint32(responseBytes[0:4]) + if correlationID != 123 { + t.Errorf("expected correlation ID 123, got %d", correlationID) + } + + // Check error code + errorCode := int16(binary.BigEndian.Uint16(responseBytes[4:6])) + if errorCode != ErrorCodeRebalanceInProgress { + t.Errorf("expected error code %d, got %d", ErrorCodeRebalanceInProgress, errorCode) + } +} + +func TestHandler_buildLeaveGroupResponse(t *testing.T) { + h := NewHandler() + defer h.Close() + + response := LeaveGroupResponse{ + CorrelationID: 124, + ErrorCode: ErrorCodeNone, + Members: []LeaveGroupMemberResponse{ + { + MemberID: "member1", + GroupInstanceID: "", + ErrorCode: ErrorCodeNone, + }, + }, + } + + responseBytes := h.buildLeaveGroupResponse(response) + + if len(responseBytes) < 16 { + t.Fatalf("response too short: %d bytes", len(responseBytes)) + } + + // Check correlation ID + correlationID := binary.BigEndian.Uint32(responseBytes[0:4]) + if correlationID != 124 { + t.Errorf("expected correlation ID 124, got %d", correlationID) + } + + // Check error code + errorCode := int16(binary.BigEndian.Uint16(responseBytes[4:6])) + if errorCode != ErrorCodeNone { + t.Errorf("expected error code %d, got %d", ErrorCodeNone, errorCode) + } +} + +func TestHandler_HeartbeatLeaveGroup_EndToEnd(t *testing.T) { + // Create two handlers connected via pipe to simulate client-server + server := NewHandler() + defer server.Close() + + client := NewHandler() + defer client.Close() + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + // Setup consumer group on server + group := server.groupCoordinator.GetOrCreateGroup("test-group") + group.Mu.Lock() + group.State = consumer.GroupStateStable + group.Generation = 1 + group.Leader = "member1" + group.Members["member1"] = &consumer.GroupMember{ + ID: "member1", + State: consumer.MemberStateStable, + LastHeartbeat: time.Now().Add(-10 * time.Second), + } + group.Mu.Unlock() + + // Test heartbeat + heartbeatRequestBody := createHeartbeatRequestBody("test-group", 1, "member1") + heartbeatResponse, err := server.handleHeartbeat(456, heartbeatRequestBody) + if err != nil { + t.Fatalf("heartbeat failed: %v", err) + } + + if len(heartbeatResponse) < 8 { + t.Fatalf("heartbeat response too short: %d bytes", len(heartbeatResponse)) + } + + // Verify heartbeat was processed + group.Mu.RLock() + member := group.Members["member1"] + if time.Since(member.LastHeartbeat) > 1*time.Second { + t.Error("heartbeat timestamp was not updated") + } + group.Mu.RUnlock() + + // Test leave group + leaveGroupRequestBody := createLeaveGroupRequestBody("test-group", "member1") + leaveGroupResponse, err := server.handleLeaveGroup(457, leaveGroupRequestBody) + if err != nil { + t.Fatalf("leave group failed: %v", err) + } + + if len(leaveGroupResponse) < 8 { + t.Fatalf("leave group response too short: %d bytes", len(leaveGroupResponse)) + } + + // Verify member left and group became empty + group.Mu.RLock() + if len(group.Members) != 0 { + t.Errorf("expected 0 members after leave, got %d", len(group.Members)) + } + + if group.State != consumer.GroupStateEmpty { + t.Errorf("expected group state Empty, got %s", group.State) + } + group.Mu.RUnlock() +} + +// Helper functions for creating test request bodies + +func createHeartbeatRequestBody(groupID string, generationID int32, memberID string) []byte { + body := make([]byte, 0, 64) + + // Group ID (string) + groupIDBytes := []byte(groupID) + groupIDLength := make([]byte, 2) + binary.BigEndian.PutUint16(groupIDLength, uint16(len(groupIDBytes))) + body = append(body, groupIDLength...) + body = append(body, groupIDBytes...) + + // Generation ID (4 bytes) + generationIDBytes := make([]byte, 4) + binary.BigEndian.PutUint32(generationIDBytes, uint32(generationID)) + body = append(body, generationIDBytes...) + + // Member ID (string) + memberIDBytes := []byte(memberID) + memberIDLength := make([]byte, 2) + binary.BigEndian.PutUint16(memberIDLength, uint16(len(memberIDBytes))) + body = append(body, memberIDLength...) + body = append(body, memberIDBytes...) + + return body +} + +func createLeaveGroupRequestBody(groupID string, memberID string) []byte { + body := make([]byte, 0, 32) + + // Group ID (string) + groupIDBytes := []byte(groupID) + groupIDLength := make([]byte, 2) + binary.BigEndian.PutUint16(groupIDLength, uint16(len(groupIDBytes))) + body = append(body, groupIDLength...) + body = append(body, groupIDBytes...) + + // Member ID (string) + memberIDBytes := []byte(memberID) + memberIDLength := make([]byte, 2) + binary.BigEndian.PutUint16(memberIDLength, uint16(len(memberIDBytes))) + body = append(body, memberIDLength...) + body = append(body, memberIDBytes...) + + return body +} diff --git a/weed/mq/kafka/protocol/handler.go b/weed/mq/kafka/protocol/handler.go index 27785f6d2..7aad2eba8 100644 --- a/weed/mq/kafka/protocol/handler.go +++ b/weed/mq/kafka/protocol/handler.go @@ -187,6 +187,10 @@ func (h *Handler) HandleConn(conn net.Conn) error { response, err = h.handleOffsetCommit(correlationID, messageBuf[8:]) // skip header case 9: // OffsetFetch response, err = h.handleOffsetFetch(correlationID, messageBuf[8:]) // skip header + case 12: // Heartbeat + response, err = h.handleHeartbeat(correlationID, messageBuf[8:]) // skip header + case 13: // LeaveGroup + response, err = h.handleLeaveGroup(correlationID, messageBuf[8:]) // skip header default: err = fmt.Errorf("unsupported API key: %d (version %d)", apiKey, apiVersion) } @@ -227,7 +231,7 @@ func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { response = append(response, 0, 0) // Number of API keys (compact array format in newer versions, but using basic format for simplicity) - response = append(response, 0, 0, 0, 11) // 11 API keys + response = append(response, 0, 0, 0, 13) // 13 API keys // API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2) response = append(response, 0, 18) // API key 18 @@ -284,6 +288,16 @@ func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { response = append(response, 0, 0) // min version 0 response = append(response, 0, 8) // max version 8 + // API Key 12 (Heartbeat): api_key(2) + min_version(2) + max_version(2) + response = append(response, 0, 12) // API key 12 + response = append(response, 0, 0) // min version 0 + response = append(response, 0, 4) // max version 4 + + // API Key 13 (LeaveGroup): api_key(2) + min_version(2) + max_version(2) + response = append(response, 0, 13) // API key 13 + response = append(response, 0, 0) // min version 0 + response = append(response, 0, 4) // max version 4 + // Throttle time (4 bytes, 0 = no throttling) response = append(response, 0, 0, 0, 0) diff --git a/weed/mq/kafka/protocol/handler_test.go b/weed/mq/kafka/protocol/handler_test.go index 2362252f1..c699e593a 100644 --- a/weed/mq/kafka/protocol/handler_test.go +++ b/weed/mq/kafka/protocol/handler_test.go @@ -92,8 +92,8 @@ func TestHandler_ApiVersions(t *testing.T) { // Check number of API keys numAPIKeys := binary.BigEndian.Uint32(respBuf[6:10]) - if numAPIKeys != 11 { - t.Errorf("expected 11 API keys, got: %d", numAPIKeys) + if numAPIKeys != 13 { + t.Errorf("expected 13 API keys, got: %d", numAPIKeys) } // Check API key details: api_key(2) + min_version(2) + max_version(2) @@ -229,7 +229,7 @@ func TestHandler_handleApiVersions(t *testing.T) { t.Fatalf("handleApiVersions: %v", err) } - if len(response) < 78 { // minimum expected size (now has 11 API keys) + if len(response) < 90 { // minimum expected size (now has 13 API keys) t.Fatalf("response too short: %d bytes", len(response)) } @@ -247,8 +247,8 @@ func TestHandler_handleApiVersions(t *testing.T) { // Check number of API keys numAPIKeys := binary.BigEndian.Uint32(response[6:10]) - if numAPIKeys != 11 { - t.Errorf("expected 11 API keys, got: %d", numAPIKeys) + if numAPIKeys != 13 { + t.Errorf("expected 13 API keys, got: %d", numAPIKeys) } // Check first API key (ApiVersions)