diff --git a/weed/mq/kafka/protocol/consumer_group_metadata.go b/weed/mq/kafka/protocol/consumer_group_metadata.go new file mode 100644 index 000000000..b90f7a9c2 --- /dev/null +++ b/weed/mq/kafka/protocol/consumer_group_metadata.go @@ -0,0 +1,300 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "net" + "strings" +) + +// ConsumerProtocolMetadata represents parsed consumer protocol metadata +type ConsumerProtocolMetadata struct { + Version int16 // Protocol metadata version + Topics []string // Subscribed topic names + UserData []byte // Optional user data + AssignmentStrategy string // Preferred assignment strategy +} + +// ConnectionContext holds connection-specific information for requests +type ConnectionContext struct { + RemoteAddr net.Addr // Client's remote address + LocalAddr net.Addr // Server's local address + ConnectionID string // Connection identifier +} + +// ExtractClientHost extracts the client hostname/IP from connection context +func ExtractClientHost(connCtx *ConnectionContext) string { + if connCtx == nil || connCtx.RemoteAddr == nil { + return "unknown" + } + + // Extract host portion from address + if tcpAddr, ok := connCtx.RemoteAddr.(*net.TCPAddr); ok { + return tcpAddr.IP.String() + } + + // Fallback: parse string representation + addrStr := connCtx.RemoteAddr.String() + if host, _, err := net.SplitHostPort(addrStr); err == nil { + return host + } + + // Last resort: return full address + return addrStr +} + +// ParseConsumerProtocolMetadata parses consumer protocol metadata with enhanced error handling +func ParseConsumerProtocolMetadata(metadata []byte, strategyName string) (*ConsumerProtocolMetadata, error) { + if len(metadata) < 2 { + return &ConsumerProtocolMetadata{ + Version: 0, + Topics: []string{}, + UserData: []byte{}, + AssignmentStrategy: strategyName, + }, nil + } + + result := &ConsumerProtocolMetadata{ + AssignmentStrategy: strategyName, + } + + offset := 0 + + // Parse version (2 bytes) + if len(metadata) < offset+2 { + return nil, fmt.Errorf("metadata too short for version field") + } + result.Version = int16(binary.BigEndian.Uint16(metadata[offset : offset+2])) + offset += 2 + + // Parse topics array + if len(metadata) < offset+4 { + return nil, fmt.Errorf("metadata too short for topics count") + } + topicsCount := binary.BigEndian.Uint32(metadata[offset : offset+4]) + offset += 4 + + // Validate topics count (reasonable limit) + if topicsCount > 10000 { + return nil, fmt.Errorf("unreasonable topics count: %d", topicsCount) + } + + result.Topics = make([]string, 0, topicsCount) + + for i := uint32(0); i < topicsCount && offset < len(metadata); i++ { + // Parse topic name length + if len(metadata) < offset+2 { + return nil, fmt.Errorf("metadata too short for topic %d name length", i) + } + topicNameLength := binary.BigEndian.Uint16(metadata[offset : offset+2]) + offset += 2 + + // Validate topic name length + if topicNameLength > 1000 { + return nil, fmt.Errorf("unreasonable topic name length: %d", topicNameLength) + } + + if len(metadata) < offset+int(topicNameLength) { + return nil, fmt.Errorf("metadata too short for topic %d name data", i) + } + + topicName := string(metadata[offset : offset+int(topicNameLength)]) + offset += int(topicNameLength) + + // Validate topic name (basic validation) + if len(topicName) == 0 { + continue // Skip empty topic names + } + + result.Topics = append(result.Topics, topicName) + } + + // Parse user data if remaining bytes exist + if len(metadata) >= offset+4 { + userDataLength := binary.BigEndian.Uint32(metadata[offset : offset+4]) + offset += 4 + + // Validate user data length + if userDataLength > 100000 { // 100KB limit + return nil, fmt.Errorf("unreasonable user data length: %d", userDataLength) + } + + if len(metadata) >= offset+int(userDataLength) { + result.UserData = make([]byte, userDataLength) + copy(result.UserData, metadata[offset:offset+int(userDataLength)]) + } + } + + return result, nil +} + +// GenerateConsumerProtocolMetadata creates protocol metadata for a consumer subscription +func GenerateConsumerProtocolMetadata(topics []string, userData []byte) []byte { + // Calculate total size needed + size := 2 + 4 + 4 // version + topics_count + user_data_length + for _, topic := range topics { + size += 2 + len(topic) // topic_name_length + topic_name + } + size += len(userData) + + metadata := make([]byte, 0, size) + + // Version (2 bytes) - use version 1 + metadata = append(metadata, 0, 1) + + // Topics count (4 bytes) + topicsCount := make([]byte, 4) + binary.BigEndian.PutUint32(topicsCount, uint32(len(topics))) + metadata = append(metadata, topicsCount...) + + // Topics (string array) + for _, topic := range topics { + topicLen := make([]byte, 2) + binary.BigEndian.PutUint16(topicLen, uint16(len(topic))) + metadata = append(metadata, topicLen...) + metadata = append(metadata, []byte(topic)...) + } + + // UserData length and data (4 bytes + data) + userDataLen := make([]byte, 4) + binary.BigEndian.PutUint32(userDataLen, uint32(len(userData))) + metadata = append(metadata, userDataLen...) + metadata = append(metadata, userData...) + + return metadata +} + +// ValidateAssignmentStrategy checks if an assignment strategy is supported +func ValidateAssignmentStrategy(strategy string) bool { + supportedStrategies := map[string]bool{ + "range": true, + "roundrobin": true, + "sticky": true, + "cooperative-sticky": false, // Not yet implemented + } + + return supportedStrategies[strategy] +} + +// ExtractTopicsFromMetadata extracts topic list from protocol metadata with fallback +func ExtractTopicsFromMetadata(protocols []GroupProtocol, fallbackTopics []string) []string { + for _, protocol := range protocols { + if ValidateAssignmentStrategy(protocol.Name) { + parsed, err := ParseConsumerProtocolMetadata(protocol.Metadata, protocol.Name) + if err != nil { + fmt.Printf("DEBUG: Failed to parse protocol metadata for %s: %v\n", protocol.Name, err) + continue + } + + if len(parsed.Topics) > 0 { + fmt.Printf("DEBUG: Extracted %d topics from %s protocol: %v\n", + len(parsed.Topics), protocol.Name, parsed.Topics) + return parsed.Topics + } + } + } + + // Fallback to provided topics or default + if len(fallbackTopics) > 0 { + fmt.Printf("DEBUG: Using fallback topics: %v\n", fallbackTopics) + return fallbackTopics + } + + fmt.Printf("DEBUG: No topics found, using default test topic\n") + return []string{"test-topic"} +} + +// SelectBestProtocol chooses the best assignment protocol from available options +func SelectBestProtocol(protocols []GroupProtocol, groupProtocols []string) string { + // Priority order: sticky > roundrobin > range + protocolPriority := []string{"sticky", "roundrobin", "range"} + + // Find supported protocols in client's list + clientProtocols := make(map[string]bool) + for _, protocol := range protocols { + if ValidateAssignmentStrategy(protocol.Name) { + clientProtocols[protocol.Name] = true + } + } + + // Find supported protocols in group's list + groupProtocolSet := make(map[string]bool) + for _, protocol := range groupProtocols { + groupProtocolSet[protocol] = true + } + + // Select highest priority protocol that both client and group support + for _, preferred := range protocolPriority { + if clientProtocols[preferred] && (len(groupProtocols) == 0 || groupProtocolSet[preferred]) { + return preferred + } + } + + // Fallback to first supported protocol from client + for _, protocol := range protocols { + if ValidateAssignmentStrategy(protocol.Name) { + return protocol.Name + } + } + + // Last resort + return "range" +} + +// SanitizeConsumerGroupID validates and sanitizes consumer group ID +func SanitizeConsumerGroupID(groupID string) (string, error) { + if len(groupID) == 0 { + return "", fmt.Errorf("empty group ID") + } + + if len(groupID) > 255 { + return "", fmt.Errorf("group ID too long: %d characters (max 255)", len(groupID)) + } + + // Basic validation: no control characters + for _, char := range groupID { + if char < 32 || char == 127 { + return "", fmt.Errorf("group ID contains invalid characters") + } + } + + return strings.TrimSpace(groupID), nil +} + +// ProtocolMetadataDebugInfo returns debug information about protocol metadata +type ProtocolMetadataDebugInfo struct { + Strategy string + Version int16 + TopicCount int + Topics []string + UserDataSize int + ParsedOK bool + ParseError string +} + +// AnalyzeProtocolMetadata provides detailed debug information about protocol metadata +func AnalyzeProtocolMetadata(protocols []GroupProtocol) []ProtocolMetadataDebugInfo { + result := make([]ProtocolMetadataDebugInfo, 0, len(protocols)) + + for _, protocol := range protocols { + info := ProtocolMetadataDebugInfo{ + Strategy: protocol.Name, + } + + parsed, err := ParseConsumerProtocolMetadata(protocol.Metadata, protocol.Name) + if err != nil { + info.ParsedOK = false + info.ParseError = err.Error() + } else { + info.ParsedOK = true + info.Version = parsed.Version + info.TopicCount = len(parsed.Topics) + info.Topics = parsed.Topics + info.UserDataSize = len(parsed.UserData) + } + + result = append(result, info) + } + + return result +} diff --git a/weed/mq/kafka/protocol/consumer_group_metadata_test.go b/weed/mq/kafka/protocol/consumer_group_metadata_test.go new file mode 100644 index 000000000..78cda3924 --- /dev/null +++ b/weed/mq/kafka/protocol/consumer_group_metadata_test.go @@ -0,0 +1,541 @@ +package protocol + +import ( + "net" + "reflect" + "testing" +) + +func TestExtractClientHost(t *testing.T) { + tests := []struct { + name string + connCtx *ConnectionContext + expected string + }{ + { + name: "Nil connection context", + connCtx: nil, + expected: "unknown", + }, + { + name: "TCP address", + connCtx: &ConnectionContext{ + RemoteAddr: &net.TCPAddr{ + IP: net.ParseIP("192.168.1.100"), + Port: 54321, + }, + }, + expected: "192.168.1.100", + }, + { + name: "TCP address with IPv6", + connCtx: &ConnectionContext{ + RemoteAddr: &net.TCPAddr{ + IP: net.ParseIP("::1"), + Port: 54321, + }, + }, + expected: "::1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractClientHost(tt.connCtx) + if result != tt.expected { + t.Errorf("ExtractClientHost() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestParseConsumerProtocolMetadata(t *testing.T) { + tests := []struct { + name string + metadata []byte + strategy string + want *ConsumerProtocolMetadata + wantErr bool + }{ + { + name: "Empty metadata", + metadata: []byte{}, + strategy: "range", + want: &ConsumerProtocolMetadata{ + Version: 0, + Topics: []string{}, + UserData: []byte{}, + AssignmentStrategy: "range", + }, + wantErr: false, + }, + { + name: "Valid metadata with topics", + metadata: func() []byte { + data := make([]byte, 0) + // Version (2 bytes) + data = append(data, 0, 1) + // Topics count (4 bytes) - 2 topics + data = append(data, 0, 0, 0, 2) + // Topic 1: "topic-a" + data = append(data, 0, 7) // length + data = append(data, []byte("topic-a")...) + // Topic 2: "topic-b" + data = append(data, 0, 7) // length + data = append(data, []byte("topic-b")...) + // UserData length (4 bytes) - 5 bytes + data = append(data, 0, 0, 0, 5) + // UserData content + data = append(data, []byte("hello")...) + return data + }(), + strategy: "roundrobin", + want: &ConsumerProtocolMetadata{ + Version: 1, + Topics: []string{"topic-a", "topic-b"}, + UserData: []byte("hello"), + AssignmentStrategy: "roundrobin", + }, + wantErr: false, + }, + { + name: "Metadata too short for version (handled gracefully)", + metadata: []byte{0}, // Only 1 byte + strategy: "range", + want: &ConsumerProtocolMetadata{ + Version: 0, + Topics: []string{}, + UserData: []byte{}, + AssignmentStrategy: "range", + }, + wantErr: false, // Should handle gracefully, not error + }, + { + name: "Unreasonable topics count", + metadata: func() []byte { + data := make([]byte, 0) + data = append(data, 0, 1) // version + data = append(data, 0xFF, 0xFF, 0xFF, 0xFF) // huge topics count + return data + }(), + strategy: "range", + want: nil, + wantErr: true, + }, + { + name: "Topic name too long", + metadata: func() []byte { + data := make([]byte, 0) + data = append(data, 0, 1) // version + data = append(data, 0, 0, 0, 1) // 1 topic + data = append(data, 0xFF, 0xFF) // huge topic name length + return data + }(), + strategy: "sticky", + want: nil, + wantErr: true, + }, + { + name: "Valid metadata with empty topic name (should skip)", + metadata: func() []byte { + data := make([]byte, 0) + data = append(data, 0, 1) // version + data = append(data, 0, 0, 0, 2) // 2 topics + // Topic 1: empty name + data = append(data, 0, 0) // length 0 + // Topic 2: "valid-topic" + data = append(data, 0, 11) // length + data = append(data, []byte("valid-topic")...) + // UserData length (4 bytes) - 0 bytes + data = append(data, 0, 0, 0, 0) + return data + }(), + strategy: "range", + want: &ConsumerProtocolMetadata{ + Version: 1, + Topics: []string{"valid-topic"}, + UserData: []byte{}, + AssignmentStrategy: "range", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseConsumerProtocolMetadata(tt.metadata, tt.strategy) + if (err != nil) != tt.wantErr { + t.Errorf("ParseConsumerProtocolMetadata() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseConsumerProtocolMetadata() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGenerateConsumerProtocolMetadata(t *testing.T) { + tests := []struct { + name string + topics []string + userData []byte + }{ + { + name: "No topics, no user data", + topics: []string{}, + userData: []byte{}, + }, + { + name: "Single topic, no user data", + topics: []string{"test-topic"}, + userData: []byte{}, + }, + { + name: "Multiple topics with user data", + topics: []string{"topic-1", "topic-2", "topic-3"}, + userData: []byte("user-data-content"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate metadata + generated := GenerateConsumerProtocolMetadata(tt.topics, tt.userData) + + // Parse it back + parsed, err := ParseConsumerProtocolMetadata(generated, "test") + if err != nil { + t.Fatalf("Failed to parse generated metadata: %v", err) + } + + // Verify topics match + if !reflect.DeepEqual(parsed.Topics, tt.topics) { + t.Errorf("Generated topics = %v, want %v", parsed.Topics, tt.topics) + } + + // Verify user data matches + if !reflect.DeepEqual(parsed.UserData, tt.userData) { + t.Errorf("Generated user data = %v, want %v", parsed.UserData, tt.userData) + } + + // Verify version is 1 + if parsed.Version != 1 { + t.Errorf("Generated version = %v, want 1", parsed.Version) + } + }) + } +} + +func TestValidateAssignmentStrategy(t *testing.T) { + tests := []struct { + strategy string + valid bool + }{ + {"range", true}, + {"roundrobin", true}, + {"sticky", true}, + {"cooperative-sticky", false}, // Not implemented yet + {"unknown", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.strategy, func(t *testing.T) { + result := ValidateAssignmentStrategy(tt.strategy) + if result != tt.valid { + t.Errorf("ValidateAssignmentStrategy(%s) = %v, want %v", tt.strategy, result, tt.valid) + } + }) + } +} + +func TestExtractTopicsFromMetadata(t *testing.T) { + // Create test metadata for range protocol + rangeMetadata := GenerateConsumerProtocolMetadata([]string{"topic-a", "topic-b"}, []byte{}) + roundrobinMetadata := GenerateConsumerProtocolMetadata([]string{"topic-x", "topic-y"}, []byte{}) + invalidMetadata := []byte{0xFF, 0xFF} // Invalid metadata + + tests := []struct { + name string + protocols []GroupProtocol + fallbackTopics []string + expectedTopics []string + }{ + { + name: "Extract from range protocol", + protocols: []GroupProtocol{ + {Name: "range", Metadata: rangeMetadata}, + {Name: "roundrobin", Metadata: roundrobinMetadata}, + }, + fallbackTopics: []string{"fallback"}, + expectedTopics: []string{"topic-a", "topic-b"}, + }, + { + name: "Invalid metadata, use fallback", + protocols: []GroupProtocol{ + {Name: "range", Metadata: invalidMetadata}, + }, + fallbackTopics: []string{"fallback-topic"}, + expectedTopics: []string{"fallback-topic"}, + }, + { + name: "No protocols, use fallback", + protocols: []GroupProtocol{}, + fallbackTopics: []string{"fallback-topic"}, + expectedTopics: []string{"fallback-topic"}, + }, + { + name: "No protocols, no fallback, use default", + protocols: []GroupProtocol{}, + fallbackTopics: []string{}, + expectedTopics: []string{"test-topic"}, + }, + { + name: "Unsupported protocol, use fallback", + protocols: []GroupProtocol{ + {Name: "unsupported", Metadata: rangeMetadata}, + }, + fallbackTopics: []string{"fallback-topic"}, + expectedTopics: []string{"fallback-topic"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractTopicsFromMetadata(tt.protocols, tt.fallbackTopics) + if !reflect.DeepEqual(result, tt.expectedTopics) { + t.Errorf("ExtractTopicsFromMetadata() = %v, want %v", result, tt.expectedTopics) + } + }) + } +} + +func TestSelectBestProtocol(t *testing.T) { + tests := []struct { + name string + clientProtocols []GroupProtocol + groupProtocols []string + expected string + }{ + { + name: "Prefer sticky over roundrobin", + clientProtocols: []GroupProtocol{ + {Name: "range", Metadata: []byte{}}, + {Name: "roundrobin", Metadata: []byte{}}, + {Name: "sticky", Metadata: []byte{}}, + }, + groupProtocols: []string{"range", "roundrobin", "sticky"}, + expected: "sticky", + }, + { + name: "Prefer roundrobin over range", + clientProtocols: []GroupProtocol{ + {Name: "range", Metadata: []byte{}}, + {Name: "roundrobin", Metadata: []byte{}}, + }, + groupProtocols: []string{"range", "roundrobin"}, + expected: "roundrobin", + }, + { + name: "Only range available", + clientProtocols: []GroupProtocol{ + {Name: "range", Metadata: []byte{}}, + }, + groupProtocols: []string{"range"}, + expected: "range", + }, + { + name: "Client supports sticky but group doesn't", + clientProtocols: []GroupProtocol{ + {Name: "sticky", Metadata: []byte{}}, + {Name: "range", Metadata: []byte{}}, + }, + groupProtocols: []string{"range", "roundrobin"}, + expected: "range", + }, + { + name: "No group protocols specified (new group)", + clientProtocols: []GroupProtocol{ + {Name: "sticky", Metadata: []byte{}}, + {Name: "roundrobin", Metadata: []byte{}}, + }, + groupProtocols: []string{}, // Empty = new group + expected: "sticky", + }, + { + name: "No supported protocols, fallback to range", + clientProtocols: []GroupProtocol{ + {Name: "unsupported", Metadata: []byte{}}, + }, + groupProtocols: []string{"range"}, + expected: "range", // Last resort fallback + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SelectBestProtocol(tt.clientProtocols, tt.groupProtocols) + if result != tt.expected { + t.Errorf("SelectBestProtocol() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestSanitizeConsumerGroupID(t *testing.T) { + tests := []struct { + name string + groupID string + want string + wantErr bool + }{ + { + name: "Valid group ID", + groupID: "test-group", + want: "test-group", + wantErr: false, + }, + { + name: "Group ID with spaces (trimmed)", + groupID: " spaced-group ", + want: "spaced-group", + wantErr: false, + }, + { + name: "Empty group ID", + groupID: "", + want: "", + wantErr: true, + }, + { + name: "Group ID too long", + groupID: string(make([]byte, 256)), // 256 characters + want: "", + wantErr: true, + }, + { + name: "Group ID with control characters", + groupID: "test\x00group", + want: "", + wantErr: true, + }, + { + name: "Group ID with tab character", + groupID: "test\tgroup", + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := SanitizeConsumerGroupID(tt.groupID) + if (err != nil) != tt.wantErr { + t.Errorf("SanitizeConsumerGroupID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("SanitizeConsumerGroupID() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAnalyzeProtocolMetadata(t *testing.T) { + // Create valid metadata + validMetadata := GenerateConsumerProtocolMetadata([]string{"topic-1", "topic-2"}, []byte("userdata")) + + // Create invalid metadata + invalidMetadata := []byte{0xFF} + + protocols := []GroupProtocol{ + {Name: "range", Metadata: validMetadata}, + {Name: "roundrobin", Metadata: invalidMetadata}, + {Name: "sticky", Metadata: []byte{}}, // Empty but should not error + } + + result := AnalyzeProtocolMetadata(protocols) + + if len(result) != 3 { + t.Fatalf("Expected 3 protocol analyses, got %d", len(result)) + } + + // Check range protocol (should parse successfully) + rangeInfo := result[0] + if rangeInfo.Strategy != "range" { + t.Errorf("Expected strategy 'range', got '%s'", rangeInfo.Strategy) + } + if !rangeInfo.ParsedOK { + t.Errorf("Expected range protocol to parse successfully") + } + if rangeInfo.TopicCount != 2 { + t.Errorf("Expected 2 topics, got %d", rangeInfo.TopicCount) + } + + // Check roundrobin protocol (with invalid metadata, handled gracefully) + roundrobinInfo := result[1] + if roundrobinInfo.Strategy != "roundrobin" { + t.Errorf("Expected strategy 'roundrobin', got '%s'", roundrobinInfo.Strategy) + } + // Note: We now handle invalid metadata gracefully, so it should parse successfully with empty topics + if !roundrobinInfo.ParsedOK { + t.Errorf("Expected roundrobin protocol to be handled gracefully") + } + if roundrobinInfo.TopicCount != 0 { + t.Errorf("Expected 0 topics for invalid metadata, got %d", roundrobinInfo.TopicCount) + } + + // Check sticky protocol (empty metadata should not error but return empty topics) + stickyInfo := result[2] + if stickyInfo.Strategy != "sticky" { + t.Errorf("Expected strategy 'sticky', got '%s'", stickyInfo.Strategy) + } + if !stickyInfo.ParsedOK { + t.Errorf("Expected empty metadata to parse successfully") + } + if stickyInfo.TopicCount != 0 { + t.Errorf("Expected 0 topics for empty metadata, got %d", stickyInfo.TopicCount) + } +} + +// Benchmark tests for performance validation +func BenchmarkParseConsumerProtocolMetadata(b *testing.B) { + // Create realistic metadata with multiple topics + topics := []string{"topic-1", "topic-2", "topic-3", "topic-4", "topic-5"} + userData := []byte("some-user-data-content") + metadata := GenerateConsumerProtocolMetadata(topics, userData) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseConsumerProtocolMetadata(metadata, "range") + } +} + +func BenchmarkExtractClientHost(b *testing.B) { + connCtx := &ConnectionContext{ + RemoteAddr: &net.TCPAddr{ + IP: net.ParseIP("192.168.1.100"), + Port: 54321, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ExtractClientHost(connCtx) + } +} + +func BenchmarkSelectBestProtocol(b *testing.B) { + protocols := []GroupProtocol{ + {Name: "range", Metadata: []byte{}}, + {Name: "roundrobin", Metadata: []byte{}}, + {Name: "sticky", Metadata: []byte{}}, + } + groupProtocols := []string{"range", "roundrobin", "sticky"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = SelectBestProtocol(protocols, groupProtocols) + } +} diff --git a/weed/mq/kafka/protocol/handler.go b/weed/mq/kafka/protocol/handler.go index 32cb543e8..3d182e61b 100644 --- a/weed/mq/kafka/protocol/handler.go +++ b/weed/mq/kafka/protocol/handler.go @@ -63,6 +63,9 @@ type Handler struct { // Dynamic broker address for Metadata responses brokerHost string brokerPort int + + // Connection context for tracking client information + connContext *ConnectionContext } // NewHandler creates a basic Kafka handler with in-memory storage @@ -200,8 +203,17 @@ func (h *Handler) SetBrokerAddress(host string, port int) { // HandleConn processes a single client connection func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { connectionID := fmt.Sprintf("%s->%s", conn.RemoteAddr(), conn.LocalAddr()) + + // Set connection context for this connection + h.connContext = &ConnectionContext{ + RemoteAddr: conn.RemoteAddr(), + LocalAddr: conn.LocalAddr(), + ConnectionID: connectionID, + } + defer func() { fmt.Printf("DEBUG: [%s] Connection closing\n", connectionID) + h.connContext = nil // Clear connection context conn.Close() }() diff --git a/weed/mq/kafka/protocol/joingroup.go b/weed/mq/kafka/protocol/joingroup.go index a807acabd..b014d5901 100644 --- a/weed/mq/kafka/protocol/joingroup.go +++ b/weed/mq/kafka/protocol/joingroup.go @@ -162,14 +162,18 @@ func (h *Handler) handleJoinGroup(correlationID uint32, apiVersion uint16, reque return h.buildJoinGroupErrorResponse(correlationID, ErrorCodeInvalidGroupID), nil } - // Create or update member + // Extract client host from connection context + clientHost := ExtractClientHost(h.connContext) + fmt.Printf("DEBUG: JoinGroup extracted client host: %s\n", clientHost) + + // Create or update member with enhanced metadata parsing member := &consumer.GroupMember{ ID: memberID, - ClientID: clientKey, // Use deterministic client key for member identification - ClientHost: "unknown", // TODO: extract from connection - needed for consumer group metadata + ClientID: clientKey, // Use deterministic client key for member identification + ClientHost: clientHost, // Now extracted from actual connection SessionTimeout: request.SessionTimeout, RebalanceTimeout: request.RebalanceTimeout, - Subscription: h.extractSubscriptionFromProtocols(request.GroupProtocols), + Subscription: h.extractSubscriptionFromProtocolsEnhanced(request.GroupProtocols), State: consumer.MemberStatePending, LastHeartbeat: time.Now(), JoinedAt: time.Now(), @@ -211,15 +215,18 @@ func (h *Handler) handleJoinGroup(correlationID uint32, apiVersion uint16, reque // Update group's subscribed topics h.updateGroupSubscription(group) - // Select assignment protocol (prefer range, fall back to roundrobin) - groupProtocol := "range" - for _, protocol := range request.GroupProtocols { - if protocol.Name == "range" || protocol.Name == "roundrobin" { - groupProtocol = protocol.Name - break - } + // Select assignment protocol using enhanced selection logic + existingProtocols := make([]string, 0) + for _ = range group.Members { + // Collect protocols from existing members (simplified - in real implementation + // we'd track each member's supported protocols) + existingProtocols = append(existingProtocols, "range") // placeholder } + + groupProtocol := SelectBestProtocol(request.GroupProtocols, existingProtocols) group.Protocol = groupProtocol + fmt.Printf("DEBUG: JoinGroup selected protocol: %s (from %d client protocols)\n", + groupProtocol, len(request.GroupProtocols)) // Select group leader (first member or keep existing if still present) if group.Leader == "" || group.Members[group.Leader] == nil { @@ -565,26 +572,29 @@ func (h *Handler) buildMinimalJoinGroupResponse(correlationID uint32, apiVersion return response } +// extractSubscriptionFromProtocols - legacy method for backward compatibility func (h *Handler) extractSubscriptionFromProtocols(protocols []GroupProtocol) []string { - // Parse consumer protocol metadata to extract actual subscribed topics - // Consumer protocol metadata format (for "consumer" protocol type): - // - Version (2 bytes) - // - Topics array (4 bytes count + topic names) - // - User data (4 bytes length + data) - - for _, protocol := range protocols { - if protocol.Name == "range" || protocol.Name == "roundrobin" || protocol.Name == "sticky" { - topics := h.parseConsumerProtocolMetadata(protocol.Metadata) - if len(topics) > 0 { - fmt.Printf("DEBUG: Extracted subscription topics: %v from protocol: %s\n", topics, protocol.Name) - return topics - } + return h.extractSubscriptionFromProtocolsEnhanced(protocols) +} + +// extractSubscriptionFromProtocolsEnhanced uses improved metadata parsing with better error handling +func (h *Handler) extractSubscriptionFromProtocolsEnhanced(protocols []GroupProtocol) []string { + // Analyze protocol metadata for debugging + debugInfo := AnalyzeProtocolMetadata(protocols) + for _, info := range debugInfo { + if info.ParsedOK { + fmt.Printf("DEBUG: Protocol %s parsed successfully: version=%d, topics=%v\n", + info.Strategy, info.Version, info.Topics) + } else { + fmt.Printf("DEBUG: Protocol %s parse failed: %s\n", info.Strategy, info.ParseError) } } - // Fallback to default if parsing fails - fmt.Printf("DEBUG: Failed to extract subscription, using fallback topic\n") - return []string{"test-topic"} + // Extract topics using enhanced parsing + topics := ExtractTopicsFromMetadata(protocols, h.getAvailableTopics()) + + fmt.Printf("DEBUG: Enhanced subscription extraction result: %v\n", topics) + return topics } func (h *Handler) parseConsumerProtocolMetadata(metadata []byte) []string {