Browse Source

mq(kafka): implement ApiVersions protocol handler with manual binary encoding and comprehensive unit tests

pull/7231/head
chrislu 2 months ago
parent
commit
7c4a5f546c
  1. 1
      go.mod
  2. 2
      go.sum
  3. 76
      test/kafka/gateway_smoke_test.go
  4. 97
      weed/mq/kafka/gateway/server.go
  5. 115
      weed/mq/kafka/protocol/handler.go
  6. 162
      weed/mq/kafka/protocol/handler_test.go

1
go.mod

@ -154,6 +154,7 @@ require (
github.com/rdleal/intervalst v1.5.0
github.com/redis/go-redis/v9 v9.12.1
github.com/schollz/progressbar/v3 v3.18.0
github.com/segmentio/kafka-go v0.4.49
github.com/shirou/gopsutil/v3 v3.24.5
github.com/tarantool/go-tarantool/v2 v2.4.0
github.com/tikv/client-go/v2 v2.0.7

2
go.sum

@ -1638,6 +1638,8 @@ github.com/seaweedfs/goexif v1.0.3 h1:ve/OjI7dxPW8X9YQsv3JuVMaxEyF9Rvfd04ouL+Bz3
github.com/seaweedfs/goexif v1.0.3/go.mod h1:Oni780Z236sXpIQzk1XoJlTwqrJ02smEin9zQeff7Fk=
github.com/seaweedfs/raft v1.1.3 h1:5B6hgneQ7IuU4Ceom/f6QUt8pEeqjcsRo+IxlyPZCws=
github.com/seaweedfs/raft v1.1.3/go.mod h1:9cYlEBA+djJbnf/5tWsCybtbL7ICYpi+Uxcg3MxjuNs=
github.com/segmentio/kafka-go v0.4.49 h1:GJiNX1d/g+kG6ljyJEoi9++PUMdXGAxb7JGPiDCuNmk=
github.com/segmentio/kafka-go v0.4.49/go.mod h1:Y1gn60kzLEEaW28YshXyk2+VCUKbJ3Qr6DrnT3i4+9E=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ=

76
test/kafka/gateway_smoke_test.go

@ -1,50 +1,48 @@
package kafka
import (
"net"
"testing"
"time"
"net"
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/mq/kafka/gateway"
"github.com/seaweedfs/seaweedfs/weed/mq/kafka/gateway"
)
func TestGateway_StartAcceptsConnections(t *testing.T) {
srv := gateway.NewServer(gateway.Options{Listen: ":0"})
if err := srv.Start(); err != nil {
t.Fatalf("start gateway: %v", err)
}
addr := srv.Addr()
if addr == "" {
t.Fatalf("server Addr() empty")
}
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial gateway: %v", err)
}
_ = conn.Close()
if err := srv.Close(); err != nil {
t.Fatalf("close gateway: %v", err)
}
srv := gateway.NewServer(gateway.Options{Listen: ":0"})
if err := srv.Start(); err != nil {
t.Fatalf("start gateway: %v", err)
}
addr := srv.Addr()
if addr == "" {
t.Fatalf("server Addr() empty")
}
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial gateway: %v", err)
}
_ = conn.Close()
if err := srv.Close(); err != nil {
t.Fatalf("close gateway: %v", err)
}
}
func TestGateway_RefusesAfterClose(t *testing.T) {
srv := gateway.NewServer(gateway.Options{Listen: ":0"})
if err := srv.Start(); err != nil {
t.Fatalf("start gateway: %v", err)
}
addr := srv.Addr()
if addr == "" {
t.Fatalf("server Addr() empty")
}
if err := srv.Close(); err != nil {
t.Fatalf("close gateway: %v", err)
}
// give the OS a brief moment to release the port
time.Sleep(50 * time.Millisecond)
_, err := net.DialTimeout("tcp", addr, 300*time.Millisecond)
if err == nil {
t.Fatalf("expected dial to fail after close")
}
srv := gateway.NewServer(gateway.Options{Listen: ":0"})
if err := srv.Start(); err != nil {
t.Fatalf("start gateway: %v", err)
}
addr := srv.Addr()
if addr == "" {
t.Fatalf("server Addr() empty")
}
if err := srv.Close(); err != nil {
t.Fatalf("close gateway: %v", err)
}
// give the OS a brief moment to release the port
time.Sleep(50 * time.Millisecond)
_, err := net.DialTimeout("tcp", addr, 300*time.Millisecond)
if err == nil {
t.Fatalf("expected dial to fail after close")
}
}

97
weed/mq/kafka/gateway/server.go

@ -4,74 +4,83 @@ import (
"context"
"net"
"sync"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/mq/kafka/protocol"
)
type Options struct {
Listen string
Listen string
}
type Server struct {
opts Options
ln net.Listener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
opts Options
ln net.Listener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
handler *protocol.Handler
}
func NewServer(opts Options) *Server {
ctx, cancel := context.WithCancel(context.Background())
return &Server{opts: opts, ctx: ctx, cancel: cancel}
ctx, cancel := context.WithCancel(context.Background())
return &Server{
opts: opts,
ctx: ctx,
cancel: cancel,
handler: protocol.NewHandler(),
}
}
func (s *Server) Start() error {
ln, err := net.Listen("tcp", s.opts.Listen)
if err != nil {
return err
}
s.ln = ln
s.wg.Add(1)
go func() {
defer s.wg.Done()
for {
conn, err := s.ln.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
return
}
}
ln, err := net.Listen("tcp", s.opts.Listen)
if err != nil {
return err
}
s.ln = ln
s.wg.Add(1)
go func() {
defer s.wg.Done()
for {
conn, err := s.ln.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
return
}
}
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
_ = c.Close()
if err := s.handler.HandleConn(c); err != nil {
glog.V(1).Infof("handle conn %v: %v", c.RemoteAddr(), err)
}
}(conn)
}
}()
return nil
}
}()
return nil
}
func (s *Server) Wait() error {
s.wg.Wait()
return nil
s.wg.Wait()
return nil
}
func (s *Server) Close() error {
s.cancel()
if s.ln != nil {
_ = s.ln.Close()
}
s.wg.Wait()
return nil
s.cancel()
if s.ln != nil {
_ = s.ln.Close()
}
s.wg.Wait()
return nil
}
// Addr returns the bound address of the server listener, or empty if not started.
func (s *Server) Addr() string {
if s.ln == nil {
return ""
}
return s.ln.Addr().String()
if s.ln == nil {
return ""
}
return s.ln.Addr().String()
}

115
weed/mq/kafka/protocol/handler.go

@ -0,0 +1,115 @@
package protocol
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"net"
)
// Handler processes Kafka protocol requests from clients
type Handler struct {
}
func NewHandler() *Handler {
return &Handler{}
}
// HandleConn processes a single client connection
func (h *Handler) HandleConn(conn net.Conn) error {
defer conn.Close()
r := bufio.NewReader(conn)
w := bufio.NewWriter(conn)
defer w.Flush()
for {
// Read message size (4 bytes)
var sizeBytes [4]byte
if _, err := io.ReadFull(r, sizeBytes[:]); err != nil {
if err == io.EOF {
return nil // clean disconnect
}
return fmt.Errorf("read size: %w", err)
}
size := binary.BigEndian.Uint32(sizeBytes[:])
if size == 0 || size > 1024*1024 { // 1MB limit
return fmt.Errorf("invalid message size: %d", size)
}
// Read the message
messageBuf := make([]byte, size)
if _, err := io.ReadFull(r, messageBuf); err != nil {
return fmt.Errorf("read message: %w", err)
}
// Parse at least the basic header to get API key and correlation ID
if len(messageBuf) < 8 {
return fmt.Errorf("message too short")
}
apiKey := binary.BigEndian.Uint16(messageBuf[0:2])
apiVersion := binary.BigEndian.Uint16(messageBuf[2:4])
correlationID := binary.BigEndian.Uint32(messageBuf[4:8])
// Handle the request based on API key
var response []byte
var err error
switch apiKey {
case 18: // ApiVersions
response, err = h.handleApiVersions(correlationID)
default:
err = fmt.Errorf("unsupported API key: %d (version %d)", apiKey, apiVersion)
}
if err != nil {
return fmt.Errorf("handle request: %w", err)
}
// Write response size and data
responseSizeBytes := make([]byte, 4)
binary.BigEndian.PutUint32(responseSizeBytes, uint32(len(response)))
if _, err := w.Write(responseSizeBytes); err != nil {
return fmt.Errorf("write response size: %w", err)
}
if _, err := w.Write(response); err != nil {
return fmt.Errorf("write response: %w", err)
}
if err := w.Flush(); err != nil {
return fmt.Errorf("flush response: %w", err)
}
}
}
func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) {
// Build ApiVersions response manually
// Response format: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4)
response := make([]byte, 0, 64)
// Correlation ID
correlationIDBytes := make([]byte, 4)
binary.BigEndian.PutUint32(correlationIDBytes, correlationID)
response = append(response, correlationIDBytes...)
// Error code (0 = no error)
response = append(response, 0, 0)
// Number of API keys (compact array format in newer versions, but using basic format for simplicity)
response = append(response, 0, 0, 0, 1) // 1 API key
// API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2)
response = append(response, 0, 18) // API key 18
response = append(response, 0, 0) // min version 0
response = append(response, 0, 3) // max version 3
// Throttle time (4 bytes, 0 = no throttling)
response = append(response, 0, 0, 0, 0)
return response, nil
}

162
weed/mq/kafka/protocol/handler_test.go

@ -0,0 +1,162 @@
package protocol
import (
"encoding/binary"
"net"
"testing"
"time"
)
func TestHandler_ApiVersions(t *testing.T) {
// Create handler
h := NewHandler()
// Create in-memory connection
server, client := net.Pipe()
defer server.Close()
defer client.Close()
// Handle connection in background
done := make(chan error, 1)
go func() {
done <- h.HandleConn(server)
}()
// Create ApiVersions request manually
// Request format: api_key(2) + api_version(2) + correlation_id(4) + client_id_size(2) + client_id + body
correlationID := uint32(12345)
clientID := "test-client"
message := make([]byte, 0, 64)
message = append(message, 0, 18) // API key 18 (ApiVersions)
message = append(message, 0, 0) // API version 0
// Correlation ID
correlationIDBytes := make([]byte, 4)
binary.BigEndian.PutUint32(correlationIDBytes, correlationID)
message = append(message, correlationIDBytes...)
// Client ID length and string
clientIDLen := uint16(len(clientID))
message = append(message, byte(clientIDLen>>8), byte(clientIDLen))
message = append(message, []byte(clientID)...)
// Empty request body for ApiVersions
// Write message size and data
messageSize := uint32(len(message))
sizeBuf := make([]byte, 4)
binary.BigEndian.PutUint32(sizeBuf, messageSize)
if _, err := client.Write(sizeBuf); err != nil {
t.Fatalf("write size: %v", err)
}
if _, err := client.Write(message); err != nil {
t.Fatalf("write message: %v", err)
}
// Read response size
var respSizeBuf [4]byte
client.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := client.Read(respSizeBuf[:]); err != nil {
t.Fatalf("read response size: %v", err)
}
respSize := binary.BigEndian.Uint32(respSizeBuf[:])
if respSize == 0 || respSize > 1024*1024 {
t.Fatalf("invalid response size: %d", respSize)
}
// Read response data
respBuf := make([]byte, respSize)
if _, err := client.Read(respBuf); err != nil {
t.Fatalf("read response: %v", err)
}
// Parse response: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4)
if len(respBuf) < 14 { // minimum response size
t.Fatalf("response too short: %d bytes", len(respBuf))
}
// Check correlation ID
respCorrelationID := binary.BigEndian.Uint32(respBuf[0:4])
if respCorrelationID != correlationID {
t.Errorf("correlation ID mismatch: got %d, want %d", respCorrelationID, correlationID)
}
// Check error code
errorCode := binary.BigEndian.Uint16(respBuf[4:6])
if errorCode != 0 {
t.Errorf("expected no error, got error code: %d", errorCode)
}
// Check number of API keys
numAPIKeys := binary.BigEndian.Uint32(respBuf[6:10])
if numAPIKeys != 1 {
t.Errorf("expected 1 API key, got: %d", numAPIKeys)
}
// Check API key details: api_key(2) + min_version(2) + max_version(2)
if len(respBuf) < 16 {
t.Fatalf("response too short for API key data")
}
apiKey := binary.BigEndian.Uint16(respBuf[10:12])
minVersion := binary.BigEndian.Uint16(respBuf[12:14])
maxVersion := binary.BigEndian.Uint16(respBuf[14:16])
if apiKey != 18 {
t.Errorf("expected API key 18, got: %d", apiKey)
}
if minVersion != 0 {
t.Errorf("expected min version 0, got: %d", minVersion)
}
if maxVersion != 3 {
t.Errorf("expected max version 3, got: %d", maxVersion)
}
// Close client to end handler
client.Close()
// Wait for handler to complete
select {
case err := <-done:
if err != nil {
t.Errorf("handler error: %v", err)
}
case <-time.After(2 * time.Second):
t.Errorf("handler did not complete in time")
}
}
func TestHandler_handleApiVersions(t *testing.T) {
h := NewHandler()
correlationID := uint32(999)
response, err := h.handleApiVersions(correlationID)
if err != nil {
t.Fatalf("handleApiVersions: %v", err)
}
if len(response) < 20 { // minimum expected size
t.Fatalf("response too short: %d bytes", len(response))
}
// Check correlation ID
respCorrelationID := binary.BigEndian.Uint32(response[0:4])
if respCorrelationID != correlationID {
t.Errorf("correlation ID: got %d, want %d", respCorrelationID, correlationID)
}
// Check error code
errorCode := binary.BigEndian.Uint16(response[4:6])
if errorCode != 0 {
t.Errorf("error code: got %d, want 0", errorCode)
}
// Check API key
apiKey := binary.BigEndian.Uint16(response[10:12])
if apiKey != 18 {
t.Errorf("API key: got %d, want 18", apiKey)
}
}
Loading…
Cancel
Save