You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
331 lines
9.4 KiB
331 lines
9.4 KiB
package ipc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
)
|
|
|
|
// Client provides IPC communication with the Rust RDMA engine
|
|
type Client struct {
|
|
socketPath string
|
|
conn net.Conn
|
|
mu sync.RWMutex
|
|
logger *logrus.Logger
|
|
connected bool
|
|
}
|
|
|
|
// NewClient creates a new IPC client
|
|
func NewClient(socketPath string, logger *logrus.Logger) *Client {
|
|
if logger == nil {
|
|
logger = logrus.New()
|
|
logger.SetLevel(logrus.InfoLevel)
|
|
}
|
|
|
|
return &Client{
|
|
socketPath: socketPath,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Connect establishes connection to the Rust RDMA engine
|
|
func (c *Client) Connect(ctx context.Context) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.connected {
|
|
return nil
|
|
}
|
|
|
|
c.logger.WithField("socket", c.socketPath).Info("🔗 Connecting to Rust RDMA engine")
|
|
|
|
dialer := &net.Dialer{}
|
|
conn, err := dialer.DialContext(ctx, "unix", c.socketPath)
|
|
if err != nil {
|
|
c.logger.WithError(err).Error("❌ Failed to connect to RDMA engine")
|
|
return fmt.Errorf("failed to connect to RDMA engine at %s: %w", c.socketPath, err)
|
|
}
|
|
|
|
c.conn = conn
|
|
c.connected = true
|
|
c.logger.Info("✅ Connected to Rust RDMA engine")
|
|
|
|
return nil
|
|
}
|
|
|
|
// Disconnect closes the connection
|
|
func (c *Client) Disconnect() {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
c.connected = false
|
|
c.logger.Info("🔌 Disconnected from Rust RDMA engine")
|
|
}
|
|
}
|
|
|
|
// IsConnected returns connection status
|
|
func (c *Client) IsConnected() bool {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
return c.connected
|
|
}
|
|
|
|
// SendMessage sends an IPC message and waits for response
|
|
func (c *Client) SendMessage(ctx context.Context, msg *IpcMessage) (*IpcMessage, error) {
|
|
c.mu.RLock()
|
|
conn := c.conn
|
|
connected := c.connected
|
|
c.mu.RUnlock()
|
|
|
|
if !connected || conn == nil {
|
|
return nil, fmt.Errorf("not connected to RDMA engine")
|
|
}
|
|
|
|
// Set write timeout
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
conn.SetWriteDeadline(deadline)
|
|
} else {
|
|
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
|
}
|
|
|
|
c.logger.WithField("type", msg.Type).Debug("📤 Sending message to Rust engine")
|
|
|
|
// Serialize message with MessagePack
|
|
data, err := msgpack.Marshal(msg)
|
|
if err != nil {
|
|
c.logger.WithError(err).Error("❌ Failed to marshal message")
|
|
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
|
}
|
|
|
|
// Send message length (4 bytes) + message data
|
|
lengthBytes := make([]byte, 4)
|
|
binary.LittleEndian.PutUint32(lengthBytes, uint32(len(data)))
|
|
|
|
if _, err := conn.Write(lengthBytes); err != nil {
|
|
c.logger.WithError(err).Error("❌ Failed to send message length")
|
|
return nil, fmt.Errorf("failed to send message length: %w", err)
|
|
}
|
|
|
|
if _, err := conn.Write(data); err != nil {
|
|
c.logger.WithError(err).Error("❌ Failed to send message data")
|
|
return nil, fmt.Errorf("failed to send message data: %w", err)
|
|
}
|
|
|
|
c.logger.WithFields(logrus.Fields{
|
|
"type": msg.Type,
|
|
"size": len(data),
|
|
}).Debug("📤 Message sent successfully")
|
|
|
|
// Read response
|
|
return c.readResponse(ctx, conn)
|
|
}
|
|
|
|
// readResponse reads and deserializes the response message
|
|
func (c *Client) readResponse(ctx context.Context, conn net.Conn) (*IpcMessage, error) {
|
|
// Set read timeout
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
conn.SetReadDeadline(deadline)
|
|
} else {
|
|
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
|
}
|
|
|
|
// Read message length (4 bytes)
|
|
lengthBytes := make([]byte, 4)
|
|
if _, err := conn.Read(lengthBytes); err != nil {
|
|
c.logger.WithError(err).Error("❌ Failed to read response length")
|
|
return nil, fmt.Errorf("failed to read response length: %w", err)
|
|
}
|
|
|
|
length := binary.LittleEndian.Uint32(lengthBytes)
|
|
if length > 64*1024*1024 { // 64MB sanity check
|
|
c.logger.WithField("length", length).Error("❌ Response message too large")
|
|
return nil, fmt.Errorf("response message too large: %d bytes", length)
|
|
}
|
|
|
|
// Read message data
|
|
data := make([]byte, length)
|
|
if _, err := conn.Read(data); err != nil {
|
|
c.logger.WithError(err).Error("❌ Failed to read response data")
|
|
return nil, fmt.Errorf("failed to read response data: %w", err)
|
|
}
|
|
|
|
c.logger.WithField("size", length).Debug("📥 Response received")
|
|
|
|
// Deserialize with MessagePack
|
|
var response IpcMessage
|
|
if err := msgpack.Unmarshal(data, &response); err != nil {
|
|
c.logger.WithError(err).Error("❌ Failed to unmarshal response")
|
|
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
c.logger.WithField("type", response.Type).Debug("📥 Response deserialized successfully")
|
|
|
|
return &response, nil
|
|
}
|
|
|
|
// High-level convenience methods
|
|
|
|
// Ping sends a ping message to test connectivity
|
|
func (c *Client) Ping(ctx context.Context, clientID *string) (*PongResponse, error) {
|
|
msg := NewPingMessage(clientID)
|
|
|
|
response, err := c.SendMessage(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if response.Type == MsgError {
|
|
errorData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
|
|
}
|
|
var errorResp ErrorResponse
|
|
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
|
|
}
|
|
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
|
|
}
|
|
|
|
if response.Type != MsgPong {
|
|
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
|
|
}
|
|
|
|
// Convert response data to PongResponse
|
|
pongData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal pong data: %w", err)
|
|
}
|
|
|
|
var pong PongResponse
|
|
if err := msgpack.Unmarshal(pongData, &pong); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal pong response: %w", err)
|
|
}
|
|
|
|
return &pong, nil
|
|
}
|
|
|
|
// GetCapabilities requests engine capabilities
|
|
func (c *Client) GetCapabilities(ctx context.Context, clientID *string) (*GetCapabilitiesResponse, error) {
|
|
msg := NewGetCapabilitiesMessage(clientID)
|
|
|
|
response, err := c.SendMessage(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if response.Type == MsgError {
|
|
errorData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
|
|
}
|
|
var errorResp ErrorResponse
|
|
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
|
|
}
|
|
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
|
|
}
|
|
|
|
if response.Type != MsgGetCapabilitiesResponse {
|
|
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
|
|
}
|
|
|
|
// Convert response data to GetCapabilitiesResponse
|
|
capsData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal capabilities data: %w", err)
|
|
}
|
|
|
|
var caps GetCapabilitiesResponse
|
|
if err := msgpack.Unmarshal(capsData, &caps); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal capabilities response: %w", err)
|
|
}
|
|
|
|
return &caps, nil
|
|
}
|
|
|
|
// StartRead initiates an RDMA read operation
|
|
func (c *Client) StartRead(ctx context.Context, req *StartReadRequest) (*StartReadResponse, error) {
|
|
msg := NewStartReadMessage(req)
|
|
|
|
response, err := c.SendMessage(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if response.Type == MsgError {
|
|
errorData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
|
|
}
|
|
var errorResp ErrorResponse
|
|
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
|
|
}
|
|
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
|
|
}
|
|
|
|
if response.Type != MsgStartReadResponse {
|
|
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
|
|
}
|
|
|
|
// Convert response data to StartReadResponse
|
|
startData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal start read data: %w", err)
|
|
}
|
|
|
|
var startResp StartReadResponse
|
|
if err := msgpack.Unmarshal(startData, &startResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal start read response: %w", err)
|
|
}
|
|
|
|
return &startResp, nil
|
|
}
|
|
|
|
// CompleteRead completes an RDMA read operation
|
|
func (c *Client) CompleteRead(ctx context.Context, sessionID string, success bool, bytesTransferred uint64, clientCrc *uint32) (*CompleteReadResponse, error) {
|
|
msg := NewCompleteReadMessage(sessionID, success, bytesTransferred, clientCrc, nil)
|
|
|
|
response, err := c.SendMessage(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if response.Type == MsgError {
|
|
errorData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
|
|
}
|
|
var errorResp ErrorResponse
|
|
if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
|
|
}
|
|
return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
|
|
}
|
|
|
|
if response.Type != MsgCompleteReadResponse {
|
|
return nil, fmt.Errorf("unexpected response type: %s", response.Type)
|
|
}
|
|
|
|
// Convert response data to CompleteReadResponse
|
|
completeData, err := msgpack.Marshal(response.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal complete read data: %w", err)
|
|
}
|
|
|
|
var completeResp CompleteReadResponse
|
|
if err := msgpack.Unmarshal(completeData, &completeResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal complete read response: %w", err)
|
|
}
|
|
|
|
return &completeResp, nil
|
|
}
|