You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

331 lines
9.4 KiB

package ipc
import (
"context"
"encoding/binary"
"fmt"
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/vmihailenco/msgpack/v5"
)
// Client provides IPC communication with the Rust RDMA engine
type Client struct {
socketPath string
conn net.Conn
mu sync.RWMutex
logger *logrus.Logger
connected bool
}
// NewClient creates a new IPC client
func NewClient(socketPath string, logger *logrus.Logger) *Client {
if logger == nil {
logger = logrus.New()
logger.SetLevel(logrus.InfoLevel)
}
return &Client{
socketPath: socketPath,
logger: logger,
}
}
// Connect establishes connection to the Rust RDMA engine
func (c *Client) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return nil
}
c.logger.WithField("socket", c.socketPath).Info("🔗 Connecting to Rust RDMA engine")
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "unix", c.socketPath)
if err != nil {
c.logger.WithError(err).Error("❌ Failed to connect to RDMA engine")
return fmt.Errorf("failed to connect to RDMA engine at %s: %w", c.socketPath, err)
}
c.conn = conn
c.connected = true
c.logger.Info("✅ Connected to Rust RDMA engine")
return nil
}
// Disconnect closes the connection
func (c *Client) Disconnect() {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
c.connected = false
c.logger.Info("🔌 Disconnected from Rust RDMA engine")
}
}
// IsConnected returns connection status
func (c *Client) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connected
}
// SendMessage sends an IPC message and waits for response
func (c *Client) SendMessage(ctx context.Context, msg *IpcMessage) (*IpcMessage, error) {
c.mu.RLock()
conn := c.conn
connected := c.connected
c.mu.RUnlock()
if !connected || conn == nil {
return nil, fmt.Errorf("not connected to RDMA engine")
}
// Set write timeout
if deadline, ok := ctx.Deadline(); ok {
conn.SetWriteDeadline(deadline)
} else {
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
}
c.logger.WithField("type", msg.Type).Debug("📤 Sending message to Rust engine")
// Serialize message with MessagePack
data, err := msgpack.Marshal(msg)
if err != nil {
c.logger.WithError(err).Error("❌ Failed to marshal message")
return nil, fmt.Errorf("failed to marshal message: %w", err)
}
// Send message length (4 bytes) + message data
lengthBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(lengthBytes, uint32(len(data)))
if _, err := conn.Write(lengthBytes); err != nil {
c.logger.WithError(err).Error("❌ Failed to send message length")
return nil, fmt.Errorf("failed to send message length: %w", err)
}
if _, err := conn.Write(data); err != nil {
c.logger.WithError(err).Error("❌ Failed to send message data")
return nil, fmt.Errorf("failed to send message data: %w", err)
}
c.logger.WithFields(logrus.Fields{
"type": msg.Type,
"size": len(data),
}).Debug("📤 Message sent successfully")
// Read response
return c.readResponse(ctx, conn)
}
// readResponse reads and deserializes the response message
func (c *Client) readResponse(ctx context.Context, conn net.Conn) (*IpcMessage, error) {
// Set read timeout
if deadline, ok := ctx.Deadline(); ok {
conn.SetReadDeadline(deadline)
} else {
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
}
// Read message length (4 bytes)
lengthBytes := make([]byte, 4)
if _, err := conn.Read(lengthBytes); err != nil {
c.logger.WithError(err).Error("❌ Failed to read response length")
return nil, fmt.Errorf("failed to read response length: %w", err)
}
length := binary.LittleEndian.Uint32(lengthBytes)
if length > 64*1024*1024 { // 64MB sanity check
c.logger.WithField("length", length).Error("❌ Response message too large")
return nil, fmt.Errorf("response message too large: %d bytes", length)
}
// Read message data
data := make([]byte, length)
if _, err := conn.Read(data); err != nil {
c.logger.WithError(err).Error("❌ Failed to read response data")
return nil, fmt.Errorf("failed to read response data: %w", err)
}
c.logger.WithField("size", length).Debug("📥 Response received")
// Deserialize with MessagePack
var response IpcMessage
if err := msgpack.Unmarshal(data, &response); err != nil {
c.logger.WithError(err).Error("❌ Failed to unmarshal response")
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
c.logger.WithField("type", response.Type).Debug("📥 Response deserialized successfully")
return &response, nil
}
// High-level convenience methods
// Ping sends a ping message to test connectivity
func (c *Client) Ping(ctx context.Context, clientID *string) (*PongResponse, error) {
msg := NewPingMessage(clientID)
response, err := c.SendMessage(ctx, msg)
if err != nil {
return nil, err
}
if response.Type == MsgError {
errorData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
}
var errorResp ErrorResponse
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
}
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
}
if response.Type != MsgPong {
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
}
// Convert response data to PongResponse
pongData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal pong data: %w", err)
}
var pong PongResponse
if err := msgpack.Unmarshal(pongData, &pong); err != nil {
return nil, fmt.Errorf("failed to unmarshal pong response: %w", err)
}
return &pong, nil
}
// GetCapabilities requests engine capabilities
func (c *Client) GetCapabilities(ctx context.Context, clientID *string) (*GetCapabilitiesResponse, error) {
msg := NewGetCapabilitiesMessage(clientID)
response, err := c.SendMessage(ctx, msg)
if err != nil {
return nil, err
}
if response.Type == MsgError {
errorData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
}
var errorResp ErrorResponse
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
}
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
}
if response.Type != MsgGetCapabilitiesResponse {
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
}
// Convert response data to GetCapabilitiesResponse
capsData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal capabilities data: %w", err)
}
var caps GetCapabilitiesResponse
if err := msgpack.Unmarshal(capsData, &caps); err != nil {
return nil, fmt.Errorf("failed to unmarshal capabilities response: %w", err)
}
return &caps, nil
}
// StartRead initiates an RDMA read operation
func (c *Client) StartRead(ctx context.Context, req *StartReadRequest) (*StartReadResponse, error) {
msg := NewStartReadMessage(req)
response, err := c.SendMessage(ctx, msg)
if err != nil {
return nil, err
}
if response.Type == MsgError {
errorData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
}
var errorResp ErrorResponse
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
}
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
}
if response.Type != MsgStartReadResponse {
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
}
// Convert response data to StartReadResponse
startData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal start read data: %w", err)
}
var startResp StartReadResponse
if err := msgpack.Unmarshal(startData, &startResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal start read response: %w", err)
}
return &startResp, nil
}
// CompleteRead completes an RDMA read operation
func (c *Client) CompleteRead(ctx context.Context, sessionID string, success bool, bytesTransferred uint64, clientCrc *uint32) (*CompleteReadResponse, error) {
msg := NewCompleteReadMessage(sessionID, success, bytesTransferred, clientCrc, nil)
response, err := c.SendMessage(ctx, msg)
if err != nil {
return nil, err
}
if response.Type == MsgError {
errorData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
}
var errorResp ErrorResponse
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
}
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
}
if response.Type != MsgCompleteReadResponse {
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
}
// Convert response data to CompleteReadResponse
completeData, err := msgpack.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal complete read data: %w", err)
}
var completeResp CompleteReadResponse
if err := msgpack.Unmarshal(completeData, &completeResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal complete read response: %w", err)
}
return &completeResp, nil
}