Browse Source
mq(kafka): implement ApiVersions protocol handler with manual binary encoding and comprehensive unit tests
pull/7231/head
mq(kafka): implement ApiVersions protocol handler with manual binary encoding and comprehensive unit tests
pull/7231/head
6 changed files with 370 additions and 83 deletions
-
1go.mod
-
2go.sum
-
76test/kafka/gateway_smoke_test.go
-
97weed/mq/kafka/gateway/server.go
-
115weed/mq/kafka/protocol/handler.go
-
162weed/mq/kafka/protocol/handler_test.go
@ -1,50 +1,48 @@ |
|||
package kafka |
|||
|
|||
import ( |
|||
"net" |
|||
"testing" |
|||
"time" |
|||
"net" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/mq/kafka/gateway" |
|||
"github.com/seaweedfs/seaweedfs/weed/mq/kafka/gateway" |
|||
) |
|||
|
|||
func TestGateway_StartAcceptsConnections(t *testing.T) { |
|||
srv := gateway.NewServer(gateway.Options{Listen: ":0"}) |
|||
if err := srv.Start(); err != nil { |
|||
t.Fatalf("start gateway: %v", err) |
|||
} |
|||
addr := srv.Addr() |
|||
if addr == "" { |
|||
t.Fatalf("server Addr() empty") |
|||
} |
|||
conn, err := net.DialTimeout("tcp", addr, 2*time.Second) |
|||
if err != nil { |
|||
t.Fatalf("dial gateway: %v", err) |
|||
} |
|||
_ = conn.Close() |
|||
if err := srv.Close(); err != nil { |
|||
t.Fatalf("close gateway: %v", err) |
|||
} |
|||
srv := gateway.NewServer(gateway.Options{Listen: ":0"}) |
|||
if err := srv.Start(); err != nil { |
|||
t.Fatalf("start gateway: %v", err) |
|||
} |
|||
addr := srv.Addr() |
|||
if addr == "" { |
|||
t.Fatalf("server Addr() empty") |
|||
} |
|||
conn, err := net.DialTimeout("tcp", addr, 2*time.Second) |
|||
if err != nil { |
|||
t.Fatalf("dial gateway: %v", err) |
|||
} |
|||
_ = conn.Close() |
|||
if err := srv.Close(); err != nil { |
|||
t.Fatalf("close gateway: %v", err) |
|||
} |
|||
} |
|||
|
|||
func TestGateway_RefusesAfterClose(t *testing.T) { |
|||
srv := gateway.NewServer(gateway.Options{Listen: ":0"}) |
|||
if err := srv.Start(); err != nil { |
|||
t.Fatalf("start gateway: %v", err) |
|||
} |
|||
addr := srv.Addr() |
|||
if addr == "" { |
|||
t.Fatalf("server Addr() empty") |
|||
} |
|||
if err := srv.Close(); err != nil { |
|||
t.Fatalf("close gateway: %v", err) |
|||
} |
|||
// give the OS a brief moment to release the port
|
|||
time.Sleep(50 * time.Millisecond) |
|||
_, err := net.DialTimeout("tcp", addr, 300*time.Millisecond) |
|||
if err == nil { |
|||
t.Fatalf("expected dial to fail after close") |
|||
} |
|||
srv := gateway.NewServer(gateway.Options{Listen: ":0"}) |
|||
if err := srv.Start(); err != nil { |
|||
t.Fatalf("start gateway: %v", err) |
|||
} |
|||
addr := srv.Addr() |
|||
if addr == "" { |
|||
t.Fatalf("server Addr() empty") |
|||
} |
|||
if err := srv.Close(); err != nil { |
|||
t.Fatalf("close gateway: %v", err) |
|||
} |
|||
// give the OS a brief moment to release the port
|
|||
time.Sleep(50 * time.Millisecond) |
|||
_, err := net.DialTimeout("tcp", addr, 300*time.Millisecond) |
|||
if err == nil { |
|||
t.Fatalf("expected dial to fail after close") |
|||
} |
|||
} |
|||
|
|||
|
|||
@ -0,0 +1,115 @@ |
|||
package protocol |
|||
|
|||
import ( |
|||
"bufio" |
|||
"encoding/binary" |
|||
"fmt" |
|||
"io" |
|||
"net" |
|||
) |
|||
|
|||
// Handler processes Kafka protocol requests from clients
|
|||
type Handler struct { |
|||
} |
|||
|
|||
func NewHandler() *Handler { |
|||
return &Handler{} |
|||
} |
|||
|
|||
// HandleConn processes a single client connection
|
|||
func (h *Handler) HandleConn(conn net.Conn) error { |
|||
defer conn.Close() |
|||
|
|||
r := bufio.NewReader(conn) |
|||
w := bufio.NewWriter(conn) |
|||
defer w.Flush() |
|||
|
|||
for { |
|||
// Read message size (4 bytes)
|
|||
var sizeBytes [4]byte |
|||
if _, err := io.ReadFull(r, sizeBytes[:]); err != nil { |
|||
if err == io.EOF { |
|||
return nil // clean disconnect
|
|||
} |
|||
return fmt.Errorf("read size: %w", err) |
|||
} |
|||
|
|||
size := binary.BigEndian.Uint32(sizeBytes[:]) |
|||
if size == 0 || size > 1024*1024 { // 1MB limit
|
|||
return fmt.Errorf("invalid message size: %d", size) |
|||
} |
|||
|
|||
// Read the message
|
|||
messageBuf := make([]byte, size) |
|||
if _, err := io.ReadFull(r, messageBuf); err != nil { |
|||
return fmt.Errorf("read message: %w", err) |
|||
} |
|||
|
|||
// Parse at least the basic header to get API key and correlation ID
|
|||
if len(messageBuf) < 8 { |
|||
return fmt.Errorf("message too short") |
|||
} |
|||
|
|||
apiKey := binary.BigEndian.Uint16(messageBuf[0:2]) |
|||
apiVersion := binary.BigEndian.Uint16(messageBuf[2:4]) |
|||
correlationID := binary.BigEndian.Uint32(messageBuf[4:8]) |
|||
|
|||
// Handle the request based on API key
|
|||
var response []byte |
|||
var err error |
|||
|
|||
switch apiKey { |
|||
case 18: // ApiVersions
|
|||
response, err = h.handleApiVersions(correlationID) |
|||
default: |
|||
err = fmt.Errorf("unsupported API key: %d (version %d)", apiKey, apiVersion) |
|||
} |
|||
|
|||
if err != nil { |
|||
return fmt.Errorf("handle request: %w", err) |
|||
} |
|||
|
|||
// Write response size and data
|
|||
responseSizeBytes := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(responseSizeBytes, uint32(len(response))) |
|||
|
|||
if _, err := w.Write(responseSizeBytes); err != nil { |
|||
return fmt.Errorf("write response size: %w", err) |
|||
} |
|||
if _, err := w.Write(response); err != nil { |
|||
return fmt.Errorf("write response: %w", err) |
|||
} |
|||
|
|||
if err := w.Flush(); err != nil { |
|||
return fmt.Errorf("flush response: %w", err) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { |
|||
// Build ApiVersions response manually
|
|||
// Response format: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4)
|
|||
|
|||
response := make([]byte, 0, 64) |
|||
|
|||
// Correlation ID
|
|||
correlationIDBytes := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(correlationIDBytes, correlationID) |
|||
response = append(response, correlationIDBytes...) |
|||
|
|||
// Error code (0 = no error)
|
|||
response = append(response, 0, 0) |
|||
|
|||
// Number of API keys (compact array format in newer versions, but using basic format for simplicity)
|
|||
response = append(response, 0, 0, 0, 1) // 1 API key
|
|||
|
|||
// API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2)
|
|||
response = append(response, 0, 18) // API key 18
|
|||
response = append(response, 0, 0) // min version 0
|
|||
response = append(response, 0, 3) // max version 3
|
|||
|
|||
// Throttle time (4 bytes, 0 = no throttling)
|
|||
response = append(response, 0, 0, 0, 0) |
|||
|
|||
return response, nil |
|||
} |
|||
@ -0,0 +1,162 @@ |
|||
package protocol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"net" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestHandler_ApiVersions(t *testing.T) { |
|||
// Create handler
|
|||
h := NewHandler() |
|||
|
|||
// Create in-memory connection
|
|||
server, client := net.Pipe() |
|||
defer server.Close() |
|||
defer client.Close() |
|||
|
|||
// Handle connection in background
|
|||
done := make(chan error, 1) |
|||
go func() { |
|||
done <- h.HandleConn(server) |
|||
}() |
|||
|
|||
// Create ApiVersions request manually
|
|||
// Request format: api_key(2) + api_version(2) + correlation_id(4) + client_id_size(2) + client_id + body
|
|||
correlationID := uint32(12345) |
|||
clientID := "test-client" |
|||
|
|||
message := make([]byte, 0, 64) |
|||
message = append(message, 0, 18) // API key 18 (ApiVersions)
|
|||
message = append(message, 0, 0) // API version 0
|
|||
|
|||
// Correlation ID
|
|||
correlationIDBytes := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(correlationIDBytes, correlationID) |
|||
message = append(message, correlationIDBytes...) |
|||
|
|||
// Client ID length and string
|
|||
clientIDLen := uint16(len(clientID)) |
|||
message = append(message, byte(clientIDLen>>8), byte(clientIDLen)) |
|||
message = append(message, []byte(clientID)...) |
|||
|
|||
// Empty request body for ApiVersions
|
|||
|
|||
// Write message size and data
|
|||
messageSize := uint32(len(message)) |
|||
sizeBuf := make([]byte, 4) |
|||
binary.BigEndian.PutUint32(sizeBuf, messageSize) |
|||
|
|||
if _, err := client.Write(sizeBuf); err != nil { |
|||
t.Fatalf("write size: %v", err) |
|||
} |
|||
if _, err := client.Write(message); err != nil { |
|||
t.Fatalf("write message: %v", err) |
|||
} |
|||
|
|||
// Read response size
|
|||
var respSizeBuf [4]byte |
|||
client.SetReadDeadline(time.Now().Add(5 * time.Second)) |
|||
if _, err := client.Read(respSizeBuf[:]); err != nil { |
|||
t.Fatalf("read response size: %v", err) |
|||
} |
|||
|
|||
respSize := binary.BigEndian.Uint32(respSizeBuf[:]) |
|||
if respSize == 0 || respSize > 1024*1024 { |
|||
t.Fatalf("invalid response size: %d", respSize) |
|||
} |
|||
|
|||
// Read response data
|
|||
respBuf := make([]byte, respSize) |
|||
if _, err := client.Read(respBuf); err != nil { |
|||
t.Fatalf("read response: %v", err) |
|||
} |
|||
|
|||
// Parse response: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4)
|
|||
if len(respBuf) < 14 { // minimum response size
|
|||
t.Fatalf("response too short: %d bytes", len(respBuf)) |
|||
} |
|||
|
|||
// Check correlation ID
|
|||
respCorrelationID := binary.BigEndian.Uint32(respBuf[0:4]) |
|||
if respCorrelationID != correlationID { |
|||
t.Errorf("correlation ID mismatch: got %d, want %d", respCorrelationID, correlationID) |
|||
} |
|||
|
|||
// Check error code
|
|||
errorCode := binary.BigEndian.Uint16(respBuf[4:6]) |
|||
if errorCode != 0 { |
|||
t.Errorf("expected no error, got error code: %d", errorCode) |
|||
} |
|||
|
|||
// Check number of API keys
|
|||
numAPIKeys := binary.BigEndian.Uint32(respBuf[6:10]) |
|||
if numAPIKeys != 1 { |
|||
t.Errorf("expected 1 API key, got: %d", numAPIKeys) |
|||
} |
|||
|
|||
// Check API key details: api_key(2) + min_version(2) + max_version(2)
|
|||
if len(respBuf) < 16 { |
|||
t.Fatalf("response too short for API key data") |
|||
} |
|||
|
|||
apiKey := binary.BigEndian.Uint16(respBuf[10:12]) |
|||
minVersion := binary.BigEndian.Uint16(respBuf[12:14]) |
|||
maxVersion := binary.BigEndian.Uint16(respBuf[14:16]) |
|||
|
|||
if apiKey != 18 { |
|||
t.Errorf("expected API key 18, got: %d", apiKey) |
|||
} |
|||
if minVersion != 0 { |
|||
t.Errorf("expected min version 0, got: %d", minVersion) |
|||
} |
|||
if maxVersion != 3 { |
|||
t.Errorf("expected max version 3, got: %d", maxVersion) |
|||
} |
|||
|
|||
// Close client to end handler
|
|||
client.Close() |
|||
|
|||
// Wait for handler to complete
|
|||
select { |
|||
case err := <-done: |
|||
if err != nil { |
|||
t.Errorf("handler error: %v", err) |
|||
} |
|||
case <-time.After(2 * time.Second): |
|||
t.Errorf("handler did not complete in time") |
|||
} |
|||
} |
|||
|
|||
func TestHandler_handleApiVersions(t *testing.T) { |
|||
h := NewHandler() |
|||
correlationID := uint32(999) |
|||
|
|||
response, err := h.handleApiVersions(correlationID) |
|||
if err != nil { |
|||
t.Fatalf("handleApiVersions: %v", err) |
|||
} |
|||
|
|||
if len(response) < 20 { // minimum expected size
|
|||
t.Fatalf("response too short: %d bytes", len(response)) |
|||
} |
|||
|
|||
// Check correlation ID
|
|||
respCorrelationID := binary.BigEndian.Uint32(response[0:4]) |
|||
if respCorrelationID != correlationID { |
|||
t.Errorf("correlation ID: got %d, want %d", respCorrelationID, correlationID) |
|||
} |
|||
|
|||
// Check error code
|
|||
errorCode := binary.BigEndian.Uint16(response[4:6]) |
|||
if errorCode != 0 { |
|||
t.Errorf("error code: got %d, want 0", errorCode) |
|||
} |
|||
|
|||
// Check API key
|
|||
apiKey := binary.BigEndian.Uint16(response[10:12]) |
|||
if apiKey != 18 { |
|||
t.Errorf("API key: got %d, want 18", apiKey) |
|||
} |
|||
} |
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue