From 7c4a5f546cb3f6420f7e5b62e48b8d9df7a49a1e Mon Sep 17 00:00:00 2001 From: chrislu Date: Wed, 10 Sep 2025 11:21:52 -0700 Subject: [PATCH] mq(kafka): implement ApiVersions protocol handler with manual binary encoding and comprehensive unit tests --- go.mod | 1 + go.sum | 2 + test/kafka/gateway_smoke_test.go | 76 ++++++------ weed/mq/kafka/gateway/server.go | 97 ++++++++------- weed/mq/kafka/protocol/handler.go | 115 ++++++++++++++++++ weed/mq/kafka/protocol/handler_test.go | 162 +++++++++++++++++++++++++ 6 files changed, 370 insertions(+), 83 deletions(-) create mode 100644 weed/mq/kafka/protocol/handler.go create mode 100644 weed/mq/kafka/protocol/handler_test.go diff --git a/go.mod b/go.mod index 2779c3226..1892fec6d 100644 --- a/go.mod +++ b/go.mod @@ -154,6 +154,7 @@ require ( github.com/rdleal/intervalst v1.5.0 github.com/redis/go-redis/v9 v9.12.1 github.com/schollz/progressbar/v3 v3.18.0 + github.com/segmentio/kafka-go v0.4.49 github.com/shirou/gopsutil/v3 v3.24.5 github.com/tarantool/go-tarantool/v2 v2.4.0 github.com/tikv/client-go/v2 v2.0.7 diff --git a/go.sum b/go.sum index ca130fece..b49d03f99 100644 --- a/go.sum +++ b/go.sum @@ -1638,6 +1638,8 @@ github.com/seaweedfs/goexif v1.0.3 h1:ve/OjI7dxPW8X9YQsv3JuVMaxEyF9Rvfd04ouL+Bz3 github.com/seaweedfs/goexif v1.0.3/go.mod h1:Oni780Z236sXpIQzk1XoJlTwqrJ02smEin9zQeff7Fk= github.com/seaweedfs/raft v1.1.3 h1:5B6hgneQ7IuU4Ceom/f6QUt8pEeqjcsRo+IxlyPZCws= github.com/seaweedfs/raft v1.1.3/go.mod h1:9cYlEBA+djJbnf/5tWsCybtbL7ICYpi+Uxcg3MxjuNs= +github.com/segmentio/kafka-go v0.4.49 h1:GJiNX1d/g+kG6ljyJEoi9++PUMdXGAxb7JGPiDCuNmk= +github.com/segmentio/kafka-go v0.4.49/go.mod h1:Y1gn60kzLEEaW28YshXyk2+VCUKbJ3Qr6DrnT3i4+9E= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= diff --git a/test/kafka/gateway_smoke_test.go b/test/kafka/gateway_smoke_test.go index b8098b7a9..ad691efe4 100644 --- a/test/kafka/gateway_smoke_test.go +++ b/test/kafka/gateway_smoke_test.go @@ -1,50 +1,48 @@ package kafka import ( - "net" - "testing" - "time" + "net" + "testing" + "time" - "github.com/seaweedfs/seaweedfs/weed/mq/kafka/gateway" + "github.com/seaweedfs/seaweedfs/weed/mq/kafka/gateway" ) func TestGateway_StartAcceptsConnections(t *testing.T) { - srv := gateway.NewServer(gateway.Options{Listen: ":0"}) - if err := srv.Start(); err != nil { - t.Fatalf("start gateway: %v", err) - } - addr := srv.Addr() - if addr == "" { - t.Fatalf("server Addr() empty") - } - conn, err := net.DialTimeout("tcp", addr, 2*time.Second) - if err != nil { - t.Fatalf("dial gateway: %v", err) - } - _ = conn.Close() - if err := srv.Close(); err != nil { - t.Fatalf("close gateway: %v", err) - } + srv := gateway.NewServer(gateway.Options{Listen: ":0"}) + if err := srv.Start(); err != nil { + t.Fatalf("start gateway: %v", err) + } + addr := srv.Addr() + if addr == "" { + t.Fatalf("server Addr() empty") + } + conn, err := net.DialTimeout("tcp", addr, 2*time.Second) + if err != nil { + t.Fatalf("dial gateway: %v", err) + } + _ = conn.Close() + if err := srv.Close(); err != nil { + t.Fatalf("close gateway: %v", err) + } } func TestGateway_RefusesAfterClose(t *testing.T) { - srv := gateway.NewServer(gateway.Options{Listen: ":0"}) - if err := srv.Start(); err != nil { - t.Fatalf("start gateway: %v", err) - } - addr := srv.Addr() - if addr == "" { - t.Fatalf("server Addr() empty") - } - if err := srv.Close(); err != nil { - t.Fatalf("close gateway: %v", err) - } - // give the OS a brief moment to release the port - time.Sleep(50 * time.Millisecond) - _, err := net.DialTimeout("tcp", addr, 300*time.Millisecond) - if err == nil { - t.Fatalf("expected dial to fail after close") - } + srv := gateway.NewServer(gateway.Options{Listen: ":0"}) + if err := srv.Start(); err != nil { + t.Fatalf("start gateway: %v", err) + } + addr := srv.Addr() + if addr == "" { + t.Fatalf("server Addr() empty") + } + if err := srv.Close(); err != nil { + t.Fatalf("close gateway: %v", err) + } + // give the OS a brief moment to release the port + time.Sleep(50 * time.Millisecond) + _, err := net.DialTimeout("tcp", addr, 300*time.Millisecond) + if err == nil { + t.Fatalf("expected dial to fail after close") + } } - - diff --git a/weed/mq/kafka/gateway/server.go b/weed/mq/kafka/gateway/server.go index 857993cc4..b3e0b8239 100644 --- a/weed/mq/kafka/gateway/server.go +++ b/weed/mq/kafka/gateway/server.go @@ -4,74 +4,83 @@ import ( "context" "net" "sync" + + "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/mq/kafka/protocol" ) type Options struct { - Listen string + Listen string } type Server struct { - opts Options - ln net.Listener - wg sync.WaitGroup - ctx context.Context - cancel context.CancelFunc + opts Options + ln net.Listener + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + handler *protocol.Handler } func NewServer(opts Options) *Server { - ctx, cancel := context.WithCancel(context.Background()) - return &Server{opts: opts, ctx: ctx, cancel: cancel} + ctx, cancel := context.WithCancel(context.Background()) + return &Server{ + opts: opts, + ctx: ctx, + cancel: cancel, + handler: protocol.NewHandler(), + } } func (s *Server) Start() error { - ln, err := net.Listen("tcp", s.opts.Listen) - if err != nil { - return err - } - s.ln = ln - s.wg.Add(1) - go func() { - defer s.wg.Done() - for { - conn, err := s.ln.Accept() - if err != nil { - select { - case <-s.ctx.Done(): - return - default: - return - } - } + ln, err := net.Listen("tcp", s.opts.Listen) + if err != nil { + return err + } + s.ln = ln + s.wg.Add(1) + go func() { + defer s.wg.Done() + for { + conn, err := s.ln.Accept() + if err != nil { + select { + case <-s.ctx.Done(): + return + default: + return + } + } s.wg.Add(1) go func(c net.Conn) { defer s.wg.Done() - _ = c.Close() + if err := s.handler.HandleConn(c); err != nil { + glog.V(1).Infof("handle conn %v: %v", c.RemoteAddr(), err) + } }(conn) - } - }() - return nil + } + }() + return nil } func (s *Server) Wait() error { - s.wg.Wait() - return nil + s.wg.Wait() + return nil } func (s *Server) Close() error { - s.cancel() - if s.ln != nil { - _ = s.ln.Close() - } - s.wg.Wait() - return nil + s.cancel() + if s.ln != nil { + _ = s.ln.Close() + } + s.wg.Wait() + return nil } // Addr returns the bound address of the server listener, or empty if not started. func (s *Server) Addr() string { - if s.ln == nil { - return "" - } - return s.ln.Addr().String() + if s.ln == nil { + return "" + } + return s.ln.Addr().String() } - - diff --git a/weed/mq/kafka/protocol/handler.go b/weed/mq/kafka/protocol/handler.go new file mode 100644 index 000000000..912fc7531 --- /dev/null +++ b/weed/mq/kafka/protocol/handler.go @@ -0,0 +1,115 @@ +package protocol + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "net" +) + +// Handler processes Kafka protocol requests from clients +type Handler struct { +} + +func NewHandler() *Handler { + return &Handler{} +} + +// HandleConn processes a single client connection +func (h *Handler) HandleConn(conn net.Conn) error { + defer conn.Close() + + r := bufio.NewReader(conn) + w := bufio.NewWriter(conn) + defer w.Flush() + + for { + // Read message size (4 bytes) + var sizeBytes [4]byte + if _, err := io.ReadFull(r, sizeBytes[:]); err != nil { + if err == io.EOF { + return nil // clean disconnect + } + return fmt.Errorf("read size: %w", err) + } + + size := binary.BigEndian.Uint32(sizeBytes[:]) + if size == 0 || size > 1024*1024 { // 1MB limit + return fmt.Errorf("invalid message size: %d", size) + } + + // Read the message + messageBuf := make([]byte, size) + if _, err := io.ReadFull(r, messageBuf); err != nil { + return fmt.Errorf("read message: %w", err) + } + + // Parse at least the basic header to get API key and correlation ID + if len(messageBuf) < 8 { + return fmt.Errorf("message too short") + } + + apiKey := binary.BigEndian.Uint16(messageBuf[0:2]) + apiVersion := binary.BigEndian.Uint16(messageBuf[2:4]) + correlationID := binary.BigEndian.Uint32(messageBuf[4:8]) + + // Handle the request based on API key + var response []byte + var err error + + switch apiKey { + case 18: // ApiVersions + response, err = h.handleApiVersions(correlationID) + default: + err = fmt.Errorf("unsupported API key: %d (version %d)", apiKey, apiVersion) + } + + if err != nil { + return fmt.Errorf("handle request: %w", err) + } + + // Write response size and data + responseSizeBytes := make([]byte, 4) + binary.BigEndian.PutUint32(responseSizeBytes, uint32(len(response))) + + if _, err := w.Write(responseSizeBytes); err != nil { + return fmt.Errorf("write response size: %w", err) + } + if _, err := w.Write(response); err != nil { + return fmt.Errorf("write response: %w", err) + } + + if err := w.Flush(); err != nil { + return fmt.Errorf("flush response: %w", err) + } + } +} + +func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { + // Build ApiVersions response manually + // Response format: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4) + + response := make([]byte, 0, 64) + + // Correlation ID + correlationIDBytes := make([]byte, 4) + binary.BigEndian.PutUint32(correlationIDBytes, correlationID) + response = append(response, correlationIDBytes...) + + // Error code (0 = no error) + response = append(response, 0, 0) + + // Number of API keys (compact array format in newer versions, but using basic format for simplicity) + response = append(response, 0, 0, 0, 1) // 1 API key + + // API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2) + response = append(response, 0, 18) // API key 18 + response = append(response, 0, 0) // min version 0 + response = append(response, 0, 3) // max version 3 + + // Throttle time (4 bytes, 0 = no throttling) + response = append(response, 0, 0, 0, 0) + + return response, nil +} diff --git a/weed/mq/kafka/protocol/handler_test.go b/weed/mq/kafka/protocol/handler_test.go new file mode 100644 index 000000000..dd863ef1e --- /dev/null +++ b/weed/mq/kafka/protocol/handler_test.go @@ -0,0 +1,162 @@ +package protocol + +import ( + "encoding/binary" + "net" + "testing" + "time" +) + +func TestHandler_ApiVersions(t *testing.T) { + // Create handler + h := NewHandler() + + // Create in-memory connection + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + // Handle connection in background + done := make(chan error, 1) + go func() { + done <- h.HandleConn(server) + }() + + // Create ApiVersions request manually + // Request format: api_key(2) + api_version(2) + correlation_id(4) + client_id_size(2) + client_id + body + correlationID := uint32(12345) + clientID := "test-client" + + message := make([]byte, 0, 64) + message = append(message, 0, 18) // API key 18 (ApiVersions) + message = append(message, 0, 0) // API version 0 + + // Correlation ID + correlationIDBytes := make([]byte, 4) + binary.BigEndian.PutUint32(correlationIDBytes, correlationID) + message = append(message, correlationIDBytes...) + + // Client ID length and string + clientIDLen := uint16(len(clientID)) + message = append(message, byte(clientIDLen>>8), byte(clientIDLen)) + message = append(message, []byte(clientID)...) + + // Empty request body for ApiVersions + + // Write message size and data + messageSize := uint32(len(message)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, messageSize) + + if _, err := client.Write(sizeBuf); err != nil { + t.Fatalf("write size: %v", err) + } + if _, err := client.Write(message); err != nil { + t.Fatalf("write message: %v", err) + } + + // Read response size + var respSizeBuf [4]byte + client.SetReadDeadline(time.Now().Add(5 * time.Second)) + if _, err := client.Read(respSizeBuf[:]); err != nil { + t.Fatalf("read response size: %v", err) + } + + respSize := binary.BigEndian.Uint32(respSizeBuf[:]) + if respSize == 0 || respSize > 1024*1024 { + t.Fatalf("invalid response size: %d", respSize) + } + + // Read response data + respBuf := make([]byte, respSize) + if _, err := client.Read(respBuf); err != nil { + t.Fatalf("read response: %v", err) + } + + // Parse response: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4) + if len(respBuf) < 14 { // minimum response size + t.Fatalf("response too short: %d bytes", len(respBuf)) + } + + // Check correlation ID + respCorrelationID := binary.BigEndian.Uint32(respBuf[0:4]) + if respCorrelationID != correlationID { + t.Errorf("correlation ID mismatch: got %d, want %d", respCorrelationID, correlationID) + } + + // Check error code + errorCode := binary.BigEndian.Uint16(respBuf[4:6]) + if errorCode != 0 { + t.Errorf("expected no error, got error code: %d", errorCode) + } + + // Check number of API keys + numAPIKeys := binary.BigEndian.Uint32(respBuf[6:10]) + if numAPIKeys != 1 { + t.Errorf("expected 1 API key, got: %d", numAPIKeys) + } + + // Check API key details: api_key(2) + min_version(2) + max_version(2) + if len(respBuf) < 16 { + t.Fatalf("response too short for API key data") + } + + apiKey := binary.BigEndian.Uint16(respBuf[10:12]) + minVersion := binary.BigEndian.Uint16(respBuf[12:14]) + maxVersion := binary.BigEndian.Uint16(respBuf[14:16]) + + if apiKey != 18 { + t.Errorf("expected API key 18, got: %d", apiKey) + } + if minVersion != 0 { + t.Errorf("expected min version 0, got: %d", minVersion) + } + if maxVersion != 3 { + t.Errorf("expected max version 3, got: %d", maxVersion) + } + + // Close client to end handler + client.Close() + + // Wait for handler to complete + select { + case err := <-done: + if err != nil { + t.Errorf("handler error: %v", err) + } + case <-time.After(2 * time.Second): + t.Errorf("handler did not complete in time") + } +} + +func TestHandler_handleApiVersions(t *testing.T) { + h := NewHandler() + correlationID := uint32(999) + + response, err := h.handleApiVersions(correlationID) + if err != nil { + t.Fatalf("handleApiVersions: %v", err) + } + + if len(response) < 20 { // minimum expected size + t.Fatalf("response too short: %d bytes", len(response)) + } + + // Check correlation ID + respCorrelationID := binary.BigEndian.Uint32(response[0:4]) + if respCorrelationID != correlationID { + t.Errorf("correlation ID: got %d, want %d", respCorrelationID, correlationID) + } + + // Check error code + errorCode := binary.BigEndian.Uint16(response[4:6]) + if errorCode != 0 { + t.Errorf("error code: got %d, want 0", errorCode) + } + + // Check API key + apiKey := binary.BigEndian.Uint16(response[10:12]) + if apiKey != 18 { + t.Errorf("API key: got %d, want 18", apiKey) + } +}