You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
316 lines
7.6 KiB
316 lines
7.6 KiB
//go:build integration
|
|
|
|
package test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// Node represents an SSH-accessible (or local WSL2) machine.
|
|
type Node struct {
|
|
Host string
|
|
User string
|
|
KeyFile string
|
|
IsLocal bool // WSL2 mode: use exec.CommandContext instead of SSH
|
|
|
|
mu sync.Mutex
|
|
client *ssh.Client
|
|
}
|
|
|
|
// Connect establishes the SSH connection (no-op for local mode).
|
|
func (n *Node) Connect() error {
|
|
if n.IsLocal {
|
|
return nil
|
|
}
|
|
|
|
key, err := os.ReadFile(n.KeyFile)
|
|
if err != nil {
|
|
return fmt.Errorf("read SSH key %s: %w", n.KeyFile, err)
|
|
}
|
|
signer, err := ssh.ParsePrivateKey(key)
|
|
if err != nil {
|
|
return fmt.Errorf("parse SSH key: %w", err)
|
|
}
|
|
|
|
config := &ssh.ClientConfig{
|
|
User: n.User,
|
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
|
|
addr := n.Host
|
|
if !strings.Contains(addr, ":") {
|
|
addr += ":22"
|
|
}
|
|
|
|
n.mu.Lock()
|
|
defer n.mu.Unlock()
|
|
n.client, err = ssh.Dial("tcp", addr, config)
|
|
if err != nil {
|
|
return fmt.Errorf("SSH dial %s: %w", addr, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Run executes a command and returns stdout, stderr, exit code.
|
|
// The context controls timeout -- cancelled context kills the command.
|
|
func (n *Node) Run(ctx context.Context, cmd string) (stdout, stderr string, exitCode int, err error) {
|
|
if n.IsLocal {
|
|
return n.runLocal(ctx, cmd)
|
|
}
|
|
return n.runSSH(ctx, cmd)
|
|
}
|
|
|
|
func (n *Node) runLocal(ctx context.Context, cmd string) (string, string, int, error) {
|
|
c := exec.CommandContext(ctx, "wsl", "-e", "bash", "-c", cmd)
|
|
var outBuf, errBuf bytes.Buffer
|
|
c.Stdout = &outBuf
|
|
c.Stderr = &errBuf
|
|
|
|
err := c.Run()
|
|
if ctx.Err() != nil {
|
|
return outBuf.String(), errBuf.String(), -1, fmt.Errorf("command timed out: %w", ctx.Err())
|
|
}
|
|
if err != nil {
|
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
return outBuf.String(), errBuf.String(), exitErr.ExitCode(), nil
|
|
}
|
|
return outBuf.String(), errBuf.String(), -1, err
|
|
}
|
|
return outBuf.String(), errBuf.String(), 0, nil
|
|
}
|
|
|
|
func (n *Node) runSSH(ctx context.Context, cmd string) (string, string, int, error) {
|
|
n.mu.Lock()
|
|
if n.client == nil {
|
|
n.mu.Unlock()
|
|
return "", "", -1, fmt.Errorf("SSH not connected")
|
|
}
|
|
session, err := n.client.NewSession()
|
|
n.mu.Unlock()
|
|
if err != nil {
|
|
return "", "", -1, fmt.Errorf("new SSH session: %w", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
var outBuf, errBuf bytes.Buffer
|
|
session.Stdout = &outBuf
|
|
session.Stderr = &errBuf
|
|
|
|
done := make(chan error, 1)
|
|
go func() { done <- session.Run(cmd) }()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = session.Signal(ssh.SIGKILL)
|
|
return outBuf.String(), errBuf.String(), -1, fmt.Errorf("command timed out: %w", ctx.Err())
|
|
case err := <-done:
|
|
if err != nil {
|
|
if exitErr, ok := err.(*ssh.ExitError); ok {
|
|
return outBuf.String(), errBuf.String(), exitErr.ExitStatus(), nil
|
|
}
|
|
return outBuf.String(), errBuf.String(), -1, err
|
|
}
|
|
return outBuf.String(), errBuf.String(), 0, nil
|
|
}
|
|
}
|
|
|
|
// RunRoot executes a command with sudo -n (non-interactive).
|
|
// Fails immediately if sudo requires a password instead of hanging.
|
|
func (n *Node) RunRoot(ctx context.Context, cmd string) (string, string, int, error) {
|
|
return n.Run(ctx, "sudo -n "+cmd)
|
|
}
|
|
|
|
// Upload copies a local file to the remote node via SCP.
|
|
func (n *Node) Upload(local, remote string) error {
|
|
if n.IsLocal {
|
|
// Convert Windows path to WSL path for cp
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
wslLocal := toWSLPath(local)
|
|
_, stderr, code, err := n.Run(ctx, fmt.Sprintf("cp %s %s && chmod +x %s", wslLocal, remote, remote))
|
|
if err != nil || code != 0 {
|
|
return fmt.Errorf("local upload: code=%d stderr=%s err=%v", code, stderr, err)
|
|
}
|
|
return nil
|
|
}
|
|
return n.scpUpload(local, remote)
|
|
}
|
|
|
|
func (n *Node) scpUpload(local, remote string) error {
|
|
data, err := os.ReadFile(local)
|
|
if err != nil {
|
|
return fmt.Errorf("read local file %s: %w", local, err)
|
|
}
|
|
|
|
n.mu.Lock()
|
|
if n.client == nil {
|
|
n.mu.Unlock()
|
|
return fmt.Errorf("SSH not connected")
|
|
}
|
|
session, err := n.client.NewSession()
|
|
n.mu.Unlock()
|
|
if err != nil {
|
|
return fmt.Errorf("new SSH session: %w", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
go func() {
|
|
w, _ := session.StdinPipe()
|
|
fmt.Fprintf(w, "C0755 %d %s\n", len(data), remoteName(remote))
|
|
w.Write(data)
|
|
fmt.Fprint(w, "\x00")
|
|
w.Close()
|
|
}()
|
|
|
|
dir := remoteDir(remote)
|
|
return session.Run(fmt.Sprintf("scp -t %s", dir))
|
|
}
|
|
|
|
// Download copies a remote file to local via SCP.
|
|
func (n *Node) Download(remote, local string) error {
|
|
if n.IsLocal {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
wslLocal := toWSLPath(local)
|
|
_, stderr, code, err := n.Run(ctx, fmt.Sprintf("cp %s %s", remote, wslLocal))
|
|
if err != nil || code != 0 {
|
|
return fmt.Errorf("local download: code=%d stderr=%s err=%v", code, stderr, err)
|
|
}
|
|
return nil
|
|
}
|
|
return n.scpDownload(remote, local)
|
|
}
|
|
|
|
func (n *Node) scpDownload(remote, local string) error {
|
|
n.mu.Lock()
|
|
if n.client == nil {
|
|
n.mu.Unlock()
|
|
return fmt.Errorf("SSH not connected")
|
|
}
|
|
session, err := n.client.NewSession()
|
|
n.mu.Unlock()
|
|
if err != nil {
|
|
return fmt.Errorf("new SSH session: %w", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
var buf bytes.Buffer
|
|
session.Stdout = &buf
|
|
if err := session.Run(fmt.Sprintf("cat %s", remote)); err != nil {
|
|
return fmt.Errorf("read remote %s: %w", remote, err)
|
|
}
|
|
return os.WriteFile(local, buf.Bytes(), 0644)
|
|
}
|
|
|
|
// Kill sends SIGKILL to a process by PID.
|
|
func (n *Node) Kill(pid int) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
_, _, _, err := n.RunRoot(ctx, fmt.Sprintf("kill -9 %d", pid))
|
|
return err
|
|
}
|
|
|
|
// HasCommand checks if a command is available on the node.
|
|
func (n *Node) HasCommand(cmd string) bool {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_, _, code, err := n.Run(ctx, fmt.Sprintf("which %s", cmd))
|
|
return err == nil && code == 0
|
|
}
|
|
|
|
// Close closes the SSH connection.
|
|
func (n *Node) Close() {
|
|
n.mu.Lock()
|
|
defer n.mu.Unlock()
|
|
if n.client != nil {
|
|
n.client.Close()
|
|
n.client = nil
|
|
}
|
|
}
|
|
|
|
// DialTCP opens a direct TCP connection through the SSH tunnel.
|
|
func (n *Node) DialTCP(addr string) (net.Conn, error) {
|
|
if n.IsLocal {
|
|
return net.DialTimeout("tcp", addr, 5*time.Second)
|
|
}
|
|
n.mu.Lock()
|
|
defer n.mu.Unlock()
|
|
if n.client == nil {
|
|
return nil, fmt.Errorf("SSH not connected")
|
|
}
|
|
return n.client.Dial("tcp", addr)
|
|
}
|
|
|
|
// StreamRun executes a command and streams stdout to the writer.
|
|
func (n *Node) StreamRun(ctx context.Context, cmd string, w io.Writer) error {
|
|
if n.IsLocal {
|
|
c := exec.CommandContext(ctx, "wsl", "-e", "bash", "-c", cmd)
|
|
c.Stdout = w
|
|
c.Stderr = w
|
|
return c.Run()
|
|
}
|
|
|
|
n.mu.Lock()
|
|
if n.client == nil {
|
|
n.mu.Unlock()
|
|
return fmt.Errorf("SSH not connected")
|
|
}
|
|
session, err := n.client.NewSession()
|
|
n.mu.Unlock()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer session.Close()
|
|
|
|
session.Stdout = w
|
|
session.Stderr = w
|
|
|
|
done := make(chan error, 1)
|
|
go func() { done <- session.Run(cmd) }()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = session.Signal(ssh.SIGKILL)
|
|
return ctx.Err()
|
|
case err := <-done:
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func toWSLPath(winPath string) string {
|
|
// Convert C:\foo\bar to /mnt/c/foo/bar
|
|
p := strings.ReplaceAll(winPath, "\\", "/")
|
|
if len(p) >= 2 && p[1] == ':' {
|
|
drive := strings.ToLower(string(p[0]))
|
|
p = "/mnt/" + drive + p[2:]
|
|
}
|
|
return p
|
|
}
|
|
|
|
func remoteName(path string) string {
|
|
parts := strings.Split(path, "/")
|
|
return parts[len(parts)-1]
|
|
}
|
|
|
|
func remoteDir(path string) string {
|
|
idx := strings.LastIndex(path, "/")
|
|
if idx < 0 {
|
|
return "."
|
|
}
|
|
return path[:idx]
|
|
}
|