Browse Source

fmt

pull/7231/head
chrislu 2 months ago
parent
commit
fdb7f94526
  1. 36
      weed/mq/kafka/gateway/server.go
  2. 28
      weed/mq/kafka/protocol/handler.go
  3. 36
      weed/mq/kafka/protocol/handler_test.go

36
weed/mq/kafka/gateway/server.go

@ -1,12 +1,12 @@
package gateway package gateway
import ( import (
"context"
"net"
"sync"
"context"
"net"
"sync"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/mq/kafka/protocol"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/mq/kafka/protocol"
) )
type Options struct { type Options struct {
@ -14,12 +14,12 @@ type Options struct {
} }
type Server struct { type Server struct {
opts Options
ln net.Listener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
handler *protocol.Handler
opts Options
ln net.Listener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
handler *protocol.Handler
} }
func NewServer(opts Options) *Server { func NewServer(opts Options) *Server {
@ -51,13 +51,13 @@ func (s *Server) Start() error {
return return
} }
} }
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
if err := s.handler.HandleConn(c); err != nil {
glog.V(1).Infof("handle conn %v: %v", c.RemoteAddr(), err)
}
}(conn)
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
if err := s.handler.HandleConn(c); err != nil {
glog.V(1).Infof("handle conn %v: %v", c.RemoteAddr(), err)
}
}(conn)
} }
}() }()
return nil return nil

28
weed/mq/kafka/protocol/handler.go

@ -19,7 +19,7 @@ func NewHandler() *Handler {
// HandleConn processes a single client connection // HandleConn processes a single client connection
func (h *Handler) HandleConn(conn net.Conn) error { func (h *Handler) HandleConn(conn net.Conn) error {
defer conn.Close() defer conn.Close()
r := bufio.NewReader(conn) r := bufio.NewReader(conn)
w := bufio.NewWriter(conn) w := bufio.NewWriter(conn)
defer w.Flush() defer w.Flush()
@ -33,7 +33,7 @@ func (h *Handler) HandleConn(conn net.Conn) error {
} }
return fmt.Errorf("read size: %w", err) return fmt.Errorf("read size: %w", err)
} }
size := binary.BigEndian.Uint32(sizeBytes[:]) size := binary.BigEndian.Uint32(sizeBytes[:])
if size == 0 || size > 1024*1024 { // 1MB limit if size == 0 || size > 1024*1024 { // 1MB limit
return fmt.Errorf("invalid message size: %d", size) return fmt.Errorf("invalid message size: %d", size)
@ -49,7 +49,7 @@ func (h *Handler) HandleConn(conn net.Conn) error {
if len(messageBuf) < 8 { if len(messageBuf) < 8 {
return fmt.Errorf("message too short") return fmt.Errorf("message too short")
} }
apiKey := binary.BigEndian.Uint16(messageBuf[0:2]) apiKey := binary.BigEndian.Uint16(messageBuf[0:2])
apiVersion := binary.BigEndian.Uint16(messageBuf[2:4]) apiVersion := binary.BigEndian.Uint16(messageBuf[2:4])
correlationID := binary.BigEndian.Uint32(messageBuf[4:8]) correlationID := binary.BigEndian.Uint32(messageBuf[4:8])
@ -57,14 +57,14 @@ func (h *Handler) HandleConn(conn net.Conn) error {
// Handle the request based on API key // Handle the request based on API key
var response []byte var response []byte
var err error var err error
switch apiKey { switch apiKey {
case 18: // ApiVersions case 18: // ApiVersions
response, err = h.handleApiVersions(correlationID) response, err = h.handleApiVersions(correlationID)
default: default:
err = fmt.Errorf("unsupported API key: %d (version %d)", apiKey, apiVersion) err = fmt.Errorf("unsupported API key: %d (version %d)", apiKey, apiVersion)
} }
if err != nil { if err != nil {
return fmt.Errorf("handle request: %w", err) return fmt.Errorf("handle request: %w", err)
} }
@ -72,14 +72,14 @@ func (h *Handler) HandleConn(conn net.Conn) error {
// Write response size and data // Write response size and data
responseSizeBytes := make([]byte, 4) responseSizeBytes := make([]byte, 4)
binary.BigEndian.PutUint32(responseSizeBytes, uint32(len(response))) binary.BigEndian.PutUint32(responseSizeBytes, uint32(len(response)))
if _, err := w.Write(responseSizeBytes); err != nil { if _, err := w.Write(responseSizeBytes); err != nil {
return fmt.Errorf("write response size: %w", err) return fmt.Errorf("write response size: %w", err)
} }
if _, err := w.Write(response); err != nil { if _, err := w.Write(response); err != nil {
return fmt.Errorf("write response: %w", err) return fmt.Errorf("write response: %w", err)
} }
if err := w.Flush(); err != nil { if err := w.Flush(); err != nil {
return fmt.Errorf("flush response: %w", err) return fmt.Errorf("flush response: %w", err)
} }
@ -89,27 +89,27 @@ func (h *Handler) HandleConn(conn net.Conn) error {
func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) { func (h *Handler) handleApiVersions(correlationID uint32) ([]byte, error) {
// Build ApiVersions response manually // Build ApiVersions response manually
// Response format: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4) // Response format: correlation_id(4) + error_code(2) + num_api_keys(4) + api_keys + throttle_time(4)
response := make([]byte, 0, 64) response := make([]byte, 0, 64)
// Correlation ID // Correlation ID
correlationIDBytes := make([]byte, 4) correlationIDBytes := make([]byte, 4)
binary.BigEndian.PutUint32(correlationIDBytes, correlationID) binary.BigEndian.PutUint32(correlationIDBytes, correlationID)
response = append(response, correlationIDBytes...) response = append(response, correlationIDBytes...)
// Error code (0 = no error) // Error code (0 = no error)
response = append(response, 0, 0) response = append(response, 0, 0)
// Number of API keys (compact array format in newer versions, but using basic format for simplicity) // Number of API keys (compact array format in newer versions, but using basic format for simplicity)
response = append(response, 0, 0, 0, 1) // 1 API key response = append(response, 0, 0, 0, 1) // 1 API key
// API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2) // API Key 18 (ApiVersions): api_key(2) + min_version(2) + max_version(2)
response = append(response, 0, 18) // API key 18 response = append(response, 0, 18) // API key 18
response = append(response, 0, 0) // min version 0 response = append(response, 0, 0) // min version 0
response = append(response, 0, 3) // max version 3 response = append(response, 0, 3) // max version 3
// Throttle time (4 bytes, 0 = no throttling) // Throttle time (4 bytes, 0 = no throttling)
response = append(response, 0, 0, 0, 0) response = append(response, 0, 0, 0, 0)
return response, nil return response, nil
} }

36
weed/mq/kafka/protocol/handler_test.go

@ -26,23 +26,23 @@ func TestHandler_ApiVersions(t *testing.T) {
// Request format: api_key(2) + api_version(2) + correlation_id(4) + client_id_size(2) + client_id + body // Request format: api_key(2) + api_version(2) + correlation_id(4) + client_id_size(2) + client_id + body
correlationID := uint32(12345) correlationID := uint32(12345)
clientID := "test-client" clientID := "test-client"
message := make([]byte, 0, 64) message := make([]byte, 0, 64)
message = append(message, 0, 18) // API key 18 (ApiVersions)
message = append(message, 0, 0) // API version 0
message = append(message, 0, 18) // API key 18 (ApiVersions)
message = append(message, 0, 0) // API version 0
// Correlation ID // Correlation ID
correlationIDBytes := make([]byte, 4) correlationIDBytes := make([]byte, 4)
binary.BigEndian.PutUint32(correlationIDBytes, correlationID) binary.BigEndian.PutUint32(correlationIDBytes, correlationID)
message = append(message, correlationIDBytes...) message = append(message, correlationIDBytes...)
// Client ID length and string // Client ID length and string
clientIDLen := uint16(len(clientID)) clientIDLen := uint16(len(clientID))
message = append(message, byte(clientIDLen>>8), byte(clientIDLen)) message = append(message, byte(clientIDLen>>8), byte(clientIDLen))
message = append(message, []byte(clientID)...) message = append(message, []byte(clientID)...)
// Empty request body for ApiVersions // Empty request body for ApiVersions
// Write message size and data // Write message size and data
messageSize := uint32(len(message)) messageSize := uint32(len(message))
sizeBuf := make([]byte, 4) sizeBuf := make([]byte, 4)
@ -77,34 +77,34 @@ func TestHandler_ApiVersions(t *testing.T) {
if len(respBuf) < 14 { // minimum response size if len(respBuf) < 14 { // minimum response size
t.Fatalf("response too short: %d bytes", len(respBuf)) t.Fatalf("response too short: %d bytes", len(respBuf))
} }
// Check correlation ID // Check correlation ID
respCorrelationID := binary.BigEndian.Uint32(respBuf[0:4]) respCorrelationID := binary.BigEndian.Uint32(respBuf[0:4])
if respCorrelationID != correlationID { if respCorrelationID != correlationID {
t.Errorf("correlation ID mismatch: got %d, want %d", respCorrelationID, correlationID) t.Errorf("correlation ID mismatch: got %d, want %d", respCorrelationID, correlationID)
} }
// Check error code // Check error code
errorCode := binary.BigEndian.Uint16(respBuf[4:6]) errorCode := binary.BigEndian.Uint16(respBuf[4:6])
if errorCode != 0 { if errorCode != 0 {
t.Errorf("expected no error, got error code: %d", errorCode) t.Errorf("expected no error, got error code: %d", errorCode)
} }
// Check number of API keys // Check number of API keys
numAPIKeys := binary.BigEndian.Uint32(respBuf[6:10]) numAPIKeys := binary.BigEndian.Uint32(respBuf[6:10])
if numAPIKeys != 1 { if numAPIKeys != 1 {
t.Errorf("expected 1 API key, got: %d", numAPIKeys) t.Errorf("expected 1 API key, got: %d", numAPIKeys)
} }
// Check API key details: api_key(2) + min_version(2) + max_version(2) // Check API key details: api_key(2) + min_version(2) + max_version(2)
if len(respBuf) < 16 { if len(respBuf) < 16 {
t.Fatalf("response too short for API key data") t.Fatalf("response too short for API key data")
} }
apiKey := binary.BigEndian.Uint16(respBuf[10:12]) apiKey := binary.BigEndian.Uint16(respBuf[10:12])
minVersion := binary.BigEndian.Uint16(respBuf[12:14]) minVersion := binary.BigEndian.Uint16(respBuf[12:14])
maxVersion := binary.BigEndian.Uint16(respBuf[14:16]) maxVersion := binary.BigEndian.Uint16(respBuf[14:16])
if apiKey != 18 { if apiKey != 18 {
t.Errorf("expected API key 18, got: %d", apiKey) t.Errorf("expected API key 18, got: %d", apiKey)
} }
@ -132,28 +132,28 @@ func TestHandler_ApiVersions(t *testing.T) {
func TestHandler_handleApiVersions(t *testing.T) { func TestHandler_handleApiVersions(t *testing.T) {
h := NewHandler() h := NewHandler()
correlationID := uint32(999) correlationID := uint32(999)
response, err := h.handleApiVersions(correlationID) response, err := h.handleApiVersions(correlationID)
if err != nil { if err != nil {
t.Fatalf("handleApiVersions: %v", err) t.Fatalf("handleApiVersions: %v", err)
} }
if len(response) < 20 { // minimum expected size if len(response) < 20 { // minimum expected size
t.Fatalf("response too short: %d bytes", len(response)) t.Fatalf("response too short: %d bytes", len(response))
} }
// Check correlation ID // Check correlation ID
respCorrelationID := binary.BigEndian.Uint32(response[0:4]) respCorrelationID := binary.BigEndian.Uint32(response[0:4])
if respCorrelationID != correlationID { if respCorrelationID != correlationID {
t.Errorf("correlation ID: got %d, want %d", respCorrelationID, correlationID) t.Errorf("correlation ID: got %d, want %d", respCorrelationID, correlationID)
} }
// Check error code // Check error code
errorCode := binary.BigEndian.Uint16(response[4:6]) errorCode := binary.BigEndian.Uint16(response[4:6])
if errorCode != 0 { if errorCode != 0 {
t.Errorf("error code: got %d, want 0", errorCode) t.Errorf("error code: got %d, want 0", errorCode)
} }
// Check API key // Check API key
apiKey := binary.BigEndian.Uint16(response[10:12]) apiKey := binary.BigEndian.Uint16(response[10:12])
if apiKey != 18 { if apiKey != 18 {

Loading…
Cancel
Save