You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							331 lines
						
					
					
						
							9.4 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							331 lines
						
					
					
						
							9.4 KiB
						
					
					
				
								package ipc
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"context"
							 | 
						|
									"encoding/binary"
							 | 
						|
									"fmt"
							 | 
						|
									"net"
							 | 
						|
									"sync"
							 | 
						|
									"time"
							 | 
						|
								
							 | 
						|
									"github.com/sirupsen/logrus"
							 | 
						|
									"github.com/vmihailenco/msgpack/v5"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// Client provides IPC communication with the Rust RDMA engine
							 | 
						|
								type Client struct {
							 | 
						|
									socketPath string
							 | 
						|
									conn       net.Conn
							 | 
						|
									mu         sync.RWMutex
							 | 
						|
									logger     *logrus.Logger
							 | 
						|
									connected  bool
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// NewClient creates a new IPC client
							 | 
						|
								func NewClient(socketPath string, logger *logrus.Logger) *Client {
							 | 
						|
									if logger == nil {
							 | 
						|
										logger = logrus.New()
							 | 
						|
										logger.SetLevel(logrus.InfoLevel)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return &Client{
							 | 
						|
										socketPath: socketPath,
							 | 
						|
										logger:     logger,
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Connect establishes connection to the Rust RDMA engine
							 | 
						|
								func (c *Client) Connect(ctx context.Context) error {
							 | 
						|
									c.mu.Lock()
							 | 
						|
									defer c.mu.Unlock()
							 | 
						|
								
							 | 
						|
									if c.connected {
							 | 
						|
										return nil
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									c.logger.WithField("socket", c.socketPath).Info("🔗 Connecting to Rust RDMA engine")
							 | 
						|
								
							 | 
						|
									dialer := &net.Dialer{}
							 | 
						|
									conn, err := dialer.DialContext(ctx, "unix", c.socketPath)
							 | 
						|
									if err != nil {
							 | 
						|
										c.logger.WithError(err).Error("❌ Failed to connect to RDMA engine")
							 | 
						|
										return fmt.Errorf("failed to connect to RDMA engine at %s: %w", c.socketPath, err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									c.conn = conn
							 | 
						|
									c.connected = true
							 | 
						|
									c.logger.Info("✅ Connected to Rust RDMA engine")
							 | 
						|
								
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Disconnect closes the connection
							 | 
						|
								func (c *Client) Disconnect() {
							 | 
						|
									c.mu.Lock()
							 | 
						|
									defer c.mu.Unlock()
							 | 
						|
								
							 | 
						|
									if c.conn != nil {
							 | 
						|
										c.conn.Close()
							 | 
						|
										c.conn = nil
							 | 
						|
										c.connected = false
							 | 
						|
										c.logger.Info("🔌 Disconnected from Rust RDMA engine")
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// IsConnected returns connection status
							 | 
						|
								func (c *Client) IsConnected() bool {
							 | 
						|
									c.mu.RLock()
							 | 
						|
									defer c.mu.RUnlock()
							 | 
						|
									return c.connected
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// SendMessage sends an IPC message and waits for response
							 | 
						|
								func (c *Client) SendMessage(ctx context.Context, msg *IpcMessage) (*IpcMessage, error) {
							 | 
						|
									c.mu.RLock()
							 | 
						|
									conn := c.conn
							 | 
						|
									connected := c.connected
							 | 
						|
									c.mu.RUnlock()
							 | 
						|
								
							 | 
						|
									if !connected || conn == nil {
							 | 
						|
										return nil, fmt.Errorf("not connected to RDMA engine")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Set write timeout
							 | 
						|
									if deadline, ok := ctx.Deadline(); ok {
							 | 
						|
										conn.SetWriteDeadline(deadline)
							 | 
						|
									} else {
							 | 
						|
										conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									c.logger.WithField("type", msg.Type).Debug("📤 Sending message to Rust engine")
							 | 
						|
								
							 | 
						|
									// Serialize message with MessagePack
							 | 
						|
									data, err := msgpack.Marshal(msg)
							 | 
						|
									if err != nil {
							 | 
						|
										c.logger.WithError(err).Error("❌ Failed to marshal message")
							 | 
						|
										return nil, fmt.Errorf("failed to marshal message: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Send message length (4 bytes) + message data
							 | 
						|
									lengthBytes := make([]byte, 4)
							 | 
						|
									binary.LittleEndian.PutUint32(lengthBytes, uint32(len(data)))
							 | 
						|
								
							 | 
						|
									if _, err := conn.Write(lengthBytes); err != nil {
							 | 
						|
										c.logger.WithError(err).Error("❌ Failed to send message length")
							 | 
						|
										return nil, fmt.Errorf("failed to send message length: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if _, err := conn.Write(data); err != nil {
							 | 
						|
										c.logger.WithError(err).Error("❌ Failed to send message data")
							 | 
						|
										return nil, fmt.Errorf("failed to send message data: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									c.logger.WithFields(logrus.Fields{
							 | 
						|
										"type": msg.Type,
							 | 
						|
										"size": len(data),
							 | 
						|
									}).Debug("📤 Message sent successfully")
							 | 
						|
								
							 | 
						|
									// Read response
							 | 
						|
									return c.readResponse(ctx, conn)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// readResponse reads and deserializes the response message
							 | 
						|
								func (c *Client) readResponse(ctx context.Context, conn net.Conn) (*IpcMessage, error) {
							 | 
						|
									// Set read timeout
							 | 
						|
									if deadline, ok := ctx.Deadline(); ok {
							 | 
						|
										conn.SetReadDeadline(deadline)
							 | 
						|
									} else {
							 | 
						|
										conn.SetReadDeadline(time.Now().Add(30 * time.Second))
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Read message length (4 bytes)
							 | 
						|
									lengthBytes := make([]byte, 4)
							 | 
						|
									if _, err := conn.Read(lengthBytes); err != nil {
							 | 
						|
										c.logger.WithError(err).Error("❌ Failed to read response length")
							 | 
						|
										return nil, fmt.Errorf("failed to read response length: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									length := binary.LittleEndian.Uint32(lengthBytes)
							 | 
						|
									if length > 64*1024*1024 { // 64MB sanity check
							 | 
						|
										c.logger.WithField("length", length).Error("❌ Response message too large")
							 | 
						|
										return nil, fmt.Errorf("response message too large: %d bytes", length)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Read message data
							 | 
						|
									data := make([]byte, length)
							 | 
						|
									if _, err := conn.Read(data); err != nil {
							 | 
						|
										c.logger.WithError(err).Error("❌ Failed to read response data")
							 | 
						|
										return nil, fmt.Errorf("failed to read response data: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									c.logger.WithField("size", length).Debug("📥 Response received")
							 | 
						|
								
							 | 
						|
									// Deserialize with MessagePack
							 | 
						|
									var response IpcMessage
							 | 
						|
									if err := msgpack.Unmarshal(data, &response); err != nil {
							 | 
						|
										c.logger.WithError(err).Error("❌ Failed to unmarshal response")
							 | 
						|
										return nil, fmt.Errorf("failed to unmarshal response: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									c.logger.WithField("type", response.Type).Debug("📥 Response deserialized successfully")
							 | 
						|
								
							 | 
						|
									return &response, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// High-level convenience methods
							 | 
						|
								
							 | 
						|
								// Ping sends a ping message to test connectivity
							 | 
						|
								func (c *Client) Ping(ctx context.Context, clientID *string) (*PongResponse, error) {
							 | 
						|
									msg := NewPingMessage(clientID)
							 | 
						|
								
							 | 
						|
									response, err := c.SendMessage(ctx, msg)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type == MsgError {
							 | 
						|
										errorData, err := msgpack.Marshal(response.Data)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
							 | 
						|
										}
							 | 
						|
										var errorResp ErrorResponse
							 | 
						|
										if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
							 | 
						|
										}
							 | 
						|
										return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type != MsgPong {
							 | 
						|
										return nil, fmt.Errorf("unexpected response type: %s", response.Type)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Convert response data to PongResponse
							 | 
						|
									pongData, err := msgpack.Marshal(response.Data)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to marshal pong data: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									var pong PongResponse
							 | 
						|
									if err := msgpack.Unmarshal(pongData, &pong); err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to unmarshal pong response: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return &pong, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GetCapabilities requests engine capabilities
							 | 
						|
								func (c *Client) GetCapabilities(ctx context.Context, clientID *string) (*GetCapabilitiesResponse, error) {
							 | 
						|
									msg := NewGetCapabilitiesMessage(clientID)
							 | 
						|
								
							 | 
						|
									response, err := c.SendMessage(ctx, msg)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type == MsgError {
							 | 
						|
										errorData, err := msgpack.Marshal(response.Data)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
							 | 
						|
										}
							 | 
						|
										var errorResp ErrorResponse
							 | 
						|
										if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
							 | 
						|
										}
							 | 
						|
										return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type != MsgGetCapabilitiesResponse {
							 | 
						|
										return nil, fmt.Errorf("unexpected response type: %s", response.Type)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Convert response data to GetCapabilitiesResponse
							 | 
						|
									capsData, err := msgpack.Marshal(response.Data)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to marshal capabilities data: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									var caps GetCapabilitiesResponse
							 | 
						|
									if err := msgpack.Unmarshal(capsData, &caps); err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to unmarshal capabilities response: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return &caps, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// StartRead initiates an RDMA read operation
							 | 
						|
								func (c *Client) StartRead(ctx context.Context, req *StartReadRequest) (*StartReadResponse, error) {
							 | 
						|
									msg := NewStartReadMessage(req)
							 | 
						|
								
							 | 
						|
									response, err := c.SendMessage(ctx, msg)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type == MsgError {
							 | 
						|
										errorData, err := msgpack.Marshal(response.Data)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
							 | 
						|
										}
							 | 
						|
										var errorResp ErrorResponse
							 | 
						|
										if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
							 | 
						|
										}
							 | 
						|
										return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type != MsgStartReadResponse {
							 | 
						|
										return nil, fmt.Errorf("unexpected response type: %s", response.Type)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Convert response data to StartReadResponse
							 | 
						|
									startData, err := msgpack.Marshal(response.Data)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to marshal start read data: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									var startResp StartReadResponse
							 | 
						|
									if err := msgpack.Unmarshal(startData, &startResp); err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to unmarshal start read response: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return &startResp, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// CompleteRead completes an RDMA read operation
							 | 
						|
								func (c *Client) CompleteRead(ctx context.Context, sessionID string, success bool, bytesTransferred uint64, clientCrc *uint32) (*CompleteReadResponse, error) {
							 | 
						|
									msg := NewCompleteReadMessage(sessionID, success, bytesTransferred, clientCrc, nil)
							 | 
						|
								
							 | 
						|
									response, err := c.SendMessage(ctx, msg)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type == MsgError {
							 | 
						|
										errorData, err := msgpack.Marshal(response.Data)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
							 | 
						|
										}
							 | 
						|
										var errorResp ErrorResponse
							 | 
						|
										if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
							 | 
						|
											return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
							 | 
						|
										}
							 | 
						|
										return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if response.Type != MsgCompleteReadResponse {
							 | 
						|
										return nil, fmt.Errorf("unexpected response type: %s", response.Type)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Convert response data to CompleteReadResponse
							 | 
						|
									completeData, err := msgpack.Marshal(response.Data)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to marshal complete read data: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									var completeResp CompleteReadResponse
							 | 
						|
									if err := msgpack.Unmarshal(completeData, &completeResp); err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to unmarshal complete read response: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return &completeResp, nil
							 | 
						|
								}
							 |