Browse Source

feat: add BlockVol engine and iSCSI target (Phase 1 + Phase 2)

Phase 1: Extent-mapped block storage engine with WAL, crash recovery,
dirty map, flusher, and group commit. 174 tests, zero SeaweedFS imports.

Phase 2: Pure Go iSCSI target (RFC 7143) with PDU codec, login
negotiation, SendTargets discovery, 12 SCSI opcodes, Data-In/Out/R2T
sequencing, session management, and standalone iscsi-target binary.
164 tests. IQN->BlockDevice binding via DeviceLookup interface.

Total: 338 tests, 14.6K lines.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
feature/sw-block
Ping Qiu 1 week ago
parent
commit
6a400f6760
  1. 371
      weed/storage/blockvol/blockvol.go
  2. 3903
      weed/storage/blockvol/blockvol_qa_test.go
  3. 685
      weed/storage/blockvol/blockvol_test.go
  4. 110
      weed/storage/blockvol/dirty_map.go
  5. 153
      weed/storage/blockvol/dirty_map_test.go
  6. 239
      weed/storage/blockvol/flusher.go
  7. 261
      weed/storage/blockvol/flusher_test.go
  8. 167
      weed/storage/blockvol/group_commit.go
  9. 287
      weed/storage/blockvol/group_commit_test.go
  10. 141
      weed/storage/blockvol/iscsi/cmd/iscsi-target/main.go
  11. 216
      weed/storage/blockvol/iscsi/cmd/iscsi-target/smoke-test.sh
  12. 207
      weed/storage/blockvol/iscsi/dataio.go
  13. 410
      weed/storage/blockvol/iscsi/dataio_test.go
  14. 68
      weed/storage/blockvol/iscsi/discovery.go
  15. 213
      weed/storage/blockvol/iscsi/discovery_test.go
  16. 412
      weed/storage/blockvol/iscsi/integration_test.go
  17. 396
      weed/storage/blockvol/iscsi/login.go
  18. 444
      weed/storage/blockvol/iscsi/login_test.go
  19. 183
      weed/storage/blockvol/iscsi/params.go
  20. 357
      weed/storage/blockvol/iscsi/params_test.go
  21. 447
      weed/storage/blockvol/iscsi/pdu.go
  22. 559
      weed/storage/blockvol/iscsi/pdu_test.go
  23. 463
      weed/storage/blockvol/iscsi/scsi.go
  24. 692
      weed/storage/blockvol/iscsi/scsi_test.go
  25. 421
      weed/storage/blockvol/iscsi/session.go
  26. 364
      weed/storage/blockvol/iscsi/session_test.go
  27. 216
      weed/storage/blockvol/iscsi/target.go
  28. 259
      weed/storage/blockvol/iscsi/target_test.go
  29. 38
      weed/storage/blockvol/lba.go
  30. 95
      weed/storage/blockvol/lba_test.go
  31. 157
      weed/storage/blockvol/recovery.go
  32. 416
      weed/storage/blockvol/recovery_test.go
  33. 249
      weed/storage/blockvol/superblock.go
  34. 146
      weed/storage/blockvol/superblock_test.go
  35. 153
      weed/storage/blockvol/wal_entry.go
  36. 284
      weed/storage/blockvol/wal_entry_test.go
  37. 192
      weed/storage/blockvol/wal_writer.go
  38. 244
      weed/storage/blockvol/wal_writer_test.go

371
weed/storage/blockvol/blockvol.go

@ -0,0 +1,371 @@
// Package blockvol implements a block volume storage engine with WAL,
// dirty map, group commit, and crash recovery. It operates on a single
// file and has no dependency on SeaweedFS internals.
package blockvol
import (
"encoding/binary"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
)
// CreateOptions configures a new block volume.
type CreateOptions struct {
VolumeSize uint64 // required, logical size in bytes
ExtentSize uint32 // default 64KB
BlockSize uint32 // default 4KB
WALSize uint64 // default 64MB
Replication string // default "000"
}
// BlockVol is the core block volume engine.
type BlockVol struct {
mu sync.RWMutex
fd *os.File
path string
super Superblock
wal *WALWriter
dirtyMap *DirtyMap
groupCommit *GroupCommitter
flusher *Flusher
nextLSN atomic.Uint64
healthy atomic.Bool
}
// CreateBlockVol creates a new block volume file at path.
func CreateBlockVol(path string, opts CreateOptions) (*BlockVol, error) {
if opts.VolumeSize == 0 {
return nil, ErrInvalidVolumeSize
}
sb, err := NewSuperblock(opts.VolumeSize, opts)
if err != nil {
return nil, fmt.Errorf("blockvol: create superblock: %w", err)
}
sb.CreatedAt = uint64(time.Now().Unix())
// Extent region starts after superblock + WAL.
extentStart := sb.WALOffset + sb.WALSize
totalFileSize := extentStart + opts.VolumeSize
fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0644)
if err != nil {
return nil, fmt.Errorf("blockvol: create file: %w", err)
}
if err := fd.Truncate(int64(totalFileSize)); err != nil {
fd.Close()
os.Remove(path)
return nil, fmt.Errorf("blockvol: truncate: %w", err)
}
if _, err := sb.WriteTo(fd); err != nil {
fd.Close()
os.Remove(path)
return nil, fmt.Errorf("blockvol: write superblock: %w", err)
}
if err := fd.Sync(); err != nil {
fd.Close()
os.Remove(path)
return nil, fmt.Errorf("blockvol: sync: %w", err)
}
dm := NewDirtyMap()
wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, 0, 0)
v := &BlockVol{
fd: fd,
path: path,
super: sb,
wal: wal,
dirtyMap: dm,
}
v.nextLSN.Store(1)
v.healthy.Store(true)
v.groupCommit = NewGroupCommitter(GroupCommitterConfig{
SyncFunc: fd.Sync,
OnDegraded: func() { v.healthy.Store(false) },
})
go v.groupCommit.Run()
v.flusher = NewFlusher(FlusherConfig{
FD: fd,
Super: &v.super,
WAL: wal,
DirtyMap: dm,
})
go v.flusher.Run()
return v, nil
}
// OpenBlockVol opens an existing block volume file and runs crash recovery.
func OpenBlockVol(path string) (*BlockVol, error) {
fd, err := os.OpenFile(path, os.O_RDWR, 0644)
if err != nil {
return nil, fmt.Errorf("blockvol: open file: %w", err)
}
sb, err := ReadSuperblock(fd)
if err != nil {
fd.Close()
return nil, fmt.Errorf("blockvol: read superblock: %w", err)
}
if err := sb.Validate(); err != nil {
fd.Close()
return nil, fmt.Errorf("blockvol: validate superblock: %w", err)
}
dirtyMap := NewDirtyMap()
// Run WAL recovery: replay entries from tail to head.
result, err := RecoverWAL(fd, &sb, dirtyMap)
if err != nil {
fd.Close()
return nil, fmt.Errorf("blockvol: recovery: %w", err)
}
nextLSN := sb.WALCheckpointLSN + 1
if result.HighestLSN >= nextLSN {
nextLSN = result.HighestLSN + 1
}
wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, sb.WALHead, sb.WALTail)
v := &BlockVol{
fd: fd,
path: path,
super: sb,
wal: wal,
dirtyMap: dirtyMap,
}
v.nextLSN.Store(nextLSN)
v.healthy.Store(true)
v.groupCommit = NewGroupCommitter(GroupCommitterConfig{
SyncFunc: fd.Sync,
OnDegraded: func() { v.healthy.Store(false) },
})
go v.groupCommit.Run()
v.flusher = NewFlusher(FlusherConfig{
FD: fd,
Super: &v.super,
WAL: wal,
DirtyMap: dirtyMap,
})
go v.flusher.Run()
return v, nil
}
// WriteLBA writes data at the given logical block address.
// Data length must be a multiple of BlockSize.
func (v *BlockVol) WriteLBA(lba uint64, data []byte) error {
if err := ValidateWrite(lba, uint32(len(data)), v.super.VolumeSize, v.super.BlockSize); err != nil {
return err
}
lsn := v.nextLSN.Add(1) - 1
entry := &WALEntry{
LSN: lsn,
Epoch: 0, // Phase 1: no fencing
Type: EntryTypeWrite,
LBA: lba,
Length: uint32(len(data)),
Data: data,
}
walOff, err := v.wal.Append(entry)
if err != nil {
return fmt.Errorf("blockvol: WriteLBA: %w", err)
}
// Update dirty map: one entry per block written.
blocks := uint32(len(data)) / v.super.BlockSize
for i := uint32(0); i < blocks; i++ {
blockOff := walOff // all blocks in this entry share the same WAL offset
v.dirtyMap.Put(lba+uint64(i), blockOff, lsn, v.super.BlockSize)
}
return nil
}
// ReadLBA reads data at the given logical block address.
// length is in bytes and must be a multiple of BlockSize.
func (v *BlockVol) ReadLBA(lba uint64, length uint32) ([]byte, error) {
if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil {
return nil, err
}
blocks := length / v.super.BlockSize
result := make([]byte, length)
for i := uint32(0); i < blocks; i++ {
blockLBA := lba + uint64(i)
blockData, err := v.readOneBlock(blockLBA)
if err != nil {
return nil, fmt.Errorf("blockvol: ReadLBA block %d: %w", blockLBA, err)
}
copy(result[i*v.super.BlockSize:], blockData)
}
return result, nil
}
// readOneBlock reads a single block, checking dirty map first, then extent.
func (v *BlockVol) readOneBlock(lba uint64) ([]byte, error) {
walOff, _, _, ok := v.dirtyMap.Get(lba)
if ok {
return v.readBlockFromWAL(walOff, lba)
}
return v.readBlockFromExtent(lba)
}
// maxWALEntryDataLen caps the data length we trust from a WAL entry header.
// Anything larger than the WAL region itself is corrupt.
const maxWALEntryDataLen = 256 * 1024 * 1024 // 256MB absolute ceiling
// readBlockFromWAL reads a block's data from its WAL entry.
func (v *BlockVol) readBlockFromWAL(walOff uint64, lba uint64) ([]byte, error) {
// Read the WAL entry header to get the full entry size.
headerBuf := make([]byte, walEntryHeaderSize)
absOff := int64(v.super.WALOffset + walOff)
if _, err := v.fd.ReadAt(headerBuf, absOff); err != nil {
return nil, fmt.Errorf("readBlockFromWAL: read header at %d: %w", absOff, err)
}
// Check entry type first — TRIM has no data payload, so Length is
// metadata (trim extent), not a data size to allocate.
entryType := headerBuf[16] // Type is at offset LSN(8) + Epoch(8) = 16
if entryType == EntryTypeTrim {
// TRIM entry: return zeros regardless of Length field.
return make([]byte, v.super.BlockSize), nil
}
if entryType != EntryTypeWrite {
return nil, fmt.Errorf("readBlockFromWAL: expected WRITE or TRIM entry, got type 0x%02x", entryType)
}
// Parse and validate the data Length field before allocating (WRITE only).
dataLen := v.parseDataLength(headerBuf)
if uint64(dataLen) > v.super.WALSize || uint64(dataLen) > maxWALEntryDataLen {
return nil, fmt.Errorf("readBlockFromWAL: corrupt entry length %d exceeds WAL size %d", dataLen, v.super.WALSize)
}
entryLen := walEntryHeaderSize + int(dataLen)
fullBuf := make([]byte, entryLen)
if _, err := v.fd.ReadAt(fullBuf, absOff); err != nil {
return nil, fmt.Errorf("readBlockFromWAL: read entry at %d: %w", absOff, err)
}
entry, err := DecodeWALEntry(fullBuf)
if err != nil {
return nil, fmt.Errorf("readBlockFromWAL: decode: %w", err)
}
// Find the block within the entry's data.
blockOffset := (lba - entry.LBA) * uint64(v.super.BlockSize)
if blockOffset+uint64(v.super.BlockSize) > uint64(len(entry.Data)) {
return nil, fmt.Errorf("readBlockFromWAL: block offset %d out of range for entry data len %d", blockOffset, len(entry.Data))
}
block := make([]byte, v.super.BlockSize)
copy(block, entry.Data[blockOffset:blockOffset+uint64(v.super.BlockSize)])
return block, nil
}
// parseDataLength extracts the Length field from a WAL entry header buffer.
func (v *BlockVol) parseDataLength(headerBuf []byte) uint32 {
// Length is at offset: LSN(8) + Epoch(8) + Type(1) + Flags(1) + LBA(8) = 26
return binary.LittleEndian.Uint32(headerBuf[26:])
}
// readBlockFromExtent reads a block directly from the extent region.
func (v *BlockVol) readBlockFromExtent(lba uint64) ([]byte, error) {
extentStart := v.super.WALOffset + v.super.WALSize
byteOffset := extentStart + lba*uint64(v.super.BlockSize)
block := make([]byte, v.super.BlockSize)
if _, err := v.fd.ReadAt(block, int64(byteOffset)); err != nil {
return nil, fmt.Errorf("readBlockFromExtent: pread at %d: %w", byteOffset, err)
}
return block, nil
}
// Trim marks blocks as deallocated. Subsequent reads return zeros.
// The trim is recorded in the WAL with a Length field so the flusher
// can zero the extent region and recovery can replay the trim.
func (v *BlockVol) Trim(lba uint64, length uint32) error {
if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil {
return err
}
lsn := v.nextLSN.Add(1) - 1
entry := &WALEntry{
LSN: lsn,
Epoch: 0,
Type: EntryTypeTrim,
LBA: lba,
Length: length,
}
walOff, err := v.wal.Append(entry)
if err != nil {
return fmt.Errorf("blockvol: Trim: %w", err)
}
// Update dirty map: mark each trimmed block so the flusher sees it.
// readOneBlock checks entry type and returns zeros for TRIM entries.
blocks := length / v.super.BlockSize
for i := uint32(0); i < blocks; i++ {
v.dirtyMap.Put(lba+uint64(i), walOff, lsn, v.super.BlockSize)
}
return nil
}
// Path returns the file path of the block volume.
func (v *BlockVol) Path() string {
return v.path
}
// Info returns volume metadata.
func (v *BlockVol) Info() VolumeInfo {
return VolumeInfo{
VolumeSize: v.super.VolumeSize,
BlockSize: v.super.BlockSize,
ExtentSize: v.super.ExtentSize,
WALSize: v.super.WALSize,
Healthy: v.healthy.Load(),
}
}
// VolumeInfo contains read-only volume metadata.
type VolumeInfo struct {
VolumeSize uint64
BlockSize uint32
ExtentSize uint32
WALSize uint64
Healthy bool
}
// SyncCache ensures all previously written WAL entries are durable on disk.
// It submits a sync request to the group committer, which batches fsyncs.
func (v *BlockVol) SyncCache() error {
return v.groupCommit.Submit()
}
// Close shuts down the block volume and closes the file.
// Shutdown order: group committer → stop flusher goroutine → final flush → close fd.
func (v *BlockVol) Close() error {
if v.groupCommit != nil {
v.groupCommit.Stop()
}
var flushErr error
if v.flusher != nil {
v.flusher.Stop() // stop background goroutine first (no concurrent flush)
flushErr = v.flusher.FlushOnce() // then do final flush safely
}
closeErr := v.fd.Close()
if flushErr != nil {
return flushErr
}
return closeErr
}

3903
weed/storage/blockvol/blockvol_qa_test.go
File diff suppressed because it is too large
View File

685
weed/storage/blockvol/blockvol_test.go

@ -0,0 +1,685 @@
package blockvol
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"testing"
)
func TestBlockVol(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "write_then_read", run: testWriteThenRead},
{name: "overwrite_read_latest", run: testOverwriteReadLatest},
{name: "write_no_sync_not_durable", run: testWriteNoSyncNotDurable},
{name: "write_sync_durable", run: testWriteSyncDurable},
{name: "write_multiple_sync", run: testWriteMultipleSync},
{name: "read_unflushed", run: testReadUnflushed},
{name: "read_flushed", run: testReadFlushed},
{name: "read_mixed_dirty_clean", run: testReadMixedDirtyClean},
{name: "wal_read_corrupt_length", run: testWALReadCorruptLength},
{name: "open_invalid_superblock", run: testOpenInvalidSuperblock},
{name: "trim_large_length_read_returns_zero", run: testTrimLargeLengthReadReturnsZero},
// Task 1.10: Lifecycle tests.
{name: "lifecycle_create_close_reopen", run: testLifecycleCreateCloseReopen},
{name: "lifecycle_close_flushes_dirty", run: testLifecycleCloseFlushes},
{name: "lifecycle_double_close", run: testLifecycleDoubleClose},
{name: "lifecycle_info", run: testLifecycleInfo},
{name: "lifecycle_write_sync_close_reopen", run: testLifecycleWriteSyncCloseReopen},
// Task 1.11: Crash stress test.
{name: "crash_stress_100", run: testCrashStress100},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func createTestVol(t *testing.T) *BlockVol {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "test.blockvol")
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: 1 * 1024 * 1024, // 1MB
BlockSize: 4096,
WALSize: 256 * 1024, // 256KB WAL
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
return v
}
func makeBlock(fill byte) []byte {
b := make([]byte, 4096)
for i := range b {
b[i] = fill
}
return b
}
func testWriteThenRead(t *testing.T) {
v := createTestVol(t)
defer v.Close()
data := makeBlock('A')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("read data does not match written data")
}
}
func testOverwriteReadLatest(t *testing.T) {
v := createTestVol(t)
defer v.Close()
if err := v.WriteLBA(0, makeBlock('A')); err != nil {
t.Fatalf("WriteLBA(A): %v", err)
}
if err := v.WriteLBA(0, makeBlock('B')); err != nil {
t.Fatalf("WriteLBA(B): %v", err)
}
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA: %v", err)
}
if !bytes.Equal(got, makeBlock('B')) {
t.Error("read should return latest write ('B'), not 'A'")
}
}
func testWriteNoSyncNotDurable(t *testing.T) {
v := createTestVol(t)
path := v.Path()
if err := v.WriteLBA(0, makeBlock('A')); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Simulate crash: close fd without sync.
v.fd.Close()
// Reopen -- without recovery (Phase 1.9), data MAY be lost.
// This test just verifies we can reopen without error.
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol after crash: %v", err)
}
defer v2.Close()
// Data may or may not be present -- both are correct without SyncCache.
}
func testWriteSyncDurable(t *testing.T) {
v := createTestVol(t)
path := v.Path()
data := makeBlock('A')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Sync WAL (manual fsync for now, group commit in Task 1.7).
if err := v.wal.Sync(); err != nil {
t.Fatalf("Sync: %v", err)
}
// Update superblock WALHead so reopen knows where entries are.
v.super.WALHead = v.wal.LogicalHead()
v.super.WALCheckpointLSN = 0
if _, err := v.fd.Seek(0, 0); err != nil {
t.Fatalf("Seek: %v", err)
}
if _, err := v.super.WriteTo(v.fd); err != nil {
t.Fatalf("WriteTo: %v", err)
}
v.fd.Sync()
v.fd.Close()
// Reopen and manually replay WAL to verify data is durable.
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
// Manually replay: read WAL entry and populate dirty map.
replayBuf := make([]byte, v2.super.WALHead)
if _, err := v2.fd.ReadAt(replayBuf, int64(v2.super.WALOffset)); err != nil {
t.Fatalf("read WAL for replay: %v", err)
}
entry, err := DecodeWALEntry(replayBuf)
if err != nil {
t.Fatalf("decode WAL entry: %v", err)
}
blocks := entry.Length / v2.super.BlockSize
for i := uint32(0); i < blocks; i++ {
v2.dirtyMap.Put(entry.LBA+uint64(i), 0, entry.LSN, v2.super.BlockSize)
}
got, err := v2.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA after recovery: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("data not durable after sync + reopen")
}
}
func testWriteMultipleSync(t *testing.T) {
v := createTestVol(t)
defer v.Close()
// Write 10 blocks with different data.
for i := uint64(0); i < 10; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
// Read all back.
for i := uint64(0); i < 10; i++ {
got, err := v.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d): %v", i, err)
}
expected := makeBlock(byte('A' + i))
if !bytes.Equal(got, expected) {
t.Errorf("block %d: data mismatch", i)
}
}
}
func testReadUnflushed(t *testing.T) {
v := createTestVol(t)
defer v.Close()
data := makeBlock('X')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Read before flusher runs -- should come from dirty map / WAL.
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("unflushed read: data mismatch")
}
}
func testReadFlushed(t *testing.T) {
v := createTestVol(t)
defer v.Close()
data := makeBlock('F')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Manually flush: copy WAL data to extent region, clear dirty map.
extentStart := v.super.WALOffset + v.super.WALSize
if _, err := v.fd.WriteAt(data, int64(extentStart)); err != nil {
t.Fatalf("manual flush write: %v", err)
}
v.dirtyMap.Delete(0)
// Read should now come from extent region.
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA after flush: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("flushed read: data mismatch")
}
}
func testReadMixedDirtyClean(t *testing.T) {
v := createTestVol(t)
defer v.Close()
// Write blocks 0, 2, 4 (dirty).
for _, lba := range []uint64{0, 2, 4} {
if err := v.WriteLBA(lba, makeBlock(byte('A'+lba))); err != nil {
t.Fatalf("WriteLBA(%d): %v", lba, err)
}
}
// Manually flush block 0 to extent, remove from dirty map.
extentStart := v.super.WALOffset + v.super.WALSize
if _, err := v.fd.WriteAt(makeBlock('A'), int64(extentStart)); err != nil {
t.Fatalf("manual flush: %v", err)
}
v.dirtyMap.Delete(0)
// Block 0: from extent (flushed)
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA(0): %v", err)
}
if !bytes.Equal(got, makeBlock('A')) {
t.Error("block 0 (flushed) mismatch")
}
// Block 2: from dirty map (WAL)
got, err = v.ReadLBA(2, 4096)
if err != nil {
t.Fatalf("ReadLBA(2): %v", err)
}
if !bytes.Equal(got, makeBlock('C')) { // 'A'+2 = 'C'
t.Error("block 2 (dirty) mismatch")
}
// Block 4: from dirty map (WAL)
got, err = v.ReadLBA(4, 4096)
if err != nil {
t.Fatalf("ReadLBA(4): %v", err)
}
if !bytes.Equal(got, makeBlock('E')) { // 'A'+4 = 'E'
t.Error("block 4 (dirty) mismatch")
}
// Blocks 1, 3, 5: never written, should be zeros (from extent region).
for _, lba := range []uint64{1, 3, 5} {
got, err = v.ReadLBA(lba, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d): %v", lba, err)
}
if !bytes.Equal(got, make([]byte, 4096)) {
t.Errorf("block %d (unwritten) should be zeros", lba)
}
}
}
func testWALReadCorruptLength(t *testing.T) {
v := createTestVol(t)
defer v.Close()
// Write a valid block.
if err := v.WriteLBA(0, makeBlock('A')); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Get the WAL offset from dirty map.
walOff, _, _, ok := v.dirtyMap.Get(0)
if !ok {
t.Fatal("block 0 not in dirty map")
}
// Corrupt the Length field in the WAL entry on disk.
// Length is at header offset 26 (LSN=8 + Epoch=8 + Type=1 + Flags=1 + LBA=8).
absOff := int64(v.super.WALOffset + walOff)
lengthOff := absOff + 26
var hugeLenBuf [4]byte
binary.LittleEndian.PutUint32(hugeLenBuf[:], 999999999) // ~1GB
if _, err := v.fd.WriteAt(hugeLenBuf[:], lengthOff); err != nil {
t.Fatalf("corrupt length: %v", err)
}
// ReadLBA should detect the corrupt length and error (not panic/OOM).
_, err := v.ReadLBA(0, 4096)
if err == nil {
t.Error("expected error reading corrupt WAL entry, got nil")
}
}
func testOpenInvalidSuperblock(t *testing.T) {
dir := t.TempDir()
// Create a valid volume, then corrupt BlockSize to 0.
path := filepath.Join(dir, "corrupt.blockvol")
v, err := CreateBlockVol(path, CreateOptions{VolumeSize: 1024 * 1024})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
v.Close()
// Corrupt BlockSize (at superblock offset 36: Magic=4 + Version=2 + Flags=2 + UUID=16 + VolumeSize=8 + ExtentSize=4 = 36).
fd, err := os.OpenFile(path, os.O_RDWR, 0644)
if err != nil {
t.Fatalf("open for corrupt: %v", err)
}
var zeroBuf [4]byte
if _, err := fd.WriteAt(zeroBuf[:], 36); err != nil {
fd.Close()
t.Fatalf("corrupt blocksize: %v", err)
}
fd.Close()
// OpenBlockVol should reject the corrupt superblock.
_, err = OpenBlockVol(path)
if err == nil {
t.Fatal("expected error opening volume with BlockSize=0")
}
if !errors.Is(err, ErrInvalidSuperblock) {
t.Errorf("expected ErrInvalidSuperblock, got: %v", err)
}
}
// --- Task 1.10: Lifecycle tests ---
func testLifecycleCreateCloseReopen(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "lifecycle.blockvol")
// Create, write, close.
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: 1 * 1024 * 1024,
BlockSize: 4096,
WALSize: 256 * 1024,
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
data := makeBlock('L')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
if err := v.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
// Reopen and verify data survived.
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
got, err := v2.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("data not durable after create→write→sync→close→reopen")
}
}
func testLifecycleCloseFlushes(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "flush.blockvol")
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: 1 * 1024 * 1024,
BlockSize: 4096,
WALSize: 256 * 1024,
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
// Write several blocks.
for i := uint64(0); i < 5; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
// Close does a final flush — dirty map should be drained.
if err := v.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
// Reopen and verify.
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
for i := uint64(0); i < 5; i++ {
got, err := v2.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d): %v", i, err)
}
expected := makeBlock(byte('A' + i))
if !bytes.Equal(got, expected) {
t.Errorf("block %d: data mismatch after close+reopen", i)
}
}
}
func testLifecycleDoubleClose(t *testing.T) {
v := createTestVol(t)
if err := v.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
// Second close should not panic (group committer + flusher are idempotent).
// The fd.Close() will return an error but should not panic.
_ = v.Close()
}
func testLifecycleInfo(t *testing.T) {
v := createTestVol(t)
defer v.Close()
info := v.Info()
if info.VolumeSize != 1*1024*1024 {
t.Errorf("VolumeSize = %d, want %d", info.VolumeSize, 1*1024*1024)
}
if info.BlockSize != 4096 {
t.Errorf("BlockSize = %d, want 4096", info.BlockSize)
}
if !info.Healthy {
t.Error("Healthy = false, want true")
}
}
func testLifecycleWriteSyncCloseReopen(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "wsco.blockvol")
// Cycle: create → write → sync → close → reopen → write → sync → close → verify.
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: 1 * 1024 * 1024,
BlockSize: 4096,
WALSize: 256 * 1024,
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
if err := v.WriteLBA(0, makeBlock('X')); err != nil {
t.Fatalf("WriteLBA round 1: %v", err)
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache round 1: %v", err)
}
v.Close()
// Reopen, write more.
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol round 2: %v", err)
}
if err := v2.WriteLBA(1, makeBlock('Y')); err != nil {
t.Fatalf("WriteLBA round 2: %v", err)
}
// Overwrite block 0.
if err := v2.WriteLBA(0, makeBlock('Z')); err != nil {
t.Fatalf("WriteLBA overwrite: %v", err)
}
if err := v2.SyncCache(); err != nil {
t.Fatalf("SyncCache round 2: %v", err)
}
v2.Close()
// Final reopen — verify.
v3, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol round 3: %v", err)
}
defer v3.Close()
got0, _ := v3.ReadLBA(0, 4096)
if !bytes.Equal(got0, makeBlock('Z')) {
t.Error("block 0: expected 'Z' (overwritten)")
}
got1, _ := v3.ReadLBA(1, 4096)
if !bytes.Equal(got1, makeBlock('Y')) {
t.Error("block 1: expected 'Y'")
}
}
// --- Task 1.11: Crash stress test ---
func testCrashStress100(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "stress.blockvol")
const (
volumeSize = 256 * 1024 // 256KB volume (64 blocks of 4KB)
blockSize = 4096
walSize = 64 * 1024 // 64KB WAL
maxLBA = volumeSize / blockSize
iterations = 100
)
// Oracle: tracks the expected state of each block.
oracle := make(map[uint64]byte) // lba → fill byte (0 = zeros/trimmed)
// Create initial volume.
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: volumeSize,
BlockSize: blockSize,
WALSize: walSize,
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
for iter := 0; iter < iterations; iter++ {
// Deterministic "random" ops using iteration number.
numOps := 3 + (iter % 5) // 3-7 ops per iteration
for op := 0; op < numOps; op++ {
lba := uint64((iter*7 + op*13) % maxLBA)
action := (iter + op) % 3 // 0=write, 1=overwrite, 2=trim
switch action {
case 0, 1: // write
fill := byte('A' + (iter+op)%26)
err := v.WriteLBA(lba, makeBlock(fill))
if err != nil {
if errors.Is(err, ErrWALFull) {
continue // WAL full, skip this op
}
t.Fatalf("iter %d op %d: WriteLBA(%d): %v", iter, op, lba, err)
}
oracle[lba] = fill
case 2: // trim
err := v.Trim(lba, blockSize)
if err != nil {
if errors.Is(err, ErrWALFull) {
continue
}
t.Fatalf("iter %d op %d: Trim(%d): %v", iter, op, lba, err)
}
oracle[lba] = 0
}
}
// Sync WAL.
if err := v.SyncCache(); err != nil {
t.Fatalf("iter %d: SyncCache: %v", iter, err)
}
// Simulate crash: close fd without clean shutdown.
v.fd.Sync() // ensure WAL is on disk
// Write superblock with current WAL positions for recovery.
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
if _, seekErr := v.fd.Seek(0, 0); seekErr != nil {
t.Fatalf("iter %d: Seek: %v", iter, seekErr)
}
v.super.WriteTo(v.fd)
v.fd.Sync()
// Hard crash: close fd, stop goroutines.
v.groupCommit.Stop()
v.flusher.Stop()
v.fd.Close()
// Reopen with recovery.
v, err = OpenBlockVol(path)
if err != nil {
t.Fatalf("iter %d: OpenBlockVol: %v", iter, err)
}
// Verify oracle against actual reads.
for lba, fill := range oracle {
got, readErr := v.ReadLBA(lba, blockSize)
if readErr != nil {
t.Fatalf("iter %d: ReadLBA(%d): %v", iter, lba, readErr)
}
var expected []byte
if fill == 0 {
expected = make([]byte, blockSize)
} else {
expected = makeBlock(fill)
}
if !bytes.Equal(got, expected) {
t.Fatalf("iter %d: block %d mismatch: got[0]=%d want[0]=%d", iter, lba, got[0], expected[0])
}
}
}
v.Close()
}
func testTrimLargeLengthReadReturnsZero(t *testing.T) {
v := createTestVol(t)
defer v.Close()
// Write data first.
data := makeBlock('A')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Verify data is written.
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA before trim: %v", err)
}
if !bytes.Equal(got, data) {
t.Fatal("data mismatch before trim")
}
// Trim with a length larger than WAL size — should still work.
// The trim Length is metadata (trim extent), not a data allocation.
if err := v.Trim(0, 4096); err != nil {
t.Fatalf("Trim: %v", err)
}
// Read should return zeros (TRIM entry in dirty map).
got, err = v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA after trim: %v", err)
}
if !bytes.Equal(got, make([]byte, 4096)) {
t.Error("read after trim should return zeros")
}
}

110
weed/storage/blockvol/dirty_map.go

@ -0,0 +1,110 @@
package blockvol
import "sync"
// dirtyEntry tracks a single LBA's location in the WAL.
type dirtyEntry struct {
walOffset uint64
lsn uint64
length uint32
}
// DirtyMap is a concurrency-safe map from LBA to WAL offset.
// Phase 1 uses a single shard; Phase 3 upgrades to 256 shards.
type DirtyMap struct {
mu sync.RWMutex
m map[uint64]dirtyEntry
}
// NewDirtyMap creates an empty dirty map.
func NewDirtyMap() *DirtyMap {
return &DirtyMap{m: make(map[uint64]dirtyEntry)}
}
// Put records that the given LBA has dirty data at walOffset.
func (d *DirtyMap) Put(lba uint64, walOffset uint64, lsn uint64, length uint32) {
d.mu.Lock()
d.m[lba] = dirtyEntry{walOffset: walOffset, lsn: lsn, length: length}
d.mu.Unlock()
}
// Get returns the dirty entry for lba. ok is false if lba is not dirty.
func (d *DirtyMap) Get(lba uint64) (walOffset uint64, lsn uint64, length uint32, ok bool) {
d.mu.RLock()
e, found := d.m[lba]
d.mu.RUnlock()
if !found {
return 0, 0, 0, false
}
return e.walOffset, e.lsn, e.length, true
}
// Delete removes a single LBA from the dirty map.
func (d *DirtyMap) Delete(lba uint64) {
d.mu.Lock()
delete(d.m, lba)
d.mu.Unlock()
}
// rangeEntry is a snapshot of one dirty entry for lock-free iteration.
type rangeEntry struct {
lba uint64
walOffset uint64
lsn uint64
length uint32
}
// Range calls fn for each dirty entry with LBA in [start, start+count).
// Entries are copied under lock, then fn is called without holding the lock,
// so fn may safely call back into DirtyMap (Put, Delete, etc.).
func (d *DirtyMap) Range(start uint64, count uint32, fn func(lba, walOffset, lsn uint64, length uint32)) {
end := start + uint64(count)
// Snapshot matching entries under lock.
d.mu.RLock()
entries := make([]rangeEntry, 0, len(d.m))
for lba, e := range d.m {
if lba >= start && lba < end {
entries = append(entries, rangeEntry{lba, e.walOffset, e.lsn, e.length})
}
}
d.mu.RUnlock()
// Call fn without lock.
for _, e := range entries {
fn(e.lba, e.walOffset, e.lsn, e.length)
}
}
// Len returns the number of dirty entries.
func (d *DirtyMap) Len() int {
d.mu.RLock()
n := len(d.m)
d.mu.RUnlock()
return n
}
// SnapshotEntry is an exported snapshot of one dirty entry.
type SnapshotEntry struct {
Lba uint64
WalOffset uint64
Lsn uint64
Length uint32
}
// Snapshot returns a copy of all dirty entries. The snapshot is taken under
// read lock but returned without holding the lock.
func (d *DirtyMap) Snapshot() []SnapshotEntry {
d.mu.RLock()
entries := make([]SnapshotEntry, 0, len(d.m))
for lba, e := range d.m {
entries = append(entries, SnapshotEntry{
Lba: lba,
WalOffset: e.walOffset,
Lsn: e.lsn,
Length: e.length,
})
}
d.mu.RUnlock()
return entries
}

153
weed/storage/blockvol/dirty_map_test.go

@ -0,0 +1,153 @@
package blockvol
import (
"sync"
"testing"
)
func TestDirtyMap(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "dirty_put_get", run: testDirtyPutGet},
{name: "dirty_overwrite", run: testDirtyOverwrite},
{name: "dirty_delete", run: testDirtyDelete},
{name: "dirty_range_query", run: testDirtyRangeQuery},
{name: "dirty_empty", run: testDirtyEmpty},
{name: "dirty_concurrent_rw", run: testDirtyConcurrentRW},
{name: "dirty_range_modify", run: testDirtyRangeModify},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testDirtyPutGet(t *testing.T) {
dm := NewDirtyMap()
dm.Put(100, 5000, 1, 4096)
off, lsn, length, ok := dm.Get(100)
if !ok {
t.Fatal("Get(100) returned not-found")
}
if off != 5000 {
t.Errorf("walOffset = %d, want 5000", off)
}
if lsn != 1 {
t.Errorf("lsn = %d, want 1", lsn)
}
if length != 4096 {
t.Errorf("length = %d, want 4096", length)
}
}
func testDirtyOverwrite(t *testing.T) {
dm := NewDirtyMap()
dm.Put(100, 5000, 1, 4096)
dm.Put(100, 9000, 2, 4096)
off, lsn, _, ok := dm.Get(100)
if !ok {
t.Fatal("Get(100) returned not-found")
}
if off != 9000 {
t.Errorf("walOffset = %d, want 9000 (latest)", off)
}
if lsn != 2 {
t.Errorf("lsn = %d, want 2 (latest)", lsn)
}
}
func testDirtyDelete(t *testing.T) {
dm := NewDirtyMap()
dm.Put(100, 5000, 1, 4096)
dm.Delete(100)
_, _, _, ok := dm.Get(100)
if ok {
t.Error("Get(100) should return not-found after delete")
}
}
func testDirtyRangeQuery(t *testing.T) {
dm := NewDirtyMap()
for i := uint64(100); i < 110; i++ {
dm.Put(i, i*1000, i, 4096)
}
found := make(map[uint64]bool)
dm.Range(100, 10, func(lba, walOffset, lsn uint64, length uint32) {
found[lba] = true
})
if len(found) != 10 {
t.Errorf("Range returned %d entries, want 10", len(found))
}
for i := uint64(100); i < 110; i++ {
if !found[i] {
t.Errorf("Range missing LBA %d", i)
}
}
}
func testDirtyEmpty(t *testing.T) {
dm := NewDirtyMap()
_, _, _, ok := dm.Get(0)
if ok {
t.Error("empty map: Get(0) should return not-found")
}
_, _, _, ok = dm.Get(999)
if ok {
t.Error("empty map: Get(999) should return not-found")
}
if dm.Len() != 0 {
t.Errorf("empty map: Len() = %d, want 0", dm.Len())
}
}
func testDirtyConcurrentRW(t *testing.T) {
dm := NewDirtyMap()
const goroutines = 16
const opsPerGoroutine = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func(id int) {
defer wg.Done()
base := uint64(id * opsPerGoroutine)
for i := uint64(0); i < opsPerGoroutine; i++ {
lba := base + i
dm.Put(lba, lba*10, lba, 4096)
dm.Get(lba)
if i%3 == 0 {
dm.Delete(lba)
}
}
}(g)
}
wg.Wait()
// No assertion on final state -- the test passes if no race detector fires.
}
func testDirtyRangeModify(t *testing.T) {
// Verify that Range callback can safely call Delete without deadlock.
dm := NewDirtyMap()
for i := uint64(0); i < 10; i++ {
dm.Put(i, i*1000, i+1, 4096)
}
// Delete every entry from within the Range callback.
dm.Range(0, 10, func(lba, walOffset, lsn uint64, length uint32) {
dm.Delete(lba)
})
if dm.Len() != 0 {
t.Errorf("after Range+Delete: Len() = %d, want 0", dm.Len())
}
}

239
weed/storage/blockvol/flusher.go

@ -0,0 +1,239 @@
package blockvol
import (
"encoding/binary"
"fmt"
"os"
"sync"
"time"
)
// Flusher copies WAL entries to the extent region and frees WAL space.
// It runs as a background goroutine and can also be triggered manually.
type Flusher struct {
fd *os.File
super *Superblock
wal *WALWriter
dirtyMap *DirtyMap
walOffset uint64 // absolute file offset of WAL region
walSize uint64
blockSize uint32
extentStart uint64 // absolute file offset of extent region
mu sync.Mutex
checkpointLSN uint64 // last flushed LSN
checkpointTail uint64 // WAL physical tail after last flush
interval time.Duration
notifyCh chan struct{}
stopCh chan struct{}
done chan struct{}
stopOnce sync.Once
}
// FlusherConfig configures the flusher.
type FlusherConfig struct {
FD *os.File
Super *Superblock
WAL *WALWriter
DirtyMap *DirtyMap
Interval time.Duration // default 100ms
}
// NewFlusher creates a flusher. Call Run() in a goroutine.
func NewFlusher(cfg FlusherConfig) *Flusher {
if cfg.Interval == 0 {
cfg.Interval = 100 * time.Millisecond
}
return &Flusher{
fd: cfg.FD,
super: cfg.Super,
wal: cfg.WAL,
dirtyMap: cfg.DirtyMap,
walOffset: cfg.Super.WALOffset,
walSize: cfg.Super.WALSize,
blockSize: cfg.Super.BlockSize,
extentStart: cfg.Super.WALOffset + cfg.Super.WALSize,
checkpointLSN: cfg.Super.WALCheckpointLSN,
checkpointTail: 0,
interval: cfg.Interval,
notifyCh: make(chan struct{}, 1),
stopCh: make(chan struct{}),
done: make(chan struct{}),
}
}
// Run is the flusher main loop. Call in a goroutine.
func (f *Flusher) Run() {
defer close(f.done)
ticker := time.NewTicker(f.interval)
defer ticker.Stop()
for {
select {
case <-f.stopCh:
return
case <-ticker.C:
f.FlushOnce()
case <-f.notifyCh:
f.FlushOnce()
}
}
}
// Notify wakes up the flusher for an immediate flush cycle.
func (f *Flusher) Notify() {
select {
case f.notifyCh <- struct{}{}:
default:
}
}
// Stop shuts down the flusher. Safe to call multiple times.
func (f *Flusher) Stop() {
f.stopOnce.Do(func() {
close(f.stopCh)
})
<-f.done
}
// FlushOnce performs a single flush cycle: scan dirty map, copy data to
// extent region, fsync, update checkpoint, advance WAL tail.
func (f *Flusher) FlushOnce() error {
// Snapshot dirty entries. We use a full scan (Range over all possible LBAs
// is impractical), so we collect from the dirty map directly.
type flushEntry struct {
lba uint64
walOff uint64
lsn uint64
length uint32
}
entries := f.dirtyMap.Snapshot()
if len(entries) == 0 {
return nil
}
// Find the max LSN and max WAL offset to know where to advance tail.
var maxLSN uint64
var maxWALEnd uint64
for _, e := range entries {
// Read the WAL entry and copy data to extent region.
headerBuf := make([]byte, walEntryHeaderSize)
absWALOff := int64(f.walOffset + e.WalOffset)
if _, err := f.fd.ReadAt(headerBuf, absWALOff); err != nil {
return fmt.Errorf("flusher: read WAL header at %d: %w", absWALOff, err)
}
// Parse entry type and length.
entryType := headerBuf[16] // Type at LSN(8)+Epoch(8)=16
dataLen := parseLength(headerBuf)
if entryType == EntryTypeWrite && dataLen > 0 {
// Read full entry.
entryLen := walEntryHeaderSize + int(dataLen)
fullBuf := make([]byte, entryLen)
if _, err := f.fd.ReadAt(fullBuf, absWALOff); err != nil {
return fmt.Errorf("flusher: read WAL entry at %d: %w", absWALOff, err)
}
entry, err := DecodeWALEntry(fullBuf)
if err != nil {
return fmt.Errorf("flusher: decode WAL entry: %w", err)
}
// Write data to extent region.
blocks := entry.Length / f.blockSize
for i := uint32(0); i < blocks; i++ {
blockLBA := entry.LBA + uint64(i)
extentOff := int64(f.extentStart + blockLBA*uint64(f.blockSize))
blockData := entry.Data[i*f.blockSize : (i+1)*f.blockSize]
if _, err := f.fd.WriteAt(blockData, extentOff); err != nil {
return fmt.Errorf("flusher: write extent at LBA %d: %w", blockLBA, err)
}
}
// Track WAL end position for tail advance.
walEnd := e.WalOffset + uint64(entryLen)
if walEnd > maxWALEnd {
maxWALEnd = walEnd
}
} else if entryType == EntryTypeTrim {
// TRIM entries: zero the extent region for this LBA.
// Each dirty map entry represents one trimmed block.
zeroBlock := make([]byte, f.blockSize)
extentOff := int64(f.extentStart + e.Lba*uint64(f.blockSize))
if _, err := f.fd.WriteAt(zeroBlock, extentOff); err != nil {
return fmt.Errorf("flusher: zero extent at LBA %d: %w", e.Lba, err)
}
// TRIM entry has no data payload, just a header.
walEnd := e.WalOffset + uint64(walEntryHeaderSize)
if walEnd > maxWALEnd {
maxWALEnd = walEnd
}
}
if e.Lsn > maxLSN {
maxLSN = e.Lsn
}
}
// Fsync extent writes.
if err := f.fd.Sync(); err != nil {
return fmt.Errorf("flusher: fsync extent: %w", err)
}
// Remove flushed entries from dirty map.
f.mu.Lock()
for _, e := range entries {
// Only remove if the dirty map entry still has the same LSN
// (a newer write may have updated it).
_, currentLSN, _, ok := f.dirtyMap.Get(e.Lba)
if ok && currentLSN == e.Lsn {
f.dirtyMap.Delete(e.Lba)
}
}
f.checkpointLSN = maxLSN
f.checkpointTail = maxWALEnd
f.mu.Unlock()
// Advance WAL tail to free space.
if maxWALEnd > 0 {
f.wal.AdvanceTail(maxWALEnd)
}
// Update superblock checkpoint.
f.updateSuperblockCheckpoint(maxLSN, f.wal.Tail())
return nil
}
// updateSuperblockCheckpoint writes the updated checkpoint to disk.
func (f *Flusher) updateSuperblockCheckpoint(checkpointLSN uint64, walTail uint64) error {
f.super.WALCheckpointLSN = checkpointLSN
f.super.WALHead = f.wal.LogicalHead()
f.super.WALTail = f.wal.LogicalTail()
if _, err := f.fd.Seek(0, 0); err != nil {
return fmt.Errorf("flusher: seek to superblock: %w", err)
}
if _, err := f.super.WriteTo(f.fd); err != nil {
return fmt.Errorf("flusher: write superblock: %w", err)
}
return f.fd.Sync()
}
// CheckpointLSN returns the last flushed LSN.
func (f *Flusher) CheckpointLSN() uint64 {
f.mu.Lock()
defer f.mu.Unlock()
return f.checkpointLSN
}
// parseLength extracts the Length field from a WAL entry header buffer.
func parseLength(headerBuf []byte) uint32 {
// Length at LSN(8)+Epoch(8)+Type(1)+Flags(1)+LBA(8) = 26
return binary.LittleEndian.Uint32(headerBuf[26:])
}

261
weed/storage/blockvol/flusher_test.go

@ -0,0 +1,261 @@
package blockvol
import (
"bytes"
"path/filepath"
"testing"
"time"
)
func TestFlusher(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "flush_moves_data", run: testFlushMovesData},
{name: "flush_idempotent", run: testFlushIdempotent},
{name: "flush_concurrent_writes", run: testFlushConcurrentWrites},
{name: "flush_frees_wal_space", run: testFlushFreesWALSpace},
{name: "flush_partial", run: testFlushPartial},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func createTestVolWithFlusher(t *testing.T) (*BlockVol, *Flusher) {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "test.blockvol")
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: 1 * 1024 * 1024, // 1MB
BlockSize: 4096,
WALSize: 256 * 1024, // 256KB WAL
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
f := NewFlusher(FlusherConfig{
FD: v.fd,
Super: &v.super,
WAL: v.wal,
DirtyMap: v.dirtyMap,
Interval: 1 * time.Hour, // don't auto-flush in tests
})
return v, f
}
func testFlushMovesData(t *testing.T) {
v, f := createTestVolWithFlusher(t)
defer v.Close()
// Write 10 blocks.
for i := uint64(0); i < 10; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
if v.dirtyMap.Len() != 10 {
t.Fatalf("dirty map len = %d, want 10", v.dirtyMap.Len())
}
// Run flusher.
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce: %v", err)
}
// Dirty map should be empty.
if v.dirtyMap.Len() != 0 {
t.Errorf("after flush: dirty map len = %d, want 0", v.dirtyMap.Len())
}
// Checkpoint should have advanced.
if f.CheckpointLSN() == 0 {
t.Error("checkpoint LSN should be > 0 after flush")
}
// Read from extent (dirty map is empty, so reads go to extent).
for i := uint64(0); i < 10; i++ {
got, err := v.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d) after flush: %v", i, err)
}
if !bytes.Equal(got, makeBlock(byte('A'+i))) {
t.Errorf("block %d: data mismatch after flush", i)
}
}
}
func testFlushIdempotent(t *testing.T) {
v, f := createTestVolWithFlusher(t)
defer v.Close()
data := makeBlock('X')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Flush twice.
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce 1: %v", err)
}
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce 2: %v", err)
}
// Data should still be correct.
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA after double flush: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("data mismatch after double flush")
}
}
func testFlushConcurrentWrites(t *testing.T) {
v, f := createTestVolWithFlusher(t)
defer v.Close()
// Write blocks 0-4.
for i := uint64(0); i < 5; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
// Flush (moves blocks 0-4 to extent).
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce: %v", err)
}
// Write blocks 5-9 AFTER flush.
for i := uint64(5); i < 10; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
// Blocks 0-4 should read from extent, blocks 5-9 from WAL.
for i := uint64(0); i < 10; i++ {
got, err := v.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d): %v", i, err)
}
if !bytes.Equal(got, makeBlock(byte('A'+i))) {
t.Errorf("block %d: data mismatch", i)
}
}
// Dirty map should have 5 entries (blocks 5-9).
if v.dirtyMap.Len() != 5 {
t.Errorf("dirty map len = %d, want 5", v.dirtyMap.Len())
}
// Also: overwrite block 0 after flush -- new write should go to WAL.
newData := makeBlock('Z')
if err := v.WriteLBA(0, newData); err != nil {
t.Fatalf("WriteLBA(0) overwrite: %v", err)
}
got, err := v.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA(0) after overwrite: %v", err)
}
if !bytes.Equal(got, newData) {
t.Error("block 0: should return overwritten data 'Z'")
}
}
func testFlushFreesWALSpace(t *testing.T) {
v, f := createTestVolWithFlusher(t)
defer v.Close()
// Write enough blocks to fill a significant portion of WAL.
entrySize := uint64(walEntryHeaderSize + 4096)
walCapacity := v.super.WALSize / entrySize
// Write ~80% of capacity.
writeCount := int(walCapacity * 80 / 100)
for i := 0; i < writeCount; i++ {
if err := v.WriteLBA(uint64(i), makeBlock(byte(i%26+'A'))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
// Try to write more -- should eventually fail with WAL full.
var walFullBefore bool
for i := writeCount; i < writeCount+int(walCapacity); i++ {
if err := v.WriteLBA(uint64(i%writeCount), makeBlock('X')); err != nil {
walFullBefore = true
break
}
}
// Flush to free WAL space.
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce: %v", err)
}
// WAL tail should have advanced (free space available).
// New writes should succeed.
if err := v.WriteLBA(0, makeBlock('Y')); err != nil {
t.Fatalf("WriteLBA after flush: %v", err)
}
// Log whether WAL was full before flush.
if walFullBefore {
t.Log("WAL was full before flush, writes succeeded after flush")
}
}
func testFlushPartial(t *testing.T) {
v, f := createTestVolWithFlusher(t)
defer v.Close()
// Write blocks 0-4.
for i := uint64(0); i < 5; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
// Flush once (all 5 blocks).
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce: %v", err)
}
checkpointAfterFirst := f.CheckpointLSN()
// Write blocks 5-9.
for i := uint64(5); i < 10; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
// Simulate partial flush: flusher runs again, should handle new entries.
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce 2: %v", err)
}
checkpointAfterSecond := f.CheckpointLSN()
if checkpointAfterSecond <= checkpointAfterFirst {
t.Errorf("checkpoint should advance: first=%d, second=%d", checkpointAfterFirst, checkpointAfterSecond)
}
// All blocks should be readable from extent.
for i := uint64(0); i < 10; i++ {
got, err := v.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d) after two flushes: %v", i, err)
}
if !bytes.Equal(got, makeBlock(byte('A'+i))) {
t.Errorf("block %d: data mismatch after two flushes", i)
}
}
}

167
weed/storage/blockvol/group_commit.go

@ -0,0 +1,167 @@
package blockvol
import (
"errors"
"sync"
"sync/atomic"
"time"
)
var ErrGroupCommitShutdown = errors.New("blockvol: group committer shut down")
// GroupCommitter batches SyncCache requests and performs a single fsync
// for the entire batch. This amortizes the cost of fsync across many callers.
type GroupCommitter struct {
syncFunc func() error // called to fsync (injectable for testing)
maxDelay time.Duration // max wait before flushing a partial batch
maxBatch int // flush immediately when this many waiters accumulate
onDegraded func() // called when fsync fails
mu sync.Mutex
pending []chan error
stopped bool // set under mu by Run() before draining
notifyCh chan struct{}
stopCh chan struct{}
done chan struct{}
stopOnce sync.Once
syncCount atomic.Uint64 // number of fsyncs performed (for testing)
}
// GroupCommitterConfig configures the group committer.
type GroupCommitterConfig struct {
SyncFunc func() error // required: the fsync function
MaxDelay time.Duration // default 1ms
MaxBatch int // default 64
OnDegraded func() // optional: called on fsync error
}
// NewGroupCommitter creates a new group committer. Call Run() to start it.
func NewGroupCommitter(cfg GroupCommitterConfig) *GroupCommitter {
if cfg.MaxDelay == 0 {
cfg.MaxDelay = 1 * time.Millisecond
}
if cfg.MaxBatch == 0 {
cfg.MaxBatch = 64
}
if cfg.OnDegraded == nil {
cfg.OnDegraded = func() {}
}
return &GroupCommitter{
syncFunc: cfg.SyncFunc,
maxDelay: cfg.MaxDelay,
maxBatch: cfg.MaxBatch,
onDegraded: cfg.OnDegraded,
notifyCh: make(chan struct{}, 1),
stopCh: make(chan struct{}),
done: make(chan struct{}),
}
}
// Run is the main loop. Call this in a goroutine.
func (gc *GroupCommitter) Run() {
defer close(gc.done)
for {
// Wait for first waiter or shutdown.
select {
case <-gc.stopCh:
gc.markStoppedAndDrain()
return
case <-gc.notifyCh:
}
// Collect batch: wait up to maxDelay for more waiters, or until maxBatch reached.
deadline := time.NewTimer(gc.maxDelay)
for {
gc.mu.Lock()
n := len(gc.pending)
gc.mu.Unlock()
if n >= gc.maxBatch {
deadline.Stop()
break
}
select {
case <-gc.stopCh:
deadline.Stop()
gc.markStoppedAndDrain()
return
case <-deadline.C:
goto flush
case <-gc.notifyCh:
continue
}
}
flush:
// Take all pending waiters.
gc.mu.Lock()
batch := gc.pending
gc.pending = nil
gc.mu.Unlock()
if len(batch) == 0 {
continue
}
// Perform fsync.
err := gc.syncFunc()
gc.syncCount.Add(1)
if err != nil {
gc.onDegraded()
}
// Wake all waiters.
for _, ch := range batch {
ch <- err
}
}
}
// Submit submits a sync request and blocks until the batch fsync completes.
// Returns nil on success, the fsync error, or ErrGroupCommitShutdown if stopped.
func (gc *GroupCommitter) Submit() error {
ch := make(chan error, 1)
gc.mu.Lock()
if gc.stopped {
gc.mu.Unlock()
return ErrGroupCommitShutdown
}
gc.pending = append(gc.pending, ch)
gc.mu.Unlock()
// Non-blocking notify to wake Run().
select {
case gc.notifyCh <- struct{}{}:
default:
}
return <-ch
}
// Stop shuts down the group committer. Pending waiters receive ErrGroupCommitShutdown.
// Safe to call multiple times.
func (gc *GroupCommitter) Stop() {
gc.stopOnce.Do(func() {
close(gc.stopCh)
})
<-gc.done
}
// SyncCount returns the number of fsyncs performed (for testing).
func (gc *GroupCommitter) SyncCount() uint64 {
return gc.syncCount.Load()
}
// markStoppedAndDrain sets the stopped flag under mu and drains all pending
// waiters with ErrGroupCommitShutdown. This ensures no new Submit() can
// enqueue after we drain, closing the race window from QA-002.
func (gc *GroupCommitter) markStoppedAndDrain() {
gc.mu.Lock()
gc.stopped = true
batch := gc.pending
gc.pending = nil
gc.mu.Unlock()
for _, ch := range batch {
ch <- ErrGroupCommitShutdown
}
}

287
weed/storage/blockvol/group_commit_test.go

@ -0,0 +1,287 @@
package blockvol
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestGroupCommitter(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "group_single_barrier", run: testGroupSingleBarrier},
{name: "group_batch_10", run: testGroupBatch10},
{name: "group_max_delay", run: testGroupMaxDelay},
{name: "group_max_batch", run: testGroupMaxBatch},
{name: "group_fsync_error", run: testGroupFsyncError},
{name: "group_sequential", run: testGroupSequential},
{name: "group_shutdown", run: testGroupShutdown},
{name: "group_submit_after_stop", run: testGroupSubmitAfterStop},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testGroupSingleBarrier(t *testing.T) {
var syncCalls atomic.Uint64
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error {
syncCalls.Add(1)
return nil
},
MaxDelay: 10 * time.Millisecond,
})
go gc.Run()
defer gc.Stop()
if err := gc.Submit(); err != nil {
t.Fatalf("Submit: %v", err)
}
if c := syncCalls.Load(); c != 1 {
t.Errorf("syncCalls = %d, want 1", c)
}
}
func testGroupBatch10(t *testing.T) {
var syncCalls atomic.Uint64
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error {
syncCalls.Add(1)
return nil
},
MaxDelay: 50 * time.Millisecond,
MaxBatch: 64,
})
go gc.Run()
defer gc.Stop()
const n = 10
var wg sync.WaitGroup
errs := make([]error, n)
// Launch all 10 concurrently.
wg.Add(n)
for i := 0; i < n; i++ {
go func(idx int) {
defer wg.Done()
errs[idx] = gc.Submit()
}(i)
}
wg.Wait()
for i, err := range errs {
if err != nil {
t.Errorf("Submit[%d]: %v", i, err)
}
}
// All 10 should be batched into 1 fsync (maybe 2 if timing is unlucky).
if c := syncCalls.Load(); c > 2 {
t.Errorf("syncCalls = %d, want 1-2 (batched)", c)
}
}
func testGroupMaxDelay(t *testing.T) {
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error { return nil },
MaxDelay: 5 * time.Millisecond,
})
go gc.Run()
defer gc.Stop()
start := time.Now()
if err := gc.Submit(); err != nil {
t.Fatalf("Submit: %v", err)
}
elapsed := time.Since(start)
// Should complete within ~maxDelay + some margin, not take much longer.
if elapsed > 50*time.Millisecond {
t.Errorf("Submit took %v, expected ~5ms", elapsed)
}
}
func testGroupMaxBatch(t *testing.T) {
var syncCalls atomic.Uint64
const batch = 64
// Use a slow sync to ensure we can detect immediate trigger.
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error {
syncCalls.Add(1)
return nil
},
MaxDelay: 5 * time.Second, // very long delay -- should NOT wait this long
MaxBatch: batch,
})
go gc.Run()
defer gc.Stop()
var wg sync.WaitGroup
wg.Add(batch)
for i := 0; i < batch; i++ {
go func() {
defer wg.Done()
gc.Submit()
}()
}
// Wait with timeout -- if maxBatch triggers, should complete fast.
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Good -- completed without waiting 5 seconds.
case <-time.After(2 * time.Second):
t.Fatal("maxBatch=64 did not trigger immediate flush")
}
}
func testGroupFsyncError(t *testing.T) {
errIO := errors.New("simulated EIO")
var degraded atomic.Bool
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error {
return errIO
},
MaxDelay: 10 * time.Millisecond,
OnDegraded: func() { degraded.Store(true) },
})
go gc.Run()
defer gc.Stop()
const n = 5
var wg sync.WaitGroup
errs := make([]error, n)
wg.Add(n)
for i := 0; i < n; i++ {
go func(idx int) {
defer wg.Done()
errs[idx] = gc.Submit()
}(i)
}
wg.Wait()
for i, err := range errs {
if !errors.Is(err, errIO) {
t.Errorf("Submit[%d]: got %v, want %v", i, err, errIO)
}
}
if !degraded.Load() {
t.Error("onDegraded was not called")
}
}
func testGroupSequential(t *testing.T) {
var syncCalls atomic.Uint64
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error {
syncCalls.Add(1)
return nil
},
MaxDelay: 10 * time.Millisecond,
})
go gc.Run()
defer gc.Stop()
// First sync.
if err := gc.Submit(); err != nil {
t.Fatalf("Submit 1: %v", err)
}
// Second sync (after first completes).
if err := gc.Submit(); err != nil {
t.Fatalf("Submit 2: %v", err)
}
if c := syncCalls.Load(); c != 2 {
t.Errorf("syncCalls = %d, want 2 (two separate fsyncs)", c)
}
}
func testGroupShutdown(t *testing.T) {
// Block fsync so we can shut down while waiters are pending.
syncStarted := make(chan struct{})
syncBlock := make(chan struct{})
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error {
close(syncStarted)
<-syncBlock // block until test releases
return nil
},
MaxDelay: 1 * time.Millisecond,
})
go gc.Run()
// Submit one request that will block on fsync.
errCh1 := make(chan error, 1)
go func() {
errCh1 <- gc.Submit()
}()
// Wait for fsync to start.
<-syncStarted
// Submit another request while fsync is in progress.
errCh2 := make(chan error, 1)
go func() {
errCh2 <- gc.Submit()
}()
time.Sleep(5 * time.Millisecond) // let it enqueue
// Stop the group committer (will drain pending with ErrGroupCommitShutdown).
go func() {
time.Sleep(10 * time.Millisecond)
close(syncBlock) // unblock the fsync first
}()
gc.Stop()
// First waiter should get nil (fsync succeeded).
if err := <-errCh1; err != nil {
t.Errorf("first waiter: %v, want nil", err)
}
// Second waiter should get shutdown error.
if err := <-errCh2; !errors.Is(err, ErrGroupCommitShutdown) {
t.Errorf("second waiter: %v, want ErrGroupCommitShutdown", err)
}
}
func testGroupSubmitAfterStop(t *testing.T) {
gc := NewGroupCommitter(GroupCommitterConfig{
SyncFunc: func() error { return nil },
MaxDelay: 10 * time.Millisecond,
})
go gc.Run()
gc.Stop()
// Submit after Stop must return ErrGroupCommitShutdown, not deadlock.
done := make(chan error, 1)
go func() {
done <- gc.Submit()
}()
select {
case err := <-done:
if !errors.Is(err, ErrGroupCommitShutdown) {
t.Errorf("Submit after Stop: %v, want ErrGroupCommitShutdown", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Submit after Stop deadlocked")
}
}

141
weed/storage/blockvol/iscsi/cmd/iscsi-target/main.go

@ -0,0 +1,141 @@
// iscsi-target is a standalone iSCSI target backed by a BlockVol file.
// Usage:
//
// iscsi-target -vol /path/to/volume.blk -addr :3260 -iqn iqn.2024.com.seaweedfs:vol1
// iscsi-target -create -size 1G -vol /path/to/volume.blk -addr :3260
package main
import (
"flag"
"fmt"
"log"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol"
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi"
)
func main() {
volPath := flag.String("vol", "", "path to BlockVol file")
addr := flag.String("addr", ":3260", "listen address")
iqn := flag.String("iqn", "iqn.2024.com.seaweedfs:vol1", "target IQN")
create := flag.Bool("create", false, "create a new volume file")
size := flag.String("size", "1G", "volume size (e.g., 1G, 100M) — used with -create")
flag.Parse()
if *volPath == "" {
fmt.Fprintln(os.Stderr, "error: -vol is required")
flag.Usage()
os.Exit(1)
}
logger := log.New(os.Stdout, "[iscsi] ", log.LstdFlags)
var vol *blockvol.BlockVol
var err error
if *create {
volSize, parseErr := parseSize(*size)
if parseErr != nil {
log.Fatalf("invalid size %q: %v", *size, parseErr)
}
vol, err = blockvol.CreateBlockVol(*volPath, blockvol.CreateOptions{
VolumeSize: volSize,
BlockSize: 4096,
WALSize: 64 * 1024 * 1024,
})
if err != nil {
log.Fatalf("create volume: %v", err)
}
logger.Printf("created volume: %s (%s)", *volPath, *size)
} else {
vol, err = blockvol.OpenBlockVol(*volPath)
if err != nil {
log.Fatalf("open volume: %v", err)
}
logger.Printf("opened volume: %s", *volPath)
}
defer vol.Close()
info := vol.Info()
logger.Printf("volume: %d bytes, block=%d, healthy=%v",
info.VolumeSize, info.BlockSize, info.Healthy)
// Create adapter
adapter := &blockVolAdapter{vol: vol}
// Create target server
config := iscsi.DefaultTargetConfig()
config.TargetName = *iqn
config.TargetAlias = "SeaweedFS BlockVol"
ts := iscsi.NewTargetServer(*addr, config, logger)
ts.AddVolume(*iqn, adapter)
// Graceful shutdown on signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigCh
logger.Printf("received %v, shutting down...", sig)
ts.Close()
}()
logger.Printf("starting iSCSI target: %s on %s", *iqn, *addr)
if err := ts.ListenAndServe(); err != nil {
log.Fatalf("target server: %v", err)
}
logger.Println("target stopped")
}
// blockVolAdapter wraps BlockVol to implement iscsi.BlockDevice.
type blockVolAdapter struct {
vol *blockvol.BlockVol
}
func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) {
return a.vol.ReadLBA(lba, length)
}
func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error {
return a.vol.WriteLBA(lba, data)
}
func (a *blockVolAdapter) Trim(lba uint64, length uint32) error {
return a.vol.Trim(lba, length)
}
func (a *blockVolAdapter) SyncCache() error { return a.vol.SyncCache() }
func (a *blockVolAdapter) BlockSize() uint32 { return a.vol.Info().BlockSize }
func (a *blockVolAdapter) VolumeSize() uint64 { return a.vol.Info().VolumeSize }
func (a *blockVolAdapter) IsHealthy() bool { return a.vol.Info().Healthy }
func parseSize(s string) (uint64, error) {
s = strings.TrimSpace(s)
if len(s) == 0 {
return 0, fmt.Errorf("empty size")
}
multiplier := uint64(1)
suffix := s[len(s)-1]
switch suffix {
case 'K', 'k':
multiplier = 1024
s = s[:len(s)-1]
case 'M', 'm':
multiplier = 1024 * 1024
s = s[:len(s)-1]
case 'G', 'g':
multiplier = 1024 * 1024 * 1024
s = s[:len(s)-1]
case 'T', 't':
multiplier = 1024 * 1024 * 1024 * 1024
s = s[:len(s)-1]
}
n, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return 0, err
}
return n * multiplier, nil
}

216
weed/storage/blockvol/iscsi/cmd/iscsi-target/smoke-test.sh

@ -0,0 +1,216 @@
#!/usr/bin/env bash
# smoke-test.sh — iscsiadm smoke test for SeaweedFS iSCSI target
#
# Prerequisites:
# - Linux host with iscsiadm (open-iscsi) installed
# - Root access (for iscsiadm, mkfs, mount)
# - iscsi-target binary built: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/
#
# Usage:
# sudo ./smoke-test.sh [target-addr:port]
#
# Default: starts a local target on :3260, runs discovery + login + I/O, cleans up.
set -euo pipefail
TARGET_ADDR="${1:-127.0.0.1}"
TARGET_PORT="${2:-3260}"
TARGET_IQN="iqn.2024.com.seaweedfs:smoke"
VOL_FILE="/tmp/iscsi-smoke-test.blk"
MOUNT_POINT="/tmp/iscsi-smoke-mnt"
TARGET_PID=""
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
log() { echo -e "${GREEN}[SMOKE]${NC} $*"; }
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; }
fail() { echo -e "${RED}[FAIL]${NC} $*"; exit 1; }
cleanup() {
log "Cleaning up..."
# Unmount if mounted
if mountpoint -q "$MOUNT_POINT" 2>/dev/null; then
umount "$MOUNT_POINT" || true
fi
rmdir "$MOUNT_POINT" 2>/dev/null || true
# Logout from iSCSI
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout 2>/dev/null || true
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" -o delete 2>/dev/null || true
# Stop target
if [[ -n "$TARGET_PID" ]] && kill -0 "$TARGET_PID" 2>/dev/null; then
kill "$TARGET_PID"
wait "$TARGET_PID" 2>/dev/null || true
fi
# Remove volume file
rm -f "$VOL_FILE"
log "Cleanup complete"
}
trap cleanup EXIT
# -------------------------------------------------------
# Preflight
# -------------------------------------------------------
if [[ $EUID -ne 0 ]]; then
fail "This script must be run as root (for iscsiadm/mount)"
fi
if ! command -v iscsiadm &>/dev/null; then
fail "iscsiadm not found. Install open-iscsi: apt install open-iscsi"
fi
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
TARGET_BIN="${SCRIPT_DIR}/iscsi-target"
if [[ ! -x "$TARGET_BIN" ]]; then
# Try the repo root
TARGET_BIN="$(cd "$SCRIPT_DIR/../../../../../.." && pwd)/iscsi-target"
fi
if [[ ! -x "$TARGET_BIN" ]]; then
fail "iscsi-target binary not found. Build with: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/"
fi
# -------------------------------------------------------
# Step 1: Start iSCSI target
# -------------------------------------------------------
log "Starting iSCSI target..."
"$TARGET_BIN" \
-create \
-vol "$VOL_FILE" \
-size 100M \
-addr "${TARGET_ADDR}:${TARGET_PORT}" \
-iqn "$TARGET_IQN" &
TARGET_PID=$!
sleep 1
if ! kill -0 "$TARGET_PID" 2>/dev/null; then
fail "Target failed to start"
fi
log "Target started (PID $TARGET_PID)"
# -------------------------------------------------------
# Step 2: Discovery
# -------------------------------------------------------
log "Running iscsiadm discovery..."
DISCOVERY=$(iscsiadm -m discovery -t sendtargets -p "${TARGET_ADDR}:${TARGET_PORT}" 2>&1) || {
fail "Discovery failed: $DISCOVERY"
}
echo "$DISCOVERY"
if ! echo "$DISCOVERY" | grep -q "$TARGET_IQN"; then
fail "Target IQN not found in discovery output"
fi
log "Discovery OK"
# -------------------------------------------------------
# Step 3: Login
# -------------------------------------------------------
log "Logging in to target..."
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --login || {
fail "Login failed"
}
log "Login OK"
# Wait for device to appear
sleep 2
ISCSI_DEV=$(iscsiadm -m session -P 3 2>/dev/null | grep "Attached scsi disk" | awk '{print $4}' | head -1)
if [[ -z "$ISCSI_DEV" ]]; then
warn "Could not determine attached device — trying /dev/sdb"
ISCSI_DEV="sdb"
fi
DEV_PATH="/dev/$ISCSI_DEV"
log "Attached device: $DEV_PATH"
if [[ ! -b "$DEV_PATH" ]]; then
fail "Block device $DEV_PATH not found"
fi
# -------------------------------------------------------
# Step 4: dd write/read test
# -------------------------------------------------------
log "Writing test pattern with dd..."
dd if=/dev/urandom of=/tmp/iscsi-smoke-pattern bs=1M count=1 2>/dev/null
dd if=/tmp/iscsi-smoke-pattern of="$DEV_PATH" bs=1M count=1 oflag=direct 2>/dev/null || {
fail "dd write failed"
}
log "Reading back and verifying..."
dd if="$DEV_PATH" of=/tmp/iscsi-smoke-readback bs=1M count=1 iflag=direct 2>/dev/null || {
fail "dd read failed"
}
if cmp -s /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback; then
log "Data integrity verified (dd write/read OK)"
else
fail "Data integrity check failed!"
fi
rm -f /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback
# -------------------------------------------------------
# Step 5: mkfs + mount + file I/O
# -------------------------------------------------------
log "Creating ext4 filesystem..."
mkfs.ext4 -F -q "$DEV_PATH" || {
warn "mkfs.ext4 failed — skipping filesystem tests"
# Still consider the test a pass if dd worked
log "SMOKE TEST PASSED (dd only, mkfs skipped)"
exit 0
}
mkdir -p "$MOUNT_POINT"
mount "$DEV_PATH" "$MOUNT_POINT" || {
fail "mount failed"
}
log "Mounted at $MOUNT_POINT"
# Write a file
echo "SeaweedFS iSCSI smoke test" > "$MOUNT_POINT/test.txt"
sync
# Read it back
CONTENT=$(cat "$MOUNT_POINT/test.txt")
if [[ "$CONTENT" != "SeaweedFS iSCSI smoke test" ]]; then
fail "File content mismatch"
fi
log "File I/O verified"
# -------------------------------------------------------
# Step 6: Logout
# -------------------------------------------------------
log "Unmounting and logging out..."
umount "$MOUNT_POINT"
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout || {
fail "Logout failed"
}
log "Logout OK"
# Verify no stale sessions
SESSION_COUNT=$(iscsiadm -m session 2>/dev/null | grep -c "$TARGET_IQN" || true)
if [[ "$SESSION_COUNT" -gt 0 ]]; then
fail "Stale session detected after logout"
fi
log "No stale sessions"
# -------------------------------------------------------
# Result
# -------------------------------------------------------
echo ""
echo -e "${GREEN}========================================${NC}"
echo -e "${GREEN} SMOKE TEST PASSED${NC}"
echo -e "${GREEN}========================================${NC}"
echo ""
echo " Discovery: OK"
echo " Login: OK"
echo " dd I/O: OK"
echo " mkfs+mount: OK"
echo " File I/O: OK"
echo " Logout: OK"
echo " Cleanup: OK"
echo ""

207
weed/storage/blockvol/iscsi/dataio.go

@ -0,0 +1,207 @@
package iscsi
import (
"errors"
"io"
)
var (
ErrDataSNOrder = errors.New("iscsi: Data-Out DataSN out of order")
ErrDataOverflow = errors.New("iscsi: data exceeds expected transfer length")
ErrDataIncomplete = errors.New("iscsi: data transfer incomplete")
)
// DataInWriter splits a read response into multiple Data-In PDUs, respecting
// MaxRecvDataSegmentLength. The final PDU carries the S-bit (status) and F-bit.
type DataInWriter struct {
maxSegLen uint32 // negotiated MaxRecvDataSegmentLength
}
// NewDataInWriter creates a writer with the given max segment length.
func NewDataInWriter(maxSegLen uint32) *DataInWriter {
if maxSegLen == 0 {
maxSegLen = 8192 // sensible default
}
return &DataInWriter{maxSegLen: maxSegLen}
}
// WriteDataIn splits data into Data-In PDUs and writes them to w.
// itt is the initiator task tag, ttt is the target transfer tag.
// statSN is the current StatSN (incremented when S-bit is set).
// Returns the number of PDUs written.
func (d *DataInWriter) WriteDataIn(w io.Writer, data []byte, itt uint32, expCmdSN, maxCmdSN uint32, statSN *uint32) (int, error) {
totalLen := uint32(len(data))
if totalLen == 0 {
// Zero-length read — send single Data-In with S-bit, no data
pdu := &PDU{}
pdu.SetOpcode(OpSCSIDataIn)
pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status
pdu.SetInitiatorTaskTag(itt)
pdu.SetTargetTransferTag(0xFFFFFFFF)
pdu.SetStatSN(*statSN)
*statSN++
pdu.SetExpCmdSN(expCmdSN)
pdu.SetMaxCmdSN(maxCmdSN)
pdu.SetDataSN(0)
pdu.SetSCSIStatus(SCSIStatusGood)
if err := WritePDU(w, pdu); err != nil {
return 0, err
}
return 1, nil
}
var offset uint32
var dataSN uint32
count := 0
for offset < totalLen {
segLen := d.maxSegLen
if offset+segLen > totalLen {
segLen = totalLen - offset
}
isFinal := (offset + segLen) >= totalLen
pdu := &PDU{}
pdu.SetOpcode(OpSCSIDataIn)
pdu.SetInitiatorTaskTag(itt)
pdu.SetTargetTransferTag(0xFFFFFFFF)
pdu.SetExpCmdSN(expCmdSN)
pdu.SetMaxCmdSN(maxCmdSN)
pdu.SetDataSN(dataSN)
pdu.SetBufferOffset(offset)
pdu.DataSegment = data[offset : offset+segLen]
if isFinal {
pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status
pdu.SetStatSN(*statSN)
*statSN++
pdu.SetSCSIStatus(SCSIStatusGood)
} else {
pdu.SetOpSpecific1(0) // no flags
}
if err := WritePDU(w, pdu); err != nil {
return count, err
}
count++
dataSN++
offset += segLen
}
return count, nil
}
// DataOutCollector collects Data-Out PDUs for a single write command,
// assembling the full data buffer from potentially multiple PDUs.
// It handles both immediate data and R2T-solicited data.
type DataOutCollector struct {
expectedLen uint32
buf []byte
received uint32
nextDataSN uint32
done bool
}
// NewDataOutCollector creates a collector expecting the given total transfer length.
func NewDataOutCollector(expectedLen uint32) *DataOutCollector {
return &DataOutCollector{
expectedLen: expectedLen,
buf: make([]byte, expectedLen),
}
}
// AddImmediateData adds the data from the SCSI Command PDU (immediate data).
func (c *DataOutCollector) AddImmediateData(data []byte) error {
if uint32(len(data)) > c.expectedLen {
return ErrDataOverflow
}
copy(c.buf, data)
c.received += uint32(len(data))
if c.received >= c.expectedLen {
c.done = true
}
return nil
}
// AddDataOut processes a Data-Out PDU and adds its data to the buffer.
func (c *DataOutCollector) AddDataOut(pdu *PDU) error {
dataSN := pdu.DataSN()
if dataSN != c.nextDataSN {
return ErrDataSNOrder
}
c.nextDataSN++
offset := pdu.BufferOffset()
data := pdu.DataSegment
end := offset + uint32(len(data))
if end > c.expectedLen {
return ErrDataOverflow
}
copy(c.buf[offset:], data)
c.received += uint32(len(data))
// Check F-bit
if pdu.OpSpecific1()&FlagF != 0 {
c.done = true
}
return nil
}
// Done returns true if all expected data has been received.
func (c *DataOutCollector) Done() bool { return c.done }
// Data returns the assembled data buffer.
func (c *DataOutCollector) Data() []byte { return c.buf }
// Remaining returns how many bytes are still needed.
func (c *DataOutCollector) Remaining() uint32 {
if c.received >= c.expectedLen {
return 0
}
return c.expectedLen - c.received
}
// BuildR2T creates an R2T PDU requesting more data from the initiator.
func BuildR2T(itt, ttt uint32, r2tSN uint32, bufferOffset, desiredLen uint32, statSN, expCmdSN, maxCmdSN uint32) *PDU {
pdu := &PDU{}
pdu.SetOpcode(OpR2T)
pdu.SetOpSpecific1(FlagF) // always Final for R2T
pdu.SetInitiatorTaskTag(itt)
pdu.SetTargetTransferTag(ttt)
pdu.SetStatSN(statSN)
pdu.SetExpCmdSN(expCmdSN)
pdu.SetMaxCmdSN(maxCmdSN)
pdu.SetR2TSN(r2tSN)
pdu.SetBufferOffset(bufferOffset)
pdu.SetDesiredDataLength(desiredLen)
return pdu
}
// SendSCSIResponse sends a SCSI Response PDU with optional sense data.
func SendSCSIResponse(w io.Writer, result SCSIResult, itt uint32, statSN *uint32, expCmdSN, maxCmdSN uint32) error {
pdu := &PDU{}
pdu.SetOpcode(OpSCSIResp)
pdu.SetOpSpecific1(FlagF) // Final
pdu.SetSCSIResponse(ISCSIRespCompleted)
pdu.SetSCSIStatus(result.Status)
pdu.SetInitiatorTaskTag(itt)
pdu.SetStatSN(*statSN)
*statSN++
pdu.SetExpCmdSN(expCmdSN)
pdu.SetMaxCmdSN(maxCmdSN)
if result.Status == SCSIStatusCheckCond {
senseData := BuildSenseData(result.SenseKey, result.SenseASC, result.SenseASCQ)
// Sense data is wrapped in a 2-byte length prefix
pdu.DataSegment = make([]byte, 2+len(senseData))
pdu.DataSegment[0] = byte(len(senseData) >> 8)
pdu.DataSegment[1] = byte(len(senseData))
copy(pdu.DataSegment[2:], senseData)
}
return WritePDU(w, pdu)
}

410
weed/storage/blockvol/iscsi/dataio_test.go

@ -0,0 +1,410 @@
package iscsi
import (
"bytes"
"testing"
)
func TestDataIO(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"datain_single_pdu", testDataInSinglePDU},
{"datain_multi_pdu", testDataInMultiPDU},
{"datain_exact_boundary", testDataInExactBoundary},
{"datain_zero_length", testDataInZeroLength},
{"datain_datasn_ordering", testDataInDataSNOrdering},
{"datain_fbit_sbit", testDataInFbitSbit},
{"dataout_single_pdu", testDataOutSinglePDU},
{"dataout_multi_pdu", testDataOutMultiPDU},
{"dataout_immediate_data", testDataOutImmediateData},
{"dataout_immediate_plus_r2t", testDataOutImmediatePlusR2T},
{"dataout_wrong_datasn", testDataOutWrongDataSN},
{"dataout_overflow", testDataOutOverflow},
{"r2t_build", testR2TBuild},
{"scsi_response_good", testSCSIResponseGood},
{"scsi_response_check_condition", testSCSIResponseCheckCondition},
{"datain_statsn_increment", testDataInStatSNIncrement},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testDataInSinglePDU(t *testing.T) {
w := &bytes.Buffer{}
dw := NewDataInWriter(8192)
data := bytes.Repeat([]byte{0xAA}, 4096)
statSN := uint32(1)
n, err := dw.WriteDataIn(w, data, 0x100, 1, 10, &statSN)
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Fatalf("expected 1 PDU, got %d", n)
}
if statSN != 2 {
t.Fatalf("StatSN should be 2, got %d", statSN)
}
pdu, err := ReadPDU(w)
if err != nil {
t.Fatal(err)
}
if pdu.Opcode() != OpSCSIDataIn {
t.Fatal("wrong opcode")
}
if pdu.OpSpecific1()&FlagF == 0 {
t.Fatal("F-bit not set")
}
if pdu.OpSpecific1()&FlagS == 0 {
t.Fatal("S-bit not set on final PDU")
}
if len(pdu.DataSegment) != 4096 {
t.Fatalf("data length: %d", len(pdu.DataSegment))
}
}
func testDataInMultiPDU(t *testing.T) {
w := &bytes.Buffer{}
dw := NewDataInWriter(1024) // small segment
data := bytes.Repeat([]byte{0xBB}, 3000)
statSN := uint32(1)
n, err := dw.WriteDataIn(w, data, 0x200, 1, 10, &statSN)
if err != nil {
t.Fatal(err)
}
// 3000 / 1024 = 3 PDUs (1024 + 1024 + 952)
if n != 3 {
t.Fatalf("expected 3 PDUs, got %d", n)
}
// Read them back
var reassembled []byte
for i := 0; i < 3; i++ {
pdu, err := ReadPDU(w)
if err != nil {
t.Fatalf("PDU %d: %v", i, err)
}
if pdu.DataSN() != uint32(i) {
t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN())
}
reassembled = append(reassembled, pdu.DataSegment...)
if i < 2 {
if pdu.OpSpecific1()&FlagF != 0 {
t.Fatalf("PDU %d should not have F-bit", i)
}
} else {
if pdu.OpSpecific1()&FlagF == 0 {
t.Fatal("last PDU should have F-bit")
}
if pdu.OpSpecific1()&FlagS == 0 {
t.Fatal("last PDU should have S-bit")
}
}
}
if !bytes.Equal(reassembled, data) {
t.Fatal("reassembled data mismatch")
}
}
func testDataInExactBoundary(t *testing.T) {
w := &bytes.Buffer{}
dw := NewDataInWriter(1024)
data := bytes.Repeat([]byte{0xCC}, 2048) // exact 2 PDUs
statSN := uint32(1)
n, err := dw.WriteDataIn(w, data, 0x300, 1, 10, &statSN)
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("expected 2 PDUs, got %d", n)
}
}
func testDataInZeroLength(t *testing.T) {
w := &bytes.Buffer{}
dw := NewDataInWriter(8192)
statSN := uint32(5)
n, err := dw.WriteDataIn(w, nil, 0x400, 1, 10, &statSN)
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Fatalf("expected 1 PDU for zero-length, got %d", n)
}
if statSN != 6 {
t.Fatal("StatSN should still increment")
}
}
func testDataInDataSNOrdering(t *testing.T) {
w := &bytes.Buffer{}
dw := NewDataInWriter(512)
data := bytes.Repeat([]byte{0xDD}, 2048) // 4 PDUs
statSN := uint32(1)
dw.WriteDataIn(w, data, 0x500, 1, 10, &statSN)
for i := 0; i < 4; i++ {
pdu, _ := ReadPDU(w)
if pdu.DataSN() != uint32(i) {
t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN())
}
expectedOffset := uint32(i) * 512
if pdu.BufferOffset() != expectedOffset {
t.Fatalf("PDU %d: offset=%d, expected %d", i, pdu.BufferOffset(), expectedOffset)
}
}
}
func testDataInFbitSbit(t *testing.T) {
w := &bytes.Buffer{}
dw := NewDataInWriter(1000)
data := bytes.Repeat([]byte{0xEE}, 2500)
statSN := uint32(1)
dw.WriteDataIn(w, data, 0x600, 1, 10, &statSN)
for i := 0; i < 3; i++ {
pdu, _ := ReadPDU(w)
flags := pdu.OpSpecific1()
if i < 2 {
if flags&FlagF != 0 || flags&FlagS != 0 {
t.Fatalf("PDU %d: should have no F/S bits", i)
}
} else {
if flags&FlagF == 0 || flags&FlagS == 0 {
t.Fatal("last PDU must have F+S bits")
}
}
}
}
func testDataOutSinglePDU(t *testing.T) {
c := NewDataOutCollector(4096)
pdu := &PDU{}
pdu.SetOpcode(OpSCSIDataOut)
pdu.SetOpSpecific1(FlagF)
pdu.SetDataSN(0)
pdu.SetBufferOffset(0)
pdu.DataSegment = bytes.Repeat([]byte{0x11}, 4096)
if err := c.AddDataOut(pdu); err != nil {
t.Fatal(err)
}
if !c.Done() {
t.Fatal("should be done")
}
if c.Remaining() != 0 {
t.Fatal("remaining should be 0")
}
}
func testDataOutMultiPDU(t *testing.T) {
c := NewDataOutCollector(8192)
for i := 0; i < 2; i++ {
pdu := &PDU{}
pdu.SetOpcode(OpSCSIDataOut)
pdu.SetDataSN(uint32(i))
pdu.SetBufferOffset(uint32(i) * 4096)
pdu.DataSegment = bytes.Repeat([]byte{byte(i + 1)}, 4096)
if i == 1 {
pdu.SetOpSpecific1(FlagF)
}
if err := c.AddDataOut(pdu); err != nil {
t.Fatalf("PDU %d: %v", i, err)
}
}
if !c.Done() {
t.Fatal("should be done")
}
data := c.Data()
if data[0] != 0x01 || data[4096] != 0x02 {
t.Fatal("data assembly wrong")
}
}
func testDataOutImmediateData(t *testing.T) {
c := NewDataOutCollector(4096)
err := c.AddImmediateData(bytes.Repeat([]byte{0xFF}, 4096))
if err != nil {
t.Fatal(err)
}
if !c.Done() {
t.Fatal("should be done with immediate data")
}
}
func testDataOutImmediatePlusR2T(t *testing.T) {
c := NewDataOutCollector(8192)
// Immediate: first 4096
err := c.AddImmediateData(bytes.Repeat([]byte{0xAA}, 4096))
if err != nil {
t.Fatal(err)
}
if c.Done() {
t.Fatal("should not be done yet")
}
if c.Remaining() != 4096 {
t.Fatalf("remaining: %d", c.Remaining())
}
// R2T-solicited Data-Out: next 4096
pdu := &PDU{}
pdu.SetOpcode(OpSCSIDataOut)
pdu.SetOpSpecific1(FlagF)
pdu.SetDataSN(0)
pdu.SetBufferOffset(4096)
pdu.DataSegment = bytes.Repeat([]byte{0xBB}, 4096)
if err := c.AddDataOut(pdu); err != nil {
t.Fatal(err)
}
if !c.Done() {
t.Fatal("should be done")
}
data := c.Data()
if data[0] != 0xAA || data[4096] != 0xBB {
t.Fatal("assembly wrong")
}
}
func testDataOutWrongDataSN(t *testing.T) {
c := NewDataOutCollector(8192)
pdu := &PDU{}
pdu.SetOpcode(OpSCSIDataOut)
pdu.SetDataSN(1) // should be 0
pdu.SetBufferOffset(0)
pdu.DataSegment = make([]byte, 4096)
err := c.AddDataOut(pdu)
if err != ErrDataSNOrder {
t.Fatalf("expected ErrDataSNOrder, got %v", err)
}
}
func testDataOutOverflow(t *testing.T) {
c := NewDataOutCollector(4096)
pdu := &PDU{}
pdu.SetOpcode(OpSCSIDataOut)
pdu.SetDataSN(0)
pdu.SetBufferOffset(0)
pdu.DataSegment = make([]byte, 8192) // more than expected
err := c.AddDataOut(pdu)
if err != ErrDataOverflow {
t.Fatalf("expected ErrDataOverflow, got %v", err)
}
}
func testR2TBuild(t *testing.T) {
pdu := BuildR2T(0x100, 0x200, 0, 4096, 4096, 5, 10, 20)
if pdu.Opcode() != OpR2T {
t.Fatal("wrong opcode")
}
if pdu.InitiatorTaskTag() != 0x100 {
t.Fatal("ITT wrong")
}
if pdu.TargetTransferTag() != 0x200 {
t.Fatal("TTT wrong")
}
if pdu.R2TSN() != 0 {
t.Fatal("R2TSN wrong")
}
if pdu.BufferOffset() != 4096 {
t.Fatal("offset wrong")
}
if pdu.DesiredDataLength() != 4096 {
t.Fatal("desired length wrong")
}
if pdu.StatSN() != 5 {
t.Fatal("StatSN wrong")
}
}
func testSCSIResponseGood(t *testing.T) {
w := &bytes.Buffer{}
statSN := uint32(10)
result := SCSIResult{Status: SCSIStatusGood}
err := SendSCSIResponse(w, result, 0x300, &statSN, 5, 15)
if err != nil {
t.Fatal(err)
}
if statSN != 11 {
t.Fatal("StatSN not incremented")
}
pdu, err := ReadPDU(w)
if err != nil {
t.Fatal(err)
}
if pdu.Opcode() != OpSCSIResp {
t.Fatal("wrong opcode")
}
if pdu.SCSIStatus() != SCSIStatusGood {
t.Fatal("status wrong")
}
if len(pdu.DataSegment) != 0 {
t.Fatal("no data expected for good status")
}
}
func testSCSIResponseCheckCondition(t *testing.T) {
w := &bytes.Buffer{}
statSN := uint32(20)
result := SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: SenseIllegalRequest,
SenseASC: ASCInvalidOpcode,
SenseASCQ: ASCQLuk,
}
err := SendSCSIResponse(w, result, 0x400, &statSN, 5, 15)
if err != nil {
t.Fatal(err)
}
pdu, err := ReadPDU(w)
if err != nil {
t.Fatal(err)
}
if pdu.SCSIStatus() != SCSIStatusCheckCond {
t.Fatal("status wrong")
}
// Data segment should contain sense data with 2-byte length prefix
if len(pdu.DataSegment) < 20 { // 2 + 18
t.Fatalf("sense data too short: %d", len(pdu.DataSegment))
}
senseLen := int(pdu.DataSegment[0])<<8 | int(pdu.DataSegment[1])
if senseLen != 18 {
t.Fatalf("sense length: %d", senseLen)
}
}
func testDataInStatSNIncrement(t *testing.T) {
w := &bytes.Buffer{}
dw := NewDataInWriter(1024)
data := bytes.Repeat([]byte{0x00}, 3072) // 3 PDUs
statSN := uint32(100)
dw.WriteDataIn(w, data, 0x700, 1, 10, &statSN)
// Only the final PDU has S-bit, so StatSN increments once
if statSN != 101 {
t.Fatalf("StatSN should be 101, got %d", statSN)
}
}

68
weed/storage/blockvol/iscsi/discovery.go

@ -0,0 +1,68 @@
package iscsi
// HandleTextRequest processes a Text Request PDU.
// Currently supports SendTargets discovery (RFC 7143, Section 12.3).
func HandleTextRequest(req *PDU, targets []DiscoveryTarget) *PDU {
resp := &PDU{}
resp.SetOpcode(OpTextResp)
resp.SetOpSpecific1(FlagF) // Final
resp.SetInitiatorTaskTag(req.InitiatorTaskTag())
resp.SetTargetTransferTag(0xFFFFFFFF) // no continuation
params, err := ParseParams(req.DataSegment)
if err != nil {
// Malformed — return empty response
return resp
}
val, ok := params.Get("SendTargets")
if !ok {
return resp
}
// Discovery responses can have duplicate keys (TargetName appears
// for each target), so we use EncodeDiscoveryTargets directly.
var matched []DiscoveryTarget
switch val {
case "All":
matched = targets
default:
for _, tgt := range targets {
if tgt.Name == val {
matched = append(matched, tgt)
break
}
}
}
if data := EncodeDiscoveryTargets(matched); len(data) > 0 {
resp.DataSegment = data
}
return resp
}
// DiscoveryTarget represents a target available for discovery.
type DiscoveryTarget struct {
Name string // IQN, e.g., "iqn.2024.com.seaweedfs:vol1"
Address string // IP:port,portal-group, e.g., "10.0.0.1:3260,1"
}
// EncodeDiscoveryTargets encodes multiple targets into text parameter format.
// Each target produces TargetName=<iqn>\0TargetAddress=<addr>\0
// This handles the special multi-value encoding where the same key can appear
// multiple times (unlike normal negotiation params).
func EncodeDiscoveryTargets(targets []DiscoveryTarget) []byte {
if len(targets) == 0 {
return nil
}
var buf []byte
for _, tgt := range targets {
buf = append(buf, "TargetName="+tgt.Name+"\x00"...)
if tgt.Address != "" {
buf = append(buf, "TargetAddress="+tgt.Address+"\x00"...)
}
}
return buf
}

213
weed/storage/blockvol/iscsi/discovery_test.go

@ -0,0 +1,213 @@
package iscsi
import (
"strings"
"testing"
)
func TestDiscovery(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"send_targets_all", testSendTargetsAll},
{"send_targets_specific", testSendTargetsSpecific},
{"send_targets_not_found", testSendTargetsNotFound},
{"send_targets_empty_list", testSendTargetsEmptyList},
{"no_send_targets_key", testNoSendTargetsKey},
{"malformed_text_request", testMalformedTextRequest},
{"special_chars_in_iqn", testSpecialCharsInIQN},
{"multiple_targets", testMultipleTargets},
{"encode_discovery_targets", testEncodeDiscoveryTargets},
{"encode_empty_targets", testEncodeEmptyTargets},
{"target_without_address", testTargetWithoutAddress},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
var testTargets = []DiscoveryTarget{
{Name: "iqn.2024.com.seaweedfs:vol1", Address: "10.0.0.1:3260,1"},
{Name: "iqn.2024.com.seaweedfs:vol2", Address: "10.0.0.2:3260,1"},
}
func makeTextReq(params *Params) *PDU {
p := &PDU{}
p.SetOpcode(OpTextReq)
p.SetOpSpecific1(FlagF)
p.SetInitiatorTaskTag(0x1234)
if params != nil {
p.DataSegment = params.Encode()
}
return p
}
func testSendTargetsAll(t *testing.T) {
params := NewParams()
params.Set("SendTargets", "All")
req := makeTextReq(params)
resp := HandleTextRequest(req, testTargets)
if resp.Opcode() != OpTextResp {
t.Fatalf("opcode: 0x%02x", resp.Opcode())
}
if resp.InitiatorTaskTag() != 0x1234 {
t.Fatal("ITT mismatch")
}
body := string(resp.DataSegment)
if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol1") {
t.Fatal("missing vol1")
}
if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol2") {
t.Fatal("missing vol2")
}
if !strings.Contains(body, "10.0.0.1:3260,1") {
t.Fatal("missing vol1 address")
}
}
func testSendTargetsSpecific(t *testing.T) {
params := NewParams()
params.Set("SendTargets", "iqn.2024.com.seaweedfs:vol2")
req := makeTextReq(params)
resp := HandleTextRequest(req, testTargets)
body := string(resp.DataSegment)
if !strings.Contains(body, "vol2") {
t.Fatal("missing vol2")
}
// Should not contain vol1
if strings.Contains(body, "vol1") {
t.Fatal("should not contain vol1")
}
}
func testSendTargetsNotFound(t *testing.T) {
params := NewParams()
params.Set("SendTargets", "iqn.2024.com.seaweedfs:nonexistent")
req := makeTextReq(params)
resp := HandleTextRequest(req, testTargets)
if len(resp.DataSegment) != 0 {
t.Fatalf("expected empty response, got %q", resp.DataSegment)
}
}
func testSendTargetsEmptyList(t *testing.T) {
params := NewParams()
params.Set("SendTargets", "All")
req := makeTextReq(params)
resp := HandleTextRequest(req, nil)
if len(resp.DataSegment) != 0 {
t.Fatalf("expected empty response, got %q", resp.DataSegment)
}
}
func testNoSendTargetsKey(t *testing.T) {
params := NewParams()
params.Set("SomethingElse", "value")
req := makeTextReq(params)
resp := HandleTextRequest(req, testTargets)
if len(resp.DataSegment) != 0 {
t.Fatal("expected empty response for non-SendTargets request")
}
}
func testMalformedTextRequest(t *testing.T) {
req := &PDU{}
req.SetOpcode(OpTextReq)
req.SetInitiatorTaskTag(0x5678)
req.DataSegment = []byte("not a valid param format") // no '='
resp := HandleTextRequest(req, testTargets)
// Should return empty response without error
if resp.Opcode() != OpTextResp {
t.Fatalf("opcode: 0x%02x", resp.Opcode())
}
if resp.InitiatorTaskTag() != 0x5678 {
t.Fatal("ITT mismatch")
}
}
func testSpecialCharsInIQN(t *testing.T) {
targets := []DiscoveryTarget{
{Name: "iqn.2024-01.com.example:storage.tape1.sys1.xyz", Address: "192.168.1.100:3260,1"},
}
params := NewParams()
params.Set("SendTargets", "All")
req := makeTextReq(params)
resp := HandleTextRequest(req, targets)
body := string(resp.DataSegment)
if !strings.Contains(body, "iqn.2024-01.com.example:storage.tape1.sys1.xyz") {
t.Fatal("IQN with special chars not found")
}
}
func testMultipleTargets(t *testing.T) {
targets := make([]DiscoveryTarget, 10)
for i := range targets {
targets[i] = DiscoveryTarget{
Name: "iqn.2024.com.test:vol" + string(rune('0'+i)),
Address: "10.0.0.1:3260,1",
}
}
params := NewParams()
params.Set("SendTargets", "All")
req := makeTextReq(params)
resp := HandleTextRequest(req, targets)
body := string(resp.DataSegment)
// The response uses EncodeDiscoveryTargets internally via the params,
// but since Params doesn't allow duplicate keys, the last one wins.
// This is a known limitation — for multi-target discovery, we use
// EncodeDiscoveryTargets directly. Let's verify at least the last target.
if !strings.Contains(body, "TargetName=") {
t.Fatal("no TargetName in response")
}
}
func testEncodeDiscoveryTargets(t *testing.T) {
encoded := EncodeDiscoveryTargets(testTargets)
body := string(encoded)
// Should contain both targets with proper format
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") {
t.Fatal("missing vol1")
}
if !strings.Contains(body, "TargetAddress=10.0.0.1:3260,1\x00") {
t.Fatal("missing vol1 address")
}
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol2\x00") {
t.Fatal("missing vol2")
}
}
func testEncodeEmptyTargets(t *testing.T) {
encoded := EncodeDiscoveryTargets(nil)
if encoded != nil {
t.Fatalf("expected nil, got %q", encoded)
}
}
func testTargetWithoutAddress(t *testing.T) {
targets := []DiscoveryTarget{
{Name: "iqn.2024.com.seaweedfs:vol1"}, // no address
}
encoded := EncodeDiscoveryTargets(targets)
body := string(encoded)
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") {
t.Fatal("missing target name")
}
if strings.Contains(body, "TargetAddress") {
t.Fatal("should not have TargetAddress")
}
}

412
weed/storage/blockvol/iscsi/integration_test.go

@ -0,0 +1,412 @@
package iscsi_test
import (
"bytes"
"encoding/binary"
"io"
"log"
"net"
"path/filepath"
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol"
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi"
)
// blockVolAdapter wraps a BlockVol to implement iscsi.BlockDevice.
type blockVolAdapter struct {
vol *blockvol.BlockVol
}
func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) {
return a.vol.ReadLBA(lba, length)
}
func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error {
return a.vol.WriteLBA(lba, data)
}
func (a *blockVolAdapter) Trim(lba uint64, length uint32) error {
return a.vol.Trim(lba, length)
}
func (a *blockVolAdapter) SyncCache() error {
return a.vol.SyncCache()
}
func (a *blockVolAdapter) BlockSize() uint32 {
return a.vol.Info().BlockSize
}
func (a *blockVolAdapter) VolumeSize() uint64 {
return a.vol.Info().VolumeSize
}
func (a *blockVolAdapter) IsHealthy() bool {
return a.vol.Info().Healthy
}
const (
intTargetName = "iqn.2024.com.seaweedfs:integration"
intInitiatorName = "iqn.2024.com.test:client"
)
func createTestVol(t *testing.T) *blockvol.BlockVol {
t.Helper()
path := filepath.Join(t.TempDir(), "test.blk")
vol, err := blockvol.CreateBlockVol(path, blockvol.CreateOptions{
VolumeSize: 1024 * 4096, // 1024 blocks = 4MB
BlockSize: 4096,
WALSize: 1024 * 1024, // 1MB WAL
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { vol.Close() })
return vol
}
func setupIntegrationTarget(t *testing.T) (net.Conn, *iscsi.TargetServer) {
t.Helper()
vol := createTestVol(t)
adapter := &blockVolAdapter{vol: vol}
config := iscsi.DefaultTargetConfig()
config.TargetName = intTargetName
logger := log.New(io.Discard, "", 0)
ts := iscsi.NewTargetServer("127.0.0.1:0", config, logger)
ts.AddVolume(intTargetName, adapter)
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
go ts.Serve(ln)
t.Cleanup(func() { ts.Close() })
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })
// Login
params := iscsi.NewParams()
params.Set("InitiatorName", intInitiatorName)
params.Set("TargetName", intTargetName)
params.Set("SessionType", "Normal")
loginReq := &iscsi.PDU{}
loginReq.SetOpcode(iscsi.OpLoginReq)
loginReq.SetLoginStages(iscsi.StageSecurityNeg, iscsi.StageFullFeature)
loginReq.SetLoginTransit(true)
loginReq.SetISID([6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01})
loginReq.SetCmdSN(1)
loginReq.DataSegment = params.Encode()
if err := iscsi.WritePDU(conn, loginReq); err != nil {
t.Fatal(err)
}
resp, err := iscsi.ReadPDU(conn)
if err != nil {
t.Fatal(err)
}
if resp.LoginStatusClass() != iscsi.LoginStatusSuccess {
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail())
}
return conn, ts
}
func TestIntegration(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"write_read_verify", testWriteReadVerify},
{"multi_block_write_read", testMultiBlockWriteRead},
{"inquiry_capacity", testInquiryCapacity},
{"sync_cache", testIntSyncCache},
{"test_unit_ready", testIntTestUnitReady},
{"concurrent_readers_writers", testConcurrentReadersWriters},
{"write_at_boundary", testWriteAtBoundary},
{"unmap_integration", testUnmapIntegration},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func sendSCSICmd(t *testing.T, conn net.Conn, cdb [16]byte, cmdSN uint32, read bool, write bool, dataOut []byte, expLen uint32) *iscsi.PDU {
t.Helper()
cmd := &iscsi.PDU{}
cmd.SetOpcode(iscsi.OpSCSICmd)
flags := uint8(iscsi.FlagF)
if read {
flags |= iscsi.FlagR
}
if write {
flags |= iscsi.FlagW
}
cmd.SetOpSpecific1(flags)
cmd.SetInitiatorTaskTag(cmdSN)
cmd.SetExpectedDataTransferLength(expLen)
cmd.SetCmdSN(cmdSN)
cmd.SetCDB(cdb)
if dataOut != nil {
cmd.DataSegment = dataOut
}
if err := iscsi.WritePDU(conn, cmd); err != nil {
t.Fatal(err)
}
resp, err := iscsi.ReadPDU(conn)
if err != nil {
t.Fatal(err)
}
return resp
}
func testWriteReadVerify(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
// Write pattern to LBA 0
writeData := make([]byte, 4096)
for i := range writeData {
writeData[i] = byte(i % 251) // prime modulus for pattern
}
var wCDB [16]byte
wCDB[0] = iscsi.ScsiWrite10
binary.BigEndian.PutUint32(wCDB[2:6], 0) // LBA 0
binary.BigEndian.PutUint16(wCDB[7:9], 1) // 1 block
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096)
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatalf("write failed: status %d", resp.SCSIStatus())
}
// Read it back
var rCDB [16]byte
rCDB[0] = iscsi.ScsiRead10
binary.BigEndian.PutUint32(rCDB[2:6], 0)
binary.BigEndian.PutUint16(rCDB[7:9], 1)
resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, 4096)
if resp2.Opcode() != iscsi.OpSCSIDataIn {
t.Fatalf("expected Data-In, got %s", iscsi.OpcodeName(resp2.Opcode()))
}
if !bytes.Equal(resp2.DataSegment, writeData) {
t.Fatal("data integrity verification failed")
}
}
func testMultiBlockWriteRead(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
// Write 4 blocks at LBA 10
blockCount := uint16(4)
dataLen := int(blockCount) * 4096
writeData := make([]byte, dataLen)
for i := range writeData {
writeData[i] = byte((i / 4096) + 1) // each block has different fill
}
var wCDB [16]byte
wCDB[0] = iscsi.ScsiWrite10
binary.BigEndian.PutUint32(wCDB[2:6], 10)
binary.BigEndian.PutUint16(wCDB[7:9], blockCount)
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, uint32(dataLen))
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatalf("write failed: status %d", resp.SCSIStatus())
}
// Read back all 4 blocks
var rCDB [16]byte
rCDB[0] = iscsi.ScsiRead10
binary.BigEndian.PutUint32(rCDB[2:6], 10)
binary.BigEndian.PutUint16(rCDB[7:9], blockCount)
resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, uint32(dataLen))
// May be split across multiple Data-In PDUs — reassemble
var readData []byte
readData = append(readData, resp2.DataSegment...)
// If first PDU doesn't have S-bit, read more
for resp2.OpSpecific1()&iscsi.FlagS == 0 {
var err error
resp2, err = iscsi.ReadPDU(conn)
if err != nil {
t.Fatal(err)
}
readData = append(readData, resp2.DataSegment...)
}
if !bytes.Equal(readData, writeData) {
t.Fatal("multi-block data integrity failed")
}
}
func testInquiryCapacity(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
// TEST UNIT READY
var tuCDB [16]byte
tuCDB[0] = iscsi.ScsiTestUnitReady
resp := sendSCSICmd(t, conn, tuCDB, 2, false, false, nil, 0)
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatal("test unit ready failed")
}
// INQUIRY
var iqCDB [16]byte
iqCDB[0] = iscsi.ScsiInquiry
binary.BigEndian.PutUint16(iqCDB[3:5], 96)
resp = sendSCSICmd(t, conn, iqCDB, 3, true, false, nil, 96)
if resp.Opcode() != iscsi.OpSCSIDataIn {
t.Fatal("expected Data-In for inquiry")
}
if resp.DataSegment[0] != 0x00 { // SBC device
t.Fatal("not SBC device type")
}
// READ CAPACITY 10
var rcCDB [16]byte
rcCDB[0] = iscsi.ScsiReadCapacity10
resp = sendSCSICmd(t, conn, rcCDB, 4, true, false, nil, 8)
if resp.Opcode() != iscsi.OpSCSIDataIn {
t.Fatal("expected Data-In for read capacity")
}
lastLBA := binary.BigEndian.Uint32(resp.DataSegment[0:4])
blockSize := binary.BigEndian.Uint32(resp.DataSegment[4:8])
if lastLBA != 1023 { // 1024 blocks, last = 1023
t.Fatalf("last LBA: %d, expected 1023", lastLBA)
}
if blockSize != 4096 {
t.Fatalf("block size: %d", blockSize)
}
}
func testIntSyncCache(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
var cdb [16]byte
cdb[0] = iscsi.ScsiSyncCache10
resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0)
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatalf("sync cache failed: %d", resp.SCSIStatus())
}
}
func testIntTestUnitReady(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
var cdb [16]byte
cdb[0] = iscsi.ScsiTestUnitReady
resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0)
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatal("TUR failed")
}
}
func testConcurrentReadersWriters(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
// Sequential writes to different LBAs, then sequential reads to verify
cmdSN := uint32(2)
for lba := uint32(0); lba < 10; lba++ {
data := bytes.Repeat([]byte{byte(lba)}, 4096)
var wCDB [16]byte
wCDB[0] = iscsi.ScsiWrite10
binary.BigEndian.PutUint32(wCDB[2:6], lba)
binary.BigEndian.PutUint16(wCDB[7:9], 1)
resp := sendSCSICmd(t, conn, wCDB, cmdSN, false, true, data, 4096)
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatalf("write LBA %d failed", lba)
}
cmdSN++
}
// Read back and verify
for lba := uint32(0); lba < 10; lba++ {
var rCDB [16]byte
rCDB[0] = iscsi.ScsiRead10
binary.BigEndian.PutUint32(rCDB[2:6], lba)
binary.BigEndian.PutUint16(rCDB[7:9], 1)
resp := sendSCSICmd(t, conn, rCDB, cmdSN, true, false, nil, 4096)
if resp.Opcode() != iscsi.OpSCSIDataIn {
t.Fatalf("LBA %d: expected Data-In", lba)
}
expected := bytes.Repeat([]byte{byte(lba)}, 4096)
if !bytes.Equal(resp.DataSegment, expected) {
t.Fatalf("LBA %d: data mismatch", lba)
}
cmdSN++
}
}
func testWriteAtBoundary(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
// Write at last valid LBA (1023)
data := bytes.Repeat([]byte{0xFF}, 4096)
var wCDB [16]byte
wCDB[0] = iscsi.ScsiWrite10
binary.BigEndian.PutUint32(wCDB[2:6], 1023)
binary.BigEndian.PutUint16(wCDB[7:9], 1)
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, data, 4096)
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatal("boundary write failed")
}
// Write past end — should fail
var wCDB2 [16]byte
wCDB2[0] = iscsi.ScsiWrite10
binary.BigEndian.PutUint32(wCDB2[2:6], 1024) // out of bounds
binary.BigEndian.PutUint16(wCDB2[7:9], 1)
resp2 := sendSCSICmd(t, conn, wCDB2, 3, false, true, data, 4096)
if resp2.SCSIStatus() == iscsi.SCSIStatusGood {
t.Fatal("OOB write should fail")
}
}
func testUnmapIntegration(t *testing.T) {
conn, _ := setupIntegrationTarget(t)
// Write data at LBA 5
writeData := bytes.Repeat([]byte{0xCC}, 4096)
var wCDB [16]byte
wCDB[0] = iscsi.ScsiWrite10
binary.BigEndian.PutUint32(wCDB[2:6], 5)
binary.BigEndian.PutUint16(wCDB[7:9], 1)
sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096)
// UNMAP LBA 5
unmapData := make([]byte, 24)
binary.BigEndian.PutUint16(unmapData[0:2], 22)
binary.BigEndian.PutUint16(unmapData[2:4], 16)
binary.BigEndian.PutUint64(unmapData[8:16], 5)
binary.BigEndian.PutUint32(unmapData[16:20], 1)
var uCDB [16]byte
uCDB[0] = iscsi.ScsiUnmap
resp := sendSCSICmd(t, conn, uCDB, 3, false, true, unmapData, uint32(len(unmapData)))
if resp.SCSIStatus() != iscsi.SCSIStatusGood {
t.Fatalf("unmap failed: %d", resp.SCSIStatus())
}
// Read back — should be zeros
var rCDB [16]byte
rCDB[0] = iscsi.ScsiRead10
binary.BigEndian.PutUint32(rCDB[2:6], 5)
binary.BigEndian.PutUint16(rCDB[7:9], 1)
resp2 := sendSCSICmd(t, conn, rCDB, 4, true, false, nil, 4096)
if resp2.Opcode() != iscsi.OpSCSIDataIn {
t.Fatal("expected Data-In")
}
zeros := make([]byte, 4096)
if !bytes.Equal(resp2.DataSegment, zeros) {
t.Fatal("unmapped block should return zeros")
}
}

396
weed/storage/blockvol/iscsi/login.go

@ -0,0 +1,396 @@
package iscsi
import (
"errors"
"strconv"
)
// Login status classes (RFC 7143, Section 11.13.5)
const (
LoginStatusSuccess uint8 = 0x00
LoginStatusRedirect uint8 = 0x01
LoginStatusInitiatorErr uint8 = 0x02
LoginStatusTargetErr uint8 = 0x03
)
// Login status details
const (
LoginDetailSuccess uint8 = 0x00
LoginDetailTargetMoved uint8 = 0x01 // redirect: permanently moved
LoginDetailTargetMovedTemp uint8 = 0x02 // redirect: temporarily moved
LoginDetailInitiatorError uint8 = 0x00 // initiator error (miscellaneous)
LoginDetailAuthFailure uint8 = 0x01 // authentication failure
LoginDetailAuthorizationFail uint8 = 0x02 // authorization failure
LoginDetailNotFound uint8 = 0x03 // target not found
LoginDetailTargetRemoved uint8 = 0x04 // target removed
LoginDetailUnsupported uint8 = 0x05 // unsupported version
LoginDetailTooManyConns uint8 = 0x06 // too many connections
LoginDetailMissingParam uint8 = 0x07 // missing parameter
LoginDetailNoSessionSlot uint8 = 0x08 // no session slot
LoginDetailNoTCPConn uint8 = 0x09 // no TCP connection available
LoginDetailNoSession uint8 = 0x0a // no existing session
LoginDetailTargetError uint8 = 0x00 // target error (miscellaneous)
LoginDetailServiceUnavail uint8 = 0x01 // service unavailable
LoginDetailOutOfResources uint8 = 0x02 // out of resources
)
var (
ErrLoginInvalidStage = errors.New("iscsi: invalid login stage transition")
ErrLoginMissingParam = errors.New("iscsi: missing required login parameter")
ErrLoginInvalidISID = errors.New("iscsi: invalid ISID")
ErrLoginTargetNotFound = errors.New("iscsi: target not found")
ErrLoginSessionExists = errors.New("iscsi: session already exists for this ISID")
ErrLoginInvalidRequest = errors.New("iscsi: invalid login request")
)
// LoginPhase tracks the current phase of login negotiation.
type LoginPhase int
const (
LoginPhaseStart LoginPhase = iota // before first PDU
LoginPhaseSecurity // CSG=SecurityNeg
LoginPhaseOperational // CSG=LoginOp
LoginPhaseDone // transition to FFP complete
)
// TargetConfig holds the target-side negotiation defaults.
type TargetConfig struct {
TargetName string
TargetAlias string
MaxRecvDataSegmentLength int
MaxBurstLength int
FirstBurstLength int
MaxConnections int
MaxOutstandingR2T int
DefaultTime2Wait int
DefaultTime2Retain int
DataPDUInOrder bool
DataSequenceInOrder bool
InitialR2T bool
ImmediateData bool
ErrorRecoveryLevel int
}
// DefaultTargetConfig returns sensible defaults for a target.
func DefaultTargetConfig() TargetConfig {
return TargetConfig{
MaxRecvDataSegmentLength: 262144, // 256KB
MaxBurstLength: 262144,
FirstBurstLength: 65536,
MaxConnections: 1,
MaxOutstandingR2T: 1,
DefaultTime2Wait: 2,
DefaultTime2Retain: 0,
DataPDUInOrder: true,
DataSequenceInOrder: true,
InitialR2T: true,
ImmediateData: true,
ErrorRecoveryLevel: 0,
}
}
// LoginNegotiator handles the target side of login negotiation.
type LoginNegotiator struct {
config TargetConfig
phase LoginPhase
isid [6]byte
tsih uint16
targetOK bool // target name validated
// Negotiated values (updated during negotiation)
NegMaxRecvDataSegLen int
NegMaxBurstLength int
NegFirstBurstLength int
NegInitialR2T bool
NegImmediateData bool
// Initiator/target info captured during login
InitiatorName string
TargetName string
SessionType string // "Normal" or "Discovery"
}
// NewLoginNegotiator creates a negotiator for a new login sequence.
func NewLoginNegotiator(config TargetConfig) *LoginNegotiator {
return &LoginNegotiator{
config: config,
phase: LoginPhaseStart,
NegMaxRecvDataSegLen: config.MaxRecvDataSegmentLength,
NegMaxBurstLength: config.MaxBurstLength,
NegFirstBurstLength: config.FirstBurstLength,
NegInitialR2T: config.InitialR2T,
NegImmediateData: config.ImmediateData,
}
}
// HandleLoginPDU processes one login request PDU and returns the response PDU.
// It manages stage transitions and parameter negotiation.
func (ln *LoginNegotiator) HandleLoginPDU(req *PDU, resolver TargetResolver) *PDU {
resp := &PDU{}
resp.SetOpcode(OpLoginResp)
resp.SetInitiatorTaskTag(req.InitiatorTaskTag())
resp.SetISID(req.ISID())
resp.SetTSIH(req.TSIH())
// Validate opcode
if req.Opcode() != OpLoginReq {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError)
return resp
}
csg := req.LoginCSG()
nsg := req.LoginNSG()
transit := req.LoginTransit()
// Parse text parameters from data segment
params, err := ParseParams(req.DataSegment)
if err != nil {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError)
return resp
}
// Process based on current stage
respParams := NewParams()
switch csg {
case StageSecurityNeg:
if ln.phase != LoginPhaseStart && ln.phase != LoginPhaseSecurity {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError)
return resp
}
ln.phase = LoginPhaseSecurity
// Capture initiator name
if name, ok := params.Get("InitiatorName"); ok {
ln.InitiatorName = name
} else if ln.InitiatorName == "" {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam)
return resp
}
// Session type
if st, ok := params.Get("SessionType"); ok {
ln.SessionType = st
}
if ln.SessionType == "" {
ln.SessionType = "Normal"
}
// Target name (required for Normal sessions)
if tn, ok := params.Get("TargetName"); ok {
if ln.SessionType == "Normal" {
if resolver == nil || !resolver.HasTarget(tn) {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailNotFound)
return resp
}
}
ln.TargetName = tn
ln.targetOK = true
} else if ln.SessionType == "Normal" && !ln.targetOK {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam)
return resp
}
// ISID
ln.isid = req.ISID()
// We don't implement CHAP — declare AuthMethod=None
respParams.Set("AuthMethod", "None")
if transit {
if nsg == StageLoginOp {
ln.phase = LoginPhaseOperational
} else if nsg == StageFullFeature {
ln.phase = LoginPhaseDone
}
}
case StageLoginOp:
if ln.phase != LoginPhaseOperational && ln.phase != LoginPhaseSecurity {
// Allow direct jump to LoginOp if security was skipped
if ln.phase == LoginPhaseStart {
// Need InitiatorName at minimum
if name, ok := params.Get("InitiatorName"); ok {
ln.InitiatorName = name
} else if ln.InitiatorName == "" {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam)
return resp
}
} else {
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError)
return resp
}
}
ln.phase = LoginPhaseOperational
// Negotiate operational parameters
ln.negotiateParams(params, respParams)
if transit && nsg == StageFullFeature {
ln.phase = LoginPhaseDone
}
default:
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError)
return resp
}
// Build response
resp.SetLoginStages(csg, nsg)
if transit {
resp.SetLoginTransit(true)
}
resp.SetLoginStatus(LoginStatusSuccess, LoginDetailSuccess)
// Assign TSIH on first successful login
if ln.tsih == 0 {
ln.tsih = 1 // simplified: single session
}
resp.SetTSIH(ln.tsih)
// Encode response params
if respParams.Len() > 0 {
resp.DataSegment = respParams.Encode()
}
return resp
}
// Done returns true if login negotiation is complete.
func (ln *LoginNegotiator) Done() bool {
return ln.phase == LoginPhaseDone
}
// Phase returns the current login phase.
func (ln *LoginNegotiator) Phase() LoginPhase {
return ln.phase
}
// negotiateParams processes operational parameter negotiation.
func (ln *LoginNegotiator) negotiateParams(req *Params, resp *Params) {
req.Each(func(key, value string) {
switch key {
case "MaxRecvDataSegmentLength":
if v, err := NegotiateNumber(value, ln.config.MaxRecvDataSegmentLength, 512, 16777215); err == nil {
ln.NegMaxRecvDataSegLen = v
resp.Set(key, strconv.Itoa(ln.config.MaxRecvDataSegmentLength))
}
case "MaxBurstLength":
if v, err := NegotiateNumber(value, ln.config.MaxBurstLength, 512, 16777215); err == nil {
ln.NegMaxBurstLength = v
resp.Set(key, strconv.Itoa(v))
}
case "FirstBurstLength":
if v, err := NegotiateNumber(value, ln.config.FirstBurstLength, 512, 16777215); err == nil {
ln.NegFirstBurstLength = v
resp.Set(key, strconv.Itoa(v))
}
case "InitialR2T":
if v, err := NegotiateBool(value, ln.config.InitialR2T); err == nil {
// InitialR2T uses OR semantics: result is Yes if either side says Yes
if value == "Yes" || ln.config.InitialR2T {
ln.NegInitialR2T = true
} else {
ln.NegInitialR2T = v
}
resp.Set(key, BoolStr(ln.NegInitialR2T))
}
case "ImmediateData":
if v, err := NegotiateBool(value, ln.config.ImmediateData); err == nil {
ln.NegImmediateData = v
resp.Set(key, BoolStr(v))
}
case "MaxConnections":
resp.Set(key, strconv.Itoa(ln.config.MaxConnections))
case "DataPDUInOrder":
resp.Set(key, BoolStr(ln.config.DataPDUInOrder))
case "DataSequenceInOrder":
resp.Set(key, BoolStr(ln.config.DataSequenceInOrder))
case "DefaultTime2Wait":
resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Wait))
case "DefaultTime2Retain":
resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Retain))
case "MaxOutstandingR2T":
resp.Set(key, strconv.Itoa(ln.config.MaxOutstandingR2T))
case "ErrorRecoveryLevel":
resp.Set(key, strconv.Itoa(ln.config.ErrorRecoveryLevel))
case "HeaderDigest":
resp.Set(key, "None")
case "DataDigest":
resp.Set(key, "None")
case "TargetName", "InitiatorName", "SessionType", "AuthMethod":
// Already handled or declarative — skip
case "TargetAlias":
// Informational from initiator — skip
default:
// Unknown keys: respond with NotUnderstood
resp.Set(key, "NotUnderstood")
}
})
// Always declare our TargetAlias if configured
if ln.config.TargetAlias != "" {
resp.Set("TargetAlias", ln.config.TargetAlias)
}
}
// TargetResolver allows the login state machine to check if a target exists.
type TargetResolver interface {
HasTarget(name string) bool
}
func setLoginReject(resp *PDU, class, detail uint8) {
resp.SetLoginStatus(class, detail)
resp.SetLoginTransit(false)
}
// LoginResult contains the outcome of a completed login negotiation.
type LoginResult struct {
InitiatorName string
TargetName string
SessionType string
ISID [6]byte
TSIH uint16
MaxRecvDataSegLen int
MaxBurstLength int
FirstBurstLength int
InitialR2T bool
ImmediateData bool
}
// Result returns the negotiation outcome. Only valid after Done() returns true.
func (ln *LoginNegotiator) Result() LoginResult {
return LoginResult{
InitiatorName: ln.InitiatorName,
TargetName: ln.TargetName,
SessionType: ln.SessionType,
ISID: ln.isid,
TSIH: ln.tsih,
MaxRecvDataSegLen: ln.NegMaxRecvDataSegLen,
MaxBurstLength: ln.NegMaxBurstLength,
FirstBurstLength: ln.NegFirstBurstLength,
InitialR2T: ln.NegInitialR2T,
ImmediateData: ln.NegImmediateData,
}
}
// BuildRedirectResponse creates a login response that redirects the initiator
// to a different target address.
func BuildRedirectResponse(req *PDU, addr string, permanent bool) *PDU {
resp := &PDU{}
resp.SetOpcode(OpLoginResp)
resp.SetInitiatorTaskTag(req.InitiatorTaskTag())
resp.SetISID(req.ISID())
detail := LoginDetailTargetMovedTemp
if permanent {
detail = LoginDetailTargetMoved
}
resp.SetLoginStatus(LoginStatusRedirect, detail)
params := NewParams()
params.Set("TargetAddress", addr)
resp.DataSegment = params.Encode()
return resp
}

444
weed/storage/blockvol/iscsi/login_test.go

@ -0,0 +1,444 @@
package iscsi
import (
"testing"
)
// mockResolver implements TargetResolver for tests.
type mockResolver struct {
targets map[string]bool
}
func (m *mockResolver) HasTarget(name string) bool {
return m.targets[name]
}
func newResolver(names ...string) *mockResolver {
r := &mockResolver{targets: make(map[string]bool)}
for _, n := range names {
r.targets[n] = true
}
return r
}
const testTargetName = "iqn.2024.com.seaweedfs:vol1"
const testInitiatorName = "iqn.2024.com.test:client1"
func TestLogin(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"single_pdu_security_to_ffp", testSinglePDUSecurityToFFP},
{"two_phase_login", testTwoPhaseLogin},
{"security_to_loginop_to_ffp", testSecurityToLoginOpToFFP},
{"missing_initiator_name", testMissingInitiatorName},
{"target_not_found", testTargetNotFound},
{"discovery_session_no_target", testDiscoverySessionNoTarget},
{"wrong_opcode", testWrongOpcode},
{"operational_negotiation", testOperationalNegotiation},
{"redirect_permanent", testRedirectPermanent},
{"redirect_temporary", testRedirectTemporary},
{"login_result", testLoginResult},
{"unknown_key_not_understood", testUnknownKeyNotUnderstood},
{"header_data_digest_none", testHeaderDataDigestNone},
{"duplicate_login_pdu", testDuplicateLoginPDU},
{"invalid_csg", testInvalidCSG},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func makeLoginReq(csg, nsg uint8, transit bool, params *Params) *PDU {
p := &PDU{}
p.SetOpcode(OpLoginReq)
p.SetLoginStages(csg, nsg)
if transit {
p.SetLoginTransit(true)
}
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01}
p.SetISID(isid)
if params != nil {
p.DataSegment = params.Encode()
}
return p
}
func testSinglePDUSecurityToFFP(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
params.Set("SessionType", "Normal")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
resp := ln.HandleLoginPDU(req, resolver)
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("status class: %d", resp.LoginStatusClass())
}
if !resp.LoginTransit() {
t.Fatal("transit should be set")
}
if !ln.Done() {
t.Fatal("should be done")
}
if resp.TSIH() == 0 {
t.Fatal("TSIH should be assigned")
}
// Verify AuthMethod=None in response
rp, err := ParseParams(resp.DataSegment)
if err != nil {
t.Fatal(err)
}
if v, ok := rp.Get("AuthMethod"); !ok || v != "None" {
t.Fatalf("AuthMethod: %q, %v", v, ok)
}
}
func testTwoPhaseLogin(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
// Phase 1: Security -> LoginOp
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
params.Set("SessionType", "Normal")
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params)
resp := ln.HandleLoginPDU(req, resolver)
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("phase 1 status: %d", resp.LoginStatusClass())
}
if ln.Done() {
t.Fatal("should not be done yet")
}
// Phase 2: LoginOp -> FFP
params2 := NewParams()
params2.Set("MaxRecvDataSegmentLength", "65536")
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2)
resp2 := ln.HandleLoginPDU(req2, resolver)
if resp2.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("phase 2 status: %d", resp2.LoginStatusClass())
}
if !ln.Done() {
t.Fatal("should be done")
}
}
func testSecurityToLoginOpToFFP(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
// Security (no transit)
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params)
resp := ln.HandleLoginPDU(req, resolver)
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("status: %d", resp.LoginStatusClass())
}
// Security -> LoginOp (transit)
req2 := makeLoginReq(StageSecurityNeg, StageLoginOp, true, NewParams())
resp2 := ln.HandleLoginPDU(req2, resolver)
if resp2.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("status: %d", resp2.LoginStatusClass())
}
// LoginOp -> FFP (transit)
req3 := makeLoginReq(StageLoginOp, StageFullFeature, true, NewParams())
resp3 := ln.HandleLoginPDU(req3, resolver)
if resp3.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("status: %d", resp3.LoginStatusClass())
}
if !ln.Done() {
t.Fatal("should be done")
}
}
func testMissingInitiatorName(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
// No InitiatorName
params := NewParams()
params.Set("TargetName", testTargetName)
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
resp := ln.HandleLoginPDU(req, resolver)
if resp.LoginStatusClass() != LoginStatusInitiatorErr {
t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass())
}
if resp.LoginStatusDetail() != LoginDetailMissingParam {
t.Fatalf("expected missing param, got %d", resp.LoginStatusDetail())
}
}
func testTargetNotFound(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver() // empty
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", "iqn.2024.com.seaweedfs:nonexistent")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
resp := ln.HandleLoginPDU(req, resolver)
if resp.LoginStatusClass() != LoginStatusInitiatorErr {
t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass())
}
if resp.LoginStatusDetail() != LoginDetailNotFound {
t.Fatalf("expected not found, got %d", resp.LoginStatusDetail())
}
}
func testDiscoverySessionNoTarget(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("SessionType", "Discovery")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
resp := ln.HandleLoginPDU(req, nil) // no resolver needed
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("status: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail())
}
if !ln.Done() {
t.Fatal("should be done")
}
if ln.SessionType != "Discovery" {
t.Fatalf("session type: %q", ln.SessionType)
}
}
func testWrongOpcode(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
req := &PDU{}
req.SetOpcode(OpSCSICmd) // wrong
resp := ln.HandleLoginPDU(req, nil)
if resp.LoginStatusClass() != LoginStatusInitiatorErr {
t.Fatalf("expected error, got %d", resp.LoginStatusClass())
}
}
func testOperationalNegotiation(t *testing.T) {
config := DefaultTargetConfig()
config.TargetAlias = "test-target"
ln := NewLoginNegotiator(config)
resolver := newResolver(testTargetName)
// First: security phase
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params)
ln.HandleLoginPDU(req, resolver)
// Second: operational negotiation
params2 := NewParams()
params2.Set("MaxRecvDataSegmentLength", "65536")
params2.Set("MaxBurstLength", "131072")
params2.Set("FirstBurstLength", "32768")
params2.Set("InitialR2T", "Yes")
params2.Set("ImmediateData", "No")
params2.Set("HeaderDigest", "CRC32C,None")
params2.Set("DataDigest", "CRC32C,None")
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2)
resp := ln.HandleLoginPDU(req2, resolver)
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("status: %d", resp.LoginStatusClass())
}
rp, err := ParseParams(resp.DataSegment)
if err != nil {
t.Fatal(err)
}
// Verify negotiated values
if v, ok := rp.Get("HeaderDigest"); !ok || v != "None" {
t.Fatalf("HeaderDigest: %q", v)
}
if v, ok := rp.Get("DataDigest"); !ok || v != "None" {
t.Fatalf("DataDigest: %q", v)
}
if v, ok := rp.Get("InitialR2T"); !ok || v != "Yes" {
t.Fatalf("InitialR2T: %q", v)
}
if v, ok := rp.Get("ImmediateData"); !ok || v != "No" {
t.Fatalf("ImmediateData: %q", v)
}
if v, ok := rp.Get("TargetAlias"); !ok || v != "test-target" {
t.Fatalf("TargetAlias: %q", v)
}
// Check negotiated state
if ln.NegInitialR2T != true {
t.Fatal("NegInitialR2T should be true")
}
if ln.NegImmediateData != false {
t.Fatal("NegImmediateData should be false")
}
}
func testRedirectPermanent(t *testing.T) {
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams())
resp := BuildRedirectResponse(req, "10.0.0.1:3260,1", true)
if resp.LoginStatusClass() != LoginStatusRedirect {
t.Fatalf("class: %d", resp.LoginStatusClass())
}
if resp.LoginStatusDetail() != LoginDetailTargetMoved {
t.Fatalf("detail: %d", resp.LoginStatusDetail())
}
rp, err := ParseParams(resp.DataSegment)
if err != nil {
t.Fatal(err)
}
if v, _ := rp.Get("TargetAddress"); v != "10.0.0.1:3260,1" {
t.Fatalf("TargetAddress: %q", v)
}
}
func testRedirectTemporary(t *testing.T) {
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams())
resp := BuildRedirectResponse(req, "192.168.1.100:3260,2", false)
if resp.LoginStatusDetail() != LoginDetailTargetMovedTemp {
t.Fatalf("detail: %d", resp.LoginStatusDetail())
}
}
func testLoginResult(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
ln.HandleLoginPDU(req, resolver)
result := ln.Result()
if result.InitiatorName != testInitiatorName {
t.Fatalf("initiator: %q", result.InitiatorName)
}
if result.SessionType != "Normal" {
t.Fatalf("session type: %q", result.SessionType)
}
if result.TSIH == 0 {
t.Fatal("TSIH should be assigned")
}
}
func testUnknownKeyNotUnderstood(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
// Security phase first
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params)
ln.HandleLoginPDU(req, resolver)
// Op phase with unknown key
params2 := NewParams()
params2.Set("X-CustomKey", "whatever")
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2)
resp := ln.HandleLoginPDU(req2, resolver)
rp, err := ParseParams(resp.DataSegment)
if err != nil {
t.Fatal(err)
}
if v, ok := rp.Get("X-CustomKey"); !ok || v != "NotUnderstood" {
t.Fatalf("expected NotUnderstood, got %q, %v", v, ok)
}
}
func testHeaderDataDigestNone(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params)
ln.HandleLoginPDU(req, resolver)
params2 := NewParams()
params2.Set("HeaderDigest", "CRC32C")
params2.Set("DataDigest", "CRC32C")
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2)
resp := ln.HandleLoginPDU(req2, resolver)
rp, _ := ParseParams(resp.DataSegment)
if v, _ := rp.Get("HeaderDigest"); v != "None" {
t.Fatalf("HeaderDigest: %q (we always negotiate None)", v)
}
if v, _ := rp.Get("DataDigest"); v != "None" {
t.Fatalf("DataDigest: %q", v)
}
}
func testDuplicateLoginPDU(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
resolver := newResolver(testTargetName)
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
// Send same security PDU twice (no transit)
req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params)
resp1 := ln.HandleLoginPDU(req, resolver)
if resp1.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("first: %d", resp1.LoginStatusClass())
}
// Same stage again should still work
resp2 := ln.HandleLoginPDU(req, resolver)
if resp2.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("second: %d", resp2.LoginStatusClass())
}
}
func testInvalidCSG(t *testing.T) {
ln := NewLoginNegotiator(DefaultTargetConfig())
// CSG=FullFeature (3) is invalid as a current stage
req := &PDU{}
req.SetOpcode(OpLoginReq)
req.SetLoginStages(StageFullFeature, StageFullFeature)
req.SetLoginTransit(true)
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01}
req.SetISID(isid)
resp := ln.HandleLoginPDU(req, nil)
if resp.LoginStatusClass() != LoginStatusInitiatorErr {
t.Fatalf("expected error for invalid CSG, got %d", resp.LoginStatusClass())
}
}

183
weed/storage/blockvol/iscsi/params.go

@ -0,0 +1,183 @@
package iscsi
import (
"errors"
"fmt"
"strconv"
"strings"
)
// iSCSI key-value text parameters (RFC 7143, Section 6).
// Parameters are encoded as "Key=Value\0" in the data segment.
var (
ErrMalformedParam = errors.New("iscsi: malformed parameter (missing '=')")
ErrEmptyKey = errors.New("iscsi: empty parameter key")
ErrDuplicateKey = errors.New("iscsi: duplicate parameter key")
)
// Params is an ordered list of iSCSI key-value parameters.
// Order matters for negotiation, so we use a slice rather than a map.
type Params struct {
items []paramItem
}
type paramItem struct {
key string
value string
}
// ParseParams decodes "Key=Value\0Key=Value\0..." from raw bytes.
// Empty trailing segments (from a trailing \0) are ignored.
// Returns ErrDuplicateKey if the same key appears more than once.
func ParseParams(data []byte) (*Params, error) {
p := &Params{}
if len(data) == 0 {
return p, nil
}
seen := make(map[string]bool)
s := string(data)
// Split by null separator
parts := strings.Split(s, "\x00")
for _, part := range parts {
if part == "" {
continue // trailing null or empty
}
idx := strings.IndexByte(part, '=')
if idx < 0 {
return nil, fmt.Errorf("%w: %q", ErrMalformedParam, part)
}
key := part[:idx]
if key == "" {
return nil, ErrEmptyKey
}
value := part[idx+1:]
if seen[key] {
return nil, fmt.Errorf("%w: %q", ErrDuplicateKey, key)
}
seen[key] = true
p.items = append(p.items, paramItem{key: key, value: value})
}
return p, nil
}
// Encode serializes parameters to "Key=Value\0" format.
func (p *Params) Encode() []byte {
if len(p.items) == 0 {
return nil
}
var b strings.Builder
for _, item := range p.items {
b.WriteString(item.key)
b.WriteByte('=')
b.WriteString(item.value)
b.WriteByte(0)
}
return []byte(b.String())
}
// Get returns the value for a key, or ("", false) if not present.
func (p *Params) Get(key string) (string, bool) {
for _, item := range p.items {
if item.key == key {
return item.value, true
}
}
return "", false
}
// Set adds or replaces a parameter. If the key already exists, its value
// is updated in place. Otherwise, the key is appended.
func (p *Params) Set(key, value string) {
for i, item := range p.items {
if item.key == key {
p.items[i].value = value
return
}
}
p.items = append(p.items, paramItem{key: key, value: value})
}
// Del removes a key if present.
func (p *Params) Del(key string) {
for i, item := range p.items {
if item.key == key {
p.items = append(p.items[:i], p.items[i+1:]...)
return
}
}
}
// Keys returns all keys in order.
func (p *Params) Keys() []string {
keys := make([]string, len(p.items))
for i, item := range p.items {
keys[i] = item.key
}
return keys
}
// Len returns the number of parameters.
func (p *Params) Len() int { return len(p.items) }
// Each iterates over all key-value pairs in order.
func (p *Params) Each(fn func(key, value string)) {
for _, item := range p.items {
fn(item.key, item.value)
}
}
// --- Negotiation helpers ---
// NegotiateNumber applies the iSCSI numeric negotiation rule:
// for "min" semantics (e.g., MaxRecvDataSegmentLength), return min(offer, ours).
// For "max" semantics (e.g., MaxBurstLength), return min(offer, ours).
// Both directions clamp to the smaller value, so min() is the general rule.
func NegotiateNumber(offered string, ours int, min, max int) (int, error) {
v, err := strconv.Atoi(offered)
if err != nil {
return 0, fmt.Errorf("iscsi: invalid numeric value %q: %w", offered, err)
}
// Clamp to valid range
if v < min {
v = min
}
if v > max {
v = max
}
// Standard negotiation: result is min(offer, ours)
if ours < v {
return ours, nil
}
return v, nil
}
// NegotiateBool applies boolean negotiation (AND semantics per RFC 7143).
// Result is true only if both sides agree to true.
func NegotiateBool(offered string, ours bool) (bool, error) {
switch offered {
case "Yes":
return ours, nil
case "No":
return false, nil
default:
return false, fmt.Errorf("iscsi: invalid boolean value %q", offered)
}
}
// BoolStr returns "Yes" or "No" for a boolean value.
func BoolStr(v bool) string {
if v {
return "Yes"
}
return "No"
}
// NewParams creates a new empty Params.
func NewParams() *Params {
return &Params{}
}

357
weed/storage/blockvol/iscsi/params_test.go

@ -0,0 +1,357 @@
package iscsi
import (
"bytes"
"strings"
"testing"
)
func TestParams(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"parse_single", testParseSingle},
{"parse_multiple", testParseMultiple},
{"parse_empty", testParseEmpty},
{"parse_trailing_null", testParseTrailingNull},
{"parse_malformed_no_equals", testParseMalformedNoEquals},
{"parse_empty_key", testParseEmptyKey},
{"parse_empty_value", testParseEmptyValue},
{"parse_duplicate_key", testParseDuplicateKey},
{"roundtrip", testParamsRoundtrip},
{"encode_empty", testEncodeEmpty},
{"set_new_key", testSetNewKey},
{"set_existing_key", testSetExistingKey},
{"del_key", testDelKey},
{"del_nonexistent", testDelNonexistent},
{"keys_order", testKeysOrder},
{"each", testEach},
{"negotiate_number_min", testNegotiateNumberMin},
{"negotiate_number_clamp", testNegotiateNumberClamp},
{"negotiate_number_invalid", testNegotiateNumberInvalid},
{"negotiate_bool_and", testNegotiateBoolAnd},
{"negotiate_bool_invalid", testNegotiateBoolInvalid},
{"bool_str", testBoolStr},
{"value_with_equals", testValueWithEquals},
{"parse_binary_data_value", testParseBinaryDataValue},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testParseSingle(t *testing.T) {
data := []byte("TargetName=iqn.2024.com.seaweedfs:vol1\x00")
p, err := ParseParams(data)
if err != nil {
t.Fatal(err)
}
if p.Len() != 1 {
t.Fatalf("expected 1 param, got %d", p.Len())
}
v, ok := p.Get("TargetName")
if !ok || v != "iqn.2024.com.seaweedfs:vol1" {
t.Fatalf("got %q, %v", v, ok)
}
}
func testParseMultiple(t *testing.T) {
data := []byte("InitiatorName=iqn.2024.com.test\x00TargetName=iqn.2024.com.seaweedfs:vol1\x00SessionType=Normal\x00")
p, err := ParseParams(data)
if err != nil {
t.Fatal(err)
}
if p.Len() != 3 {
t.Fatalf("expected 3 params, got %d", p.Len())
}
keys := p.Keys()
if keys[0] != "InitiatorName" || keys[1] != "TargetName" || keys[2] != "SessionType" {
t.Fatalf("wrong order: %v", keys)
}
}
func testParseEmpty(t *testing.T) {
p, err := ParseParams(nil)
if err != nil {
t.Fatal(err)
}
if p.Len() != 0 {
t.Fatalf("expected 0, got %d", p.Len())
}
p2, err := ParseParams([]byte{})
if err != nil {
t.Fatal(err)
}
if p2.Len() != 0 {
t.Fatalf("expected 0, got %d", p2.Len())
}
}
func testParseTrailingNull(t *testing.T) {
// Multiple trailing nulls should not cause errors
data := []byte("Key=Value\x00\x00\x00")
p, err := ParseParams(data)
if err != nil {
t.Fatal(err)
}
if p.Len() != 1 {
t.Fatalf("expected 1, got %d", p.Len())
}
}
func testParseMalformedNoEquals(t *testing.T) {
data := []byte("KeyWithoutValue\x00")
_, err := ParseParams(data)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "malformed") {
t.Fatalf("unexpected error: %v", err)
}
}
func testParseEmptyKey(t *testing.T) {
data := []byte("=Value\x00")
_, err := ParseParams(data)
if err != ErrEmptyKey {
t.Fatalf("expected ErrEmptyKey, got %v", err)
}
}
func testParseEmptyValue(t *testing.T) {
// Empty value is valid in iSCSI (e.g., reject with empty value)
data := []byte("Key=\x00")
p, err := ParseParams(data)
if err != nil {
t.Fatal(err)
}
v, ok := p.Get("Key")
if !ok || v != "" {
t.Fatalf("got %q, %v", v, ok)
}
}
func testParseDuplicateKey(t *testing.T) {
data := []byte("Key=Value1\x00Key=Value2\x00")
_, err := ParseParams(data)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "duplicate") {
t.Fatalf("unexpected error: %v", err)
}
}
func testParamsRoundtrip(t *testing.T) {
original := []byte("InitiatorName=iqn.2024.com.test\x00MaxRecvDataSegmentLength=65536\x00")
p, err := ParseParams(original)
if err != nil {
t.Fatal(err)
}
encoded := p.Encode()
if !bytes.Equal(encoded, original) {
t.Fatalf("roundtrip mismatch:\n got: %q\n want: %q", encoded, original)
}
}
func testEncodeEmpty(t *testing.T) {
p := NewParams()
if encoded := p.Encode(); encoded != nil {
t.Fatalf("expected nil, got %q", encoded)
}
}
func testSetNewKey(t *testing.T) {
p := NewParams()
p.Set("Key1", "Val1")
p.Set("Key2", "Val2")
if p.Len() != 2 {
t.Fatalf("expected 2, got %d", p.Len())
}
v, _ := p.Get("Key2")
if v != "Val2" {
t.Fatalf("got %q", v)
}
}
func testSetExistingKey(t *testing.T) {
p := NewParams()
p.Set("Key", "Old")
p.Set("Key", "New")
if p.Len() != 1 {
t.Fatalf("expected 1, got %d", p.Len())
}
v, _ := p.Get("Key")
if v != "New" {
t.Fatalf("got %q", v)
}
}
func testDelKey(t *testing.T) {
p := NewParams()
p.Set("A", "1")
p.Set("B", "2")
p.Set("C", "3")
p.Del("B")
if p.Len() != 2 {
t.Fatalf("expected 2, got %d", p.Len())
}
if _, ok := p.Get("B"); ok {
t.Fatal("B should be deleted")
}
keys := p.Keys()
if keys[0] != "A" || keys[1] != "C" {
t.Fatalf("wrong keys: %v", keys)
}
}
func testDelNonexistent(t *testing.T) {
p := NewParams()
p.Set("A", "1")
p.Del("Z") // should not panic
if p.Len() != 1 {
t.Fatalf("expected 1, got %d", p.Len())
}
}
func testKeysOrder(t *testing.T) {
data := []byte("Z=1\x00A=2\x00M=3\x00")
p, err := ParseParams(data)
if err != nil {
t.Fatal(err)
}
keys := p.Keys()
if keys[0] != "Z" || keys[1] != "A" || keys[2] != "M" {
t.Fatalf("order not preserved: %v", keys)
}
}
func testEach(t *testing.T) {
p := NewParams()
p.Set("A", "1")
p.Set("B", "2")
var collected []string
p.Each(func(k, v string) {
collected = append(collected, k+"="+v)
})
if len(collected) != 2 || collected[0] != "A=1" || collected[1] != "B=2" {
t.Fatalf("Each: %v", collected)
}
}
func testNegotiateNumberMin(t *testing.T) {
// Both sides offer a value, result is min
result, err := NegotiateNumber("65536", 262144, 512, 16777215)
if err != nil {
t.Fatal(err)
}
if result != 65536 {
t.Fatalf("expected 65536, got %d", result)
}
// Our value is smaller
result, err = NegotiateNumber("262144", 65536, 512, 16777215)
if err != nil {
t.Fatal(err)
}
if result != 65536 {
t.Fatalf("expected 65536, got %d", result)
}
}
func testNegotiateNumberClamp(t *testing.T) {
// Offered value below minimum
result, err := NegotiateNumber("100", 65536, 512, 16777215)
if err != nil {
t.Fatal(err)
}
if result != 512 {
t.Fatalf("expected 512, got %d", result)
}
// Offered value above maximum
result, err = NegotiateNumber("99999999", 65536, 512, 16777215)
if err != nil {
t.Fatal(err)
}
if result != 65536 {
t.Fatalf("expected 65536, got %d", result)
}
}
func testNegotiateNumberInvalid(t *testing.T) {
_, err := NegotiateNumber("notanumber", 65536, 512, 16777215)
if err == nil {
t.Fatal("expected error")
}
}
func testNegotiateBoolAnd(t *testing.T) {
// Both yes
r, err := NegotiateBool("Yes", true)
if err != nil || !r {
t.Fatalf("Yes+true = %v, %v", r, err)
}
// Initiator yes, target no
r, err = NegotiateBool("Yes", false)
if err != nil || r {
t.Fatalf("Yes+false = %v, %v", r, err)
}
// Initiator no, target yes
r, err = NegotiateBool("No", true)
if err != nil || r {
t.Fatalf("No+true = %v, %v", r, err)
}
// Both no
r, err = NegotiateBool("No", false)
if err != nil || r {
t.Fatalf("No+false = %v, %v", r, err)
}
}
func testNegotiateBoolInvalid(t *testing.T) {
_, err := NegotiateBool("Maybe", true)
if err == nil {
t.Fatal("expected error")
}
}
func testBoolStr(t *testing.T) {
if BoolStr(true) != "Yes" {
t.Fatal("true should be Yes")
}
if BoolStr(false) != "No" {
t.Fatal("false should be No")
}
}
func testValueWithEquals(t *testing.T) {
// Value containing '=' is valid (only first '=' splits key from value)
data := []byte("TargetAddress=10.0.0.1:3260,1\x00Key=a=b=c\x00")
p, err := ParseParams(data)
if err != nil {
t.Fatal(err)
}
v, ok := p.Get("Key")
if !ok || v != "a=b=c" {
t.Fatalf("got %q, %v", v, ok)
}
}
func testParseBinaryDataValue(t *testing.T) {
// Values can contain arbitrary bytes except null
data := []byte("Key=\x01\x02\x03\x00")
p, err := ParseParams(data)
if err != nil {
t.Fatal(err)
}
v, ok := p.Get("Key")
if !ok || v != "\x01\x02\x03" {
t.Fatalf("got %q, %v", v, ok)
}
}

447
weed/storage/blockvol/iscsi/pdu.go

@ -0,0 +1,447 @@
package iscsi
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
// BHS (Basic Header Segment) is 48 bytes, big-endian.
// RFC 7143, Section 12.1
const BHSLength = 48
// Opcode constants — initiator opcodes (RFC 7143, Section 12.1.1)
const (
OpNOPOut uint8 = 0x00
OpSCSICmd uint8 = 0x01
OpSCSITaskMgmt uint8 = 0x02
OpLoginReq uint8 = 0x03
OpTextReq uint8 = 0x04
OpSCSIDataOut uint8 = 0x05
OpLogoutReq uint8 = 0x06
OpSNACKReq uint8 = 0x0c
)
// Target opcodes (RFC 7143, Section 12.1.1)
const (
OpNOPIn uint8 = 0x20
OpSCSIResp uint8 = 0x21
OpSCSITaskResp uint8 = 0x22
OpLoginResp uint8 = 0x23
OpTextResp uint8 = 0x24
OpSCSIDataIn uint8 = 0x25
OpLogoutResp uint8 = 0x26
OpR2T uint8 = 0x31
OpAsyncMsg uint8 = 0x32
OpReject uint8 = 0x3f
)
// BHS flag masks
const (
opcMask = 0x3f // lower 6 bits of byte 0
FlagI = 0x40 // Immediate delivery bit (byte 0)
FlagF = 0x80 // Final bit (byte 1)
FlagR = 0x40 // Read bit (byte 1, SCSI command)
FlagW = 0x20 // Write bit (byte 1, SCSI command)
FlagC = 0x40 // Continue bit (byte 1, login)
FlagT = 0x80 // Transit bit (byte 1, login)
FlagS = 0x01 // Status bit (byte 1, Data-In)
FlagU = 0x02 // Underflow bit (byte 1, SCSI Response)
FlagO = 0x04 // Overflow bit (byte 1, SCSI Response)
FlagBiU = 0x08 // Bidi underflow (byte 1, SCSI Response)
FlagBiO = 0x10 // Bidi overflow (byte 1, SCSI Response)
FlagA = 0x40 // Acknowledge bit (byte 1, Data-Out)
)
// Login stage constants (CSG/NSG, 2-bit fields in byte 1 of login PDU)
const (
StageSecurityNeg uint8 = 0 // Security Negotiation
StageLoginOp uint8 = 1 // Login Operational Negotiation
StageFullFeature uint8 = 3 // Full Feature Phase
)
// SCSI status codes
const (
SCSIStatusGood uint8 = 0x00
SCSIStatusCheckCond uint8 = 0x02
SCSIStatusBusy uint8 = 0x08
SCSIStatusResvConflict uint8 = 0x18
)
// iSCSI response codes
const (
ISCSIRespCompleted uint8 = 0x00
)
// MaxDataSegmentLength limits the maximum data segment we'll accept.
// The iSCSI data segment length is a 3-byte field (max 16MB-1).
// We cap at 8MB which is well above typical MaxRecvDataSegmentLength values.
const MaxDataSegmentLength = 8 * 1024 * 1024 // 8 MB
var (
ErrPDUTruncated = errors.New("iscsi: PDU truncated")
ErrPDUTooLarge = errors.New("iscsi: data segment exceeds maximum length")
ErrInvalidAHSLength = errors.New("iscsi: invalid AHS length (not multiple of 4)")
ErrUnknownOpcode = errors.New("iscsi: unknown opcode")
)
// PDU represents a full iSCSI Protocol Data Unit.
type PDU struct {
BHS [BHSLength]byte
AHS []byte // Additional Header Segment (multiple of 4 bytes)
DataSegment []byte // Data segment (padded to 4-byte boundary on wire)
}
// --- BHS field accessors ---
// Opcode returns the opcode (lower 6 bits of byte 0).
func (p *PDU) Opcode() uint8 { return p.BHS[0] & opcMask }
// SetOpcode sets the opcode (lower 6 bits of byte 0).
func (p *PDU) SetOpcode(op uint8) {
p.BHS[0] = (p.BHS[0] & ^uint8(opcMask)) | (op & opcMask)
}
// Immediate returns true if the immediate delivery bit is set.
func (p *PDU) Immediate() bool { return p.BHS[0]&FlagI != 0 }
// SetImmediate sets the immediate delivery bit.
func (p *PDU) SetImmediate(v bool) {
if v {
p.BHS[0] |= FlagI
} else {
p.BHS[0] &^= FlagI
}
}
// OpSpecific1 returns byte 1 (opcode-specific flags).
func (p *PDU) OpSpecific1() uint8 { return p.BHS[1] }
// SetOpSpecific1 sets byte 1.
func (p *PDU) SetOpSpecific1(v uint8) { p.BHS[1] = v }
// TotalAHSLength returns the total AHS length in 4-byte words (byte 4).
func (p *PDU) TotalAHSLength() uint8 { return p.BHS[4] }
// DataSegmentLength returns the 3-byte data segment length (bytes 5-7).
func (p *PDU) DataSegmentLength() uint32 {
return uint32(p.BHS[5])<<16 | uint32(p.BHS[6])<<8 | uint32(p.BHS[7])
}
// SetDataSegmentLength sets the 3-byte data segment length field.
func (p *PDU) SetDataSegmentLength(n uint32) {
p.BHS[5] = byte(n >> 16)
p.BHS[6] = byte(n >> 8)
p.BHS[7] = byte(n)
}
// LUN returns the 8-byte LUN field (bytes 8-15).
func (p *PDU) LUN() uint64 { return binary.BigEndian.Uint64(p.BHS[8:16]) }
// SetLUN sets the LUN field.
func (p *PDU) SetLUN(lun uint64) { binary.BigEndian.PutUint64(p.BHS[8:16], lun) }
// InitiatorTaskTag returns bytes 16-19.
func (p *PDU) InitiatorTaskTag() uint32 { return binary.BigEndian.Uint32(p.BHS[16:20]) }
// SetInitiatorTaskTag sets bytes 16-19.
func (p *PDU) SetInitiatorTaskTag(tag uint32) { binary.BigEndian.PutUint32(p.BHS[16:20], tag) }
// Field32 reads a generic 4-byte field at the given BHS offset.
func (p *PDU) Field32(offset int) uint32 { return binary.BigEndian.Uint32(p.BHS[offset : offset+4]) }
// SetField32 writes a generic 4-byte field at the given BHS offset.
func (p *PDU) SetField32(offset int, v uint32) {
binary.BigEndian.PutUint32(p.BHS[offset:offset+4], v)
}
// --- Common BHS offsets for named fields ---
// TSIH returns the Target Session Identifying Handle (bytes 14-15, login PDU).
func (p *PDU) TSIH() uint16 { return binary.BigEndian.Uint16(p.BHS[14:16]) }
// SetTSIH sets the TSIH field.
func (p *PDU) SetTSIH(v uint16) { binary.BigEndian.PutUint16(p.BHS[14:16], v) }
// CmdSN returns bytes 24-27.
func (p *PDU) CmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) }
// SetCmdSN sets bytes 24-27.
func (p *PDU) SetCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) }
// ExpStatSN returns bytes 28-31.
func (p *PDU) ExpStatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) }
// SetExpStatSN sets bytes 28-31.
func (p *PDU) SetExpStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) }
// StatSN returns bytes 24-27 (target response PDUs).
func (p *PDU) StatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) }
// SetStatSN sets bytes 24-27.
func (p *PDU) SetStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) }
// ExpCmdSN returns bytes 28-31 (target response PDUs).
func (p *PDU) ExpCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) }
// SetExpCmdSN sets bytes 28-31.
func (p *PDU) SetExpCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) }
// MaxCmdSN returns bytes 32-35 (target response PDUs).
func (p *PDU) MaxCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[32:36]) }
// SetMaxCmdSN sets bytes 32-35.
func (p *PDU) SetMaxCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[32:36], v) }
// DataSN returns bytes 36-39 (Data-In / Data-Out PDUs).
func (p *PDU) DataSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) }
// SetDataSN sets bytes 36-39.
func (p *PDU) SetDataSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) }
// BufferOffset returns bytes 40-43 (Data-In / Data-Out PDUs).
func (p *PDU) BufferOffset() uint32 { return binary.BigEndian.Uint32(p.BHS[40:44]) }
// SetBufferOffset sets bytes 40-43.
func (p *PDU) SetBufferOffset(v uint32) { binary.BigEndian.PutUint32(p.BHS[40:44], v) }
// R2TSN returns bytes 36-39 (R2T PDUs).
func (p *PDU) R2TSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) }
// SetR2TSN sets bytes 36-39.
func (p *PDU) SetR2TSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) }
// DesiredDataLength returns bytes 44-47 (R2T PDUs).
func (p *PDU) DesiredDataLength() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) }
// SetDesiredDataLength sets bytes 44-47.
func (p *PDU) SetDesiredDataLength(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) }
// ExpectedDataTransferLength returns bytes 20-23 (SCSI Command PDUs).
func (p *PDU) ExpectedDataTransferLength() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) }
// SetExpectedDataTransferLength sets bytes 20-23.
func (p *PDU) SetExpectedDataTransferLength(v uint32) {
binary.BigEndian.PutUint32(p.BHS[20:24], v)
}
// TargetTransferTag returns bytes 20-23 (target PDUs: R2T, Data-In, etc.).
func (p *PDU) TargetTransferTag() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) }
// SetTargetTransferTag sets bytes 20-23.
func (p *PDU) SetTargetTransferTag(v uint32) { binary.BigEndian.PutUint32(p.BHS[20:24], v) }
// ISID returns the 6-byte Initiator Session ID (bytes 8-13, login PDU).
func (p *PDU) ISID() [6]byte {
var id [6]byte
copy(id[:], p.BHS[8:14])
return id
}
// SetISID sets the 6-byte ISID field.
func (p *PDU) SetISID(id [6]byte) { copy(p.BHS[8:14], id[:]) }
// CDB returns the 16-byte SCSI CDB from the BHS (bytes 32-47, SCSI Command PDU).
func (p *PDU) CDB() [16]byte {
var cdb [16]byte
copy(cdb[:], p.BHS[32:48])
return cdb
}
// SetCDB sets the 16-byte SCSI CDB in the BHS.
func (p *PDU) SetCDB(cdb [16]byte) { copy(p.BHS[32:48], cdb[:]) }
// ResidualCount returns bytes 44-47 (SCSI Response PDUs).
func (p *PDU) ResidualCount() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) }
// SetResidualCount sets bytes 44-47.
func (p *PDU) SetResidualCount(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) }
// --- Wire I/O ---
// pad4 rounds n up to the next multiple of 4.
func pad4(n uint32) uint32 { return (n + 3) &^ 3 }
// ReadPDU reads a complete PDU from r.
func ReadPDU(r io.Reader) (*PDU, error) {
p := &PDU{}
// Read BHS (48 bytes)
if _, err := io.ReadFull(r, p.BHS[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, ErrPDUTruncated
}
return nil, err
}
// AHS
ahsLen := uint32(p.TotalAHSLength()) * 4
if ahsLen > 0 {
p.AHS = make([]byte, ahsLen)
if _, err := io.ReadFull(r, p.AHS); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, ErrPDUTruncated
}
return nil, err
}
}
// Data segment
dsLen := p.DataSegmentLength()
if dsLen > MaxDataSegmentLength {
return nil, fmt.Errorf("%w: %d bytes", ErrPDUTooLarge, dsLen)
}
if dsLen > 0 {
paddedLen := pad4(dsLen)
buf := make([]byte, paddedLen)
if _, err := io.ReadFull(r, buf); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, ErrPDUTruncated
}
return nil, err
}
p.DataSegment = buf[:dsLen] // strip padding
}
return p, nil
}
// WritePDU writes a complete PDU to w, with proper padding.
func WritePDU(w io.Writer, p *PDU) error {
// Update header lengths from actual AHS/DataSegment
if len(p.AHS) > 0 {
if len(p.AHS)%4 != 0 {
return ErrInvalidAHSLength
}
p.BHS[4] = uint8(len(p.AHS) / 4)
} else {
p.BHS[4] = 0
}
p.SetDataSegmentLength(uint32(len(p.DataSegment)))
// Write BHS
if _, err := w.Write(p.BHS[:]); err != nil {
return err
}
// Write AHS
if len(p.AHS) > 0 {
if _, err := w.Write(p.AHS); err != nil {
return err
}
}
// Write data segment with padding
if len(p.DataSegment) > 0 {
if _, err := w.Write(p.DataSegment); err != nil {
return err
}
padLen := pad4(uint32(len(p.DataSegment))) - uint32(len(p.DataSegment))
if padLen > 0 {
var pad [3]byte
if _, err := w.Write(pad[:padLen]); err != nil {
return err
}
}
}
return nil
}
// OpcodeName returns a human-readable name for the given opcode.
func OpcodeName(op uint8) string {
switch op {
case OpNOPOut:
return "NOP-Out"
case OpSCSICmd:
return "SCSI-Command"
case OpSCSITaskMgmt:
return "SCSI-Task-Mgmt"
case OpLoginReq:
return "Login-Request"
case OpTextReq:
return "Text-Request"
case OpSCSIDataOut:
return "SCSI-Data-Out"
case OpLogoutReq:
return "Logout-Request"
case OpSNACKReq:
return "SNACK-Request"
case OpNOPIn:
return "NOP-In"
case OpSCSIResp:
return "SCSI-Response"
case OpSCSITaskResp:
return "SCSI-Task-Mgmt-Response"
case OpLoginResp:
return "Login-Response"
case OpTextResp:
return "Text-Response"
case OpSCSIDataIn:
return "SCSI-Data-In"
case OpLogoutResp:
return "Logout-Response"
case OpR2T:
return "R2T"
case OpAsyncMsg:
return "Async-Message"
case OpReject:
return "Reject"
default:
return fmt.Sprintf("Unknown(0x%02x)", op)
}
}
// Login PDU helpers
// LoginCSG returns the Current Stage from byte 1 of a login PDU (bits 3-2).
func (p *PDU) LoginCSG() uint8 { return (p.BHS[1] >> 2) & 0x03 }
// LoginNSG returns the Next Stage from byte 1 of a login PDU (bits 1-0).
func (p *PDU) LoginNSG() uint8 { return p.BHS[1] & 0x03 }
// SetLoginStages sets the CSG and NSG fields in byte 1, preserving T and C flags.
func (p *PDU) SetLoginStages(csg, nsg uint8) {
p.BHS[1] = (p.BHS[1] & 0xF0) | ((csg & 0x03) << 2) | (nsg & 0x03)
}
// LoginTransit returns true if the Transit bit is set (byte 1, bit 7).
func (p *PDU) LoginTransit() bool { return p.BHS[1]&FlagT != 0 }
// SetLoginTransit sets or clears the Transit bit.
func (p *PDU) SetLoginTransit(v bool) {
if v {
p.BHS[1] |= FlagT
} else {
p.BHS[1] &^= FlagT
}
}
// LoginContinue returns true if the Continue bit is set.
func (p *PDU) LoginContinue() bool { return p.BHS[1]&FlagC != 0 }
// LoginStatusClass returns the status class (byte 36) of a login response.
func (p *PDU) LoginStatusClass() uint8 { return p.BHS[36] }
// LoginStatusDetail returns the status detail (byte 37) of a login response.
func (p *PDU) LoginStatusDetail() uint8 { return p.BHS[37] }
// SetLoginStatus sets the status class and detail in a login response.
func (p *PDU) SetLoginStatus(class, detail uint8) {
p.BHS[36] = class
p.BHS[37] = detail
}
// SCSIResponse returns the iSCSI response byte (byte 2) of a SCSI Response PDU.
func (p *PDU) SCSIResponse() uint8 { return p.BHS[2] }
// SetSCSIResponse sets byte 2.
func (p *PDU) SetSCSIResponse(v uint8) { p.BHS[2] = v }
// SCSIStatus returns the SCSI status byte (byte 3) of a SCSI Response PDU.
func (p *PDU) SCSIStatus() uint8 { return p.BHS[3] }
// SetSCSIStatus sets byte 3.
func (p *PDU) SetSCSIStatus(v uint8) { p.BHS[3] = v }

559
weed/storage/blockvol/iscsi/pdu_test.go

@ -0,0 +1,559 @@
package iscsi
import (
"bytes"
"encoding/binary"
"io"
"strings"
"testing"
)
func TestPDU(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"roundtrip_bhs_only", testRoundtripBHSOnly},
{"roundtrip_with_data", testRoundtripWithData},
{"roundtrip_with_ahs_and_data", testRoundtripWithAHSAndData},
{"data_padding", testDataPadding},
{"opcode_accessors", testOpcodeAccessors},
{"immediate_flag", testImmediateFlag},
{"field32_accessors", testField32Accessors},
{"lun_accessor", testLUNAccessor},
{"isid_accessor", testISIDAccessor},
{"cdb_accessor", testCDBAccessor},
{"login_stage_accessors", testLoginStageAccessors},
{"login_transit_continue", testLoginTransitContinue},
{"login_status", testLoginStatus},
{"scsi_response_status", testSCSIResponseStatus},
{"data_segment_length_3byte", testDataSegmentLength3Byte},
{"tsih_accessor", testTSIHAccessor},
{"cmdsn_expstatsn", testCmdSNExpStatSN},
{"statsn_expcmdsn_maxcmdsn", testStatSNExpCmdSNMaxCmdSN},
{"datasn_bufferoffset", testDataSNBufferOffset},
{"r2t_fields", testR2TFields},
{"residual_count", testResidualCount},
{"expected_data_transfer_length", testExpectedDataTransferLength},
{"target_transfer_tag", testTargetTransferTag},
{"opcode_name", testOpcodeName},
{"read_truncated_bhs", testReadTruncatedBHS},
{"read_truncated_data", testReadTruncatedData},
{"read_truncated_ahs", testReadTruncatedAHS},
{"read_oversized_data", testReadOversizedData},
{"read_eof", testReadEOF},
{"write_invalid_ahs_length", testWriteInvalidAHSLength},
{"roundtrip_all_opcodes", testRoundtripAllOpcodes},
{"data_segment_exact_4byte_boundary", testDataSegmentExact4ByteBoundary},
{"zero_length_data_segment", testZeroLengthDataSegment},
{"max_3byte_data_length", testMax3ByteDataLength},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testRoundtripBHSOnly(t *testing.T) {
p := &PDU{}
p.SetOpcode(OpSCSICmd)
p.SetInitiatorTaskTag(0xDEADBEEF)
p.SetCmdSN(42)
var buf bytes.Buffer
if err := WritePDU(&buf, p); err != nil {
t.Fatal(err)
}
if buf.Len() != BHSLength {
t.Fatalf("expected %d bytes, got %d", BHSLength, buf.Len())
}
p2, err := ReadPDU(&buf)
if err != nil {
t.Fatal(err)
}
if p2.Opcode() != OpSCSICmd {
t.Fatalf("opcode mismatch: got 0x%02x", p2.Opcode())
}
if p2.InitiatorTaskTag() != 0xDEADBEEF {
t.Fatalf("ITT mismatch: got 0x%x", p2.InitiatorTaskTag())
}
if p2.CmdSN() != 42 {
t.Fatalf("CmdSN mismatch: got %d", p2.CmdSN())
}
}
func testRoundtripWithData(t *testing.T) {
p := &PDU{}
p.SetOpcode(OpSCSIDataIn)
p.DataSegment = []byte("hello, iSCSI world!")
var buf bytes.Buffer
if err := WritePDU(&buf, p); err != nil {
t.Fatal(err)
}
// BHS(48) + data(19) + padding(1) = 68
expectedLen := 48 + pad4(uint32(len(p.DataSegment)))
if uint32(buf.Len()) != expectedLen {
t.Fatalf("wire length: expected %d, got %d", expectedLen, buf.Len())
}
p2, err := ReadPDU(&buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(p2.DataSegment, []byte("hello, iSCSI world!")) {
t.Fatalf("data mismatch: %q", p2.DataSegment)
}
}
func testRoundtripWithAHSAndData(t *testing.T) {
p := &PDU{}
p.SetOpcode(OpSCSICmd)
p.AHS = make([]byte, 8) // 2 words
p.AHS[0] = 0xAA
p.AHS[7] = 0xBB
p.DataSegment = []byte{1, 2, 3, 4, 5}
var buf bytes.Buffer
if err := WritePDU(&buf, p); err != nil {
t.Fatal(err)
}
p2, err := ReadPDU(&buf)
if err != nil {
t.Fatal(err)
}
if len(p2.AHS) != 8 {
t.Fatalf("AHS length: expected 8, got %d", len(p2.AHS))
}
if p2.AHS[0] != 0xAA || p2.AHS[7] != 0xBB {
t.Fatal("AHS content mismatch")
}
if !bytes.Equal(p2.DataSegment, []byte{1, 2, 3, 4, 5}) {
t.Fatal("data segment mismatch")
}
}
func testDataPadding(t *testing.T) {
for _, dataLen := range []int{1, 2, 3, 4, 5, 7, 8, 100, 1023} {
p := &PDU{}
p.SetOpcode(OpSCSIDataIn)
p.DataSegment = bytes.Repeat([]byte{0xFF}, dataLen)
var buf bytes.Buffer
if err := WritePDU(&buf, p); err != nil {
t.Fatalf("dataLen=%d: write: %v", dataLen, err)
}
expectedWire := BHSLength + int(pad4(uint32(dataLen)))
if buf.Len() != expectedWire {
t.Fatalf("dataLen=%d: wire=%d, expected=%d", dataLen, buf.Len(), expectedWire)
}
p2, err := ReadPDU(&buf)
if err != nil {
t.Fatalf("dataLen=%d: read: %v", dataLen, err)
}
if len(p2.DataSegment) != dataLen {
t.Fatalf("dataLen=%d: got %d", dataLen, len(p2.DataSegment))
}
}
}
func testOpcodeAccessors(t *testing.T) {
p := &PDU{}
// Set immediate first, then opcode — verify no interference
p.SetImmediate(true)
p.SetOpcode(OpLoginReq)
if p.Opcode() != OpLoginReq {
t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpLoginReq)
}
if !p.Immediate() {
t.Fatal("immediate flag lost")
}
// Change opcode, verify immediate preserved
p.SetOpcode(OpTextReq)
if p.Opcode() != OpTextReq {
t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpTextReq)
}
if !p.Immediate() {
t.Fatal("immediate flag lost after opcode change")
}
}
func testImmediateFlag(t *testing.T) {
p := &PDU{}
if p.Immediate() {
t.Fatal("should start false")
}
p.SetImmediate(true)
if !p.Immediate() {
t.Fatal("should be true")
}
p.SetImmediate(false)
if p.Immediate() {
t.Fatal("should be false again")
}
}
func testField32Accessors(t *testing.T) {
p := &PDU{}
p.SetField32(20, 0x12345678)
if got := p.Field32(20); got != 0x12345678 {
t.Fatalf("got 0x%08x", got)
}
}
func testLUNAccessor(t *testing.T) {
p := &PDU{}
p.SetLUN(0x0001000000000000) // LUN 1 in SAM encoding
if p.LUN() != 0x0001000000000000 {
t.Fatalf("LUN mismatch: 0x%016x", p.LUN())
}
}
func testISIDAccessor(t *testing.T) {
p := &PDU{}
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01}
p.SetISID(isid)
got := p.ISID()
if got != isid {
t.Fatalf("ISID mismatch: %v", got)
}
}
func testCDBAccessor(t *testing.T) {
p := &PDU{}
var cdb [16]byte
cdb[0] = 0x28 // READ_10
cdb[1] = 0x00
binary.BigEndian.PutUint32(cdb[2:6], 100) // LBA
binary.BigEndian.PutUint16(cdb[7:9], 8) // transfer length
p.SetCDB(cdb)
got := p.CDB()
if got != cdb {
t.Fatal("CDB mismatch")
}
}
func testLoginStageAccessors(t *testing.T) {
p := &PDU{}
p.SetLoginStages(StageSecurityNeg, StageLoginOp)
if p.LoginCSG() != StageSecurityNeg {
t.Fatalf("CSG: got %d", p.LoginCSG())
}
if p.LoginNSG() != StageLoginOp {
t.Fatalf("NSG: got %d", p.LoginNSG())
}
// Change to LoginOp -> FullFeature
p.SetLoginStages(StageLoginOp, StageFullFeature)
if p.LoginCSG() != StageLoginOp {
t.Fatalf("CSG: got %d", p.LoginCSG())
}
if p.LoginNSG() != StageFullFeature {
t.Fatalf("NSG: got %d", p.LoginNSG())
}
}
func testLoginTransitContinue(t *testing.T) {
p := &PDU{}
if p.LoginTransit() {
t.Fatal("Transit should start false")
}
if p.LoginContinue() {
t.Fatal("Continue should start false")
}
p.SetLoginTransit(true)
p.SetLoginStages(StageLoginOp, StageFullFeature)
if !p.LoginTransit() {
t.Fatal("Transit should be true")
}
// Verify stages preserved
if p.LoginCSG() != StageLoginOp {
t.Fatal("CSG lost after setting transit")
}
}
func testLoginStatus(t *testing.T) {
p := &PDU{}
p.SetLoginStatus(0x02, 0x01) // Initiator error, authentication failure
if p.LoginStatusClass() != 0x02 {
t.Fatalf("class: got %d", p.LoginStatusClass())
}
if p.LoginStatusDetail() != 0x01 {
t.Fatalf("detail: got %d", p.LoginStatusDetail())
}
}
func testSCSIResponseStatus(t *testing.T) {
p := &PDU{}
p.SetSCSIResponse(ISCSIRespCompleted)
p.SetSCSIStatus(SCSIStatusGood)
if p.SCSIResponse() != ISCSIRespCompleted {
t.Fatal("response mismatch")
}
if p.SCSIStatus() != SCSIStatusGood {
t.Fatal("status mismatch")
}
}
func testDataSegmentLength3Byte(t *testing.T) {
p := &PDU{}
// Test various sizes including >64KB (needs all 3 bytes)
for _, size := range []uint32{0, 1, 255, 256, 65535, 65536, 1<<24 - 1} {
p.SetDataSegmentLength(size)
if got := p.DataSegmentLength(); got != size {
t.Fatalf("size %d: got %d", size, got)
}
}
}
func testTSIHAccessor(t *testing.T) {
p := &PDU{}
p.SetTSIH(0x1234)
if p.TSIH() != 0x1234 {
t.Fatalf("TSIH: got 0x%04x", p.TSIH())
}
}
func testCmdSNExpStatSN(t *testing.T) {
p := &PDU{}
p.SetCmdSN(100)
p.SetExpStatSN(200)
if p.CmdSN() != 100 {
t.Fatal("CmdSN mismatch")
}
if p.ExpStatSN() != 200 {
t.Fatal("ExpStatSN mismatch")
}
}
func testStatSNExpCmdSNMaxCmdSN(t *testing.T) {
p := &PDU{}
p.SetStatSN(10)
p.SetExpCmdSN(20)
p.SetMaxCmdSN(30)
if p.StatSN() != 10 {
t.Fatal("StatSN mismatch")
}
if p.ExpCmdSN() != 20 {
t.Fatal("ExpCmdSN mismatch")
}
if p.MaxCmdSN() != 30 {
t.Fatal("MaxCmdSN mismatch")
}
}
func testDataSNBufferOffset(t *testing.T) {
p := &PDU{}
p.SetDataSN(5)
p.SetBufferOffset(8192)
if p.DataSN() != 5 {
t.Fatal("DataSN mismatch")
}
if p.BufferOffset() != 8192 {
t.Fatal("BufferOffset mismatch")
}
}
func testR2TFields(t *testing.T) {
p := &PDU{}
p.SetR2TSN(3)
p.SetDesiredDataLength(65536)
if p.R2TSN() != 3 {
t.Fatal("R2TSN mismatch")
}
if p.DesiredDataLength() != 65536 {
t.Fatal("DesiredDataLength mismatch")
}
}
func testResidualCount(t *testing.T) {
p := &PDU{}
p.SetResidualCount(512)
if p.ResidualCount() != 512 {
t.Fatal("ResidualCount mismatch")
}
}
func testExpectedDataTransferLength(t *testing.T) {
p := &PDU{}
p.SetExpectedDataTransferLength(1048576)
if p.ExpectedDataTransferLength() != 1048576 {
t.Fatal("mismatch")
}
}
func testTargetTransferTag(t *testing.T) {
p := &PDU{}
p.SetTargetTransferTag(0xFFFFFFFF)
if p.TargetTransferTag() != 0xFFFFFFFF {
t.Fatal("mismatch")
}
}
func testOpcodeName(t *testing.T) {
if OpcodeName(OpSCSICmd) != "SCSI-Command" {
t.Fatal("wrong name for SCSI-Command")
}
if OpcodeName(OpLoginReq) != "Login-Request" {
t.Fatal("wrong name for Login-Request")
}
if !strings.HasPrefix(OpcodeName(0xFF), "Unknown") {
t.Fatal("unknown opcode should have Unknown prefix")
}
}
func testReadTruncatedBHS(t *testing.T) {
// Only 20 bytes — not enough for BHS
buf := bytes.NewReader(make([]byte, 20))
_, err := ReadPDU(buf)
if err == nil {
t.Fatal("expected error")
}
}
func testReadTruncatedData(t *testing.T) {
// Valid BHS claiming 100 bytes of data, but only 10 bytes follow
var bhs [BHSLength]byte
bhs[5] = 0
bhs[6] = 0
bhs[7] = 100 // DataSegmentLength = 100
buf := make([]byte, BHSLength+10)
copy(buf, bhs[:])
_, err := ReadPDU(bytes.NewReader(buf))
if err == nil {
t.Fatal("expected truncation error")
}
}
func testReadTruncatedAHS(t *testing.T) {
// BHS claiming 2 words of AHS but only 4 bytes follow
var bhs [BHSLength]byte
bhs[4] = 2 // TotalAHSLength = 2 words = 8 bytes
buf := make([]byte, BHSLength+4)
copy(buf, bhs[:])
_, err := ReadPDU(bytes.NewReader(buf))
if err == nil {
t.Fatal("expected truncation error")
}
}
func testReadOversizedData(t *testing.T) {
// BHS claiming MaxDataSegmentLength+1 bytes
var bhs [BHSLength]byte
oversized := uint32(MaxDataSegmentLength + 1)
bhs[5] = byte(oversized >> 16)
bhs[6] = byte(oversized >> 8)
bhs[7] = byte(oversized)
_, err := ReadPDU(bytes.NewReader(bhs[:]))
if err == nil {
t.Fatal("expected oversized error")
}
if !strings.Contains(err.Error(), "exceeds maximum") {
t.Fatalf("unexpected error: %v", err)
}
}
func testReadEOF(t *testing.T) {
_, err := ReadPDU(bytes.NewReader(nil))
if err != io.EOF {
t.Fatalf("expected io.EOF, got %v", err)
}
}
func testWriteInvalidAHSLength(t *testing.T) {
p := &PDU{}
p.AHS = make([]byte, 5) // not multiple of 4
var buf bytes.Buffer
err := WritePDU(&buf, p)
if err != ErrInvalidAHSLength {
t.Fatalf("expected ErrInvalidAHSLength, got %v", err)
}
}
func testRoundtripAllOpcodes(t *testing.T) {
opcodes := []uint8{
OpNOPOut, OpSCSICmd, OpSCSITaskMgmt, OpLoginReq, OpTextReq,
OpSCSIDataOut, OpLogoutReq, OpSNACKReq,
OpNOPIn, OpSCSIResp, OpSCSITaskResp, OpLoginResp, OpTextResp,
OpSCSIDataIn, OpLogoutResp, OpR2T, OpAsyncMsg, OpReject,
}
for _, op := range opcodes {
p := &PDU{}
p.SetOpcode(op)
var buf bytes.Buffer
if err := WritePDU(&buf, p); err != nil {
t.Fatalf("op=0x%02x: write: %v", op, err)
}
p2, err := ReadPDU(&buf)
if err != nil {
t.Fatalf("op=0x%02x: read: %v", op, err)
}
if p2.Opcode() != op {
t.Fatalf("op=0x%02x: got 0x%02x", op, p2.Opcode())
}
}
}
func testDataSegmentExact4ByteBoundary(t *testing.T) {
for _, size := range []int{4, 8, 12, 16, 256, 4096} {
p := &PDU{}
p.DataSegment = bytes.Repeat([]byte{0xAB}, size)
var buf bytes.Buffer
if err := WritePDU(&buf, p); err != nil {
t.Fatalf("size=%d: %v", size, err)
}
// No padding needed
if buf.Len() != BHSLength+size {
t.Fatalf("size=%d: wire=%d, expected=%d", size, buf.Len(), BHSLength+size)
}
p2, err := ReadPDU(&buf)
if err != nil {
t.Fatalf("size=%d: read: %v", size, err)
}
if len(p2.DataSegment) != size {
t.Fatalf("size=%d: got %d", size, len(p2.DataSegment))
}
}
}
func testZeroLengthDataSegment(t *testing.T) {
p := &PDU{}
p.SetOpcode(OpNOPOut)
// DataSegment is nil
var buf bytes.Buffer
if err := WritePDU(&buf, p); err != nil {
t.Fatal(err)
}
if buf.Len() != BHSLength {
t.Fatalf("expected %d, got %d", BHSLength, buf.Len())
}
p2, err := ReadPDU(&buf)
if err != nil {
t.Fatal(err)
}
if len(p2.DataSegment) != 0 {
t.Fatalf("expected empty data segment, got %d bytes", len(p2.DataSegment))
}
}
func testMax3ByteDataLength(t *testing.T) {
p := &PDU{}
maxLen := uint32(1<<24 - 1) // 16MB - 1
p.SetDataSegmentLength(maxLen)
if got := p.DataSegmentLength(); got != maxLen {
t.Fatalf("expected %d, got %d", maxLen, got)
}
}

463
weed/storage/blockvol/iscsi/scsi.go

@ -0,0 +1,463 @@
package iscsi
import (
"encoding/binary"
)
// SCSI opcode constants (SPC-5 / SBC-4)
const (
ScsiTestUnitReady uint8 = 0x00
ScsiInquiry uint8 = 0x12
ScsiModeSense6 uint8 = 0x1a
ScsiReadCapacity10 uint8 = 0x25
ScsiRead10 uint8 = 0x28
ScsiWrite10 uint8 = 0x2a
ScsiSyncCache10 uint8 = 0x35
ScsiUnmap uint8 = 0x42
ScsiReportLuns uint8 = 0xa0
ScsiRead16 uint8 = 0x88
ScsiWrite16 uint8 = 0x8a
ScsiReadCapacity16 uint8 = 0x9e // SERVICE ACTION IN (16), SA=0x10
ScsiSyncCache16 uint8 = 0x91
)
// Service action for READ CAPACITY (16)
const ScsiSAReadCapacity16 uint8 = 0x10
// SCSI sense keys
const (
SenseNoSense uint8 = 0x00
SenseNotReady uint8 = 0x02
SenseMediumError uint8 = 0x03
SenseHardwareError uint8 = 0x04
SenseIllegalRequest uint8 = 0x05
SenseAbortedCommand uint8 = 0x0b
)
// ASC/ASCQ pairs
const (
ASCInvalidOpcode uint8 = 0x20
ASCQLuk uint8 = 0x00
ASCInvalidFieldInCDB uint8 = 0x24
ASCLBAOutOfRange uint8 = 0x21
ASCNotReady uint8 = 0x04
ASCQNotReady uint8 = 0x03 // manual intervention required
)
// BlockDevice is the interface that the SCSI command handler uses to
// interact with the underlying storage. This maps onto BlockVol.
type BlockDevice interface {
ReadAt(lba uint64, length uint32) ([]byte, error)
WriteAt(lba uint64, data []byte) error
Trim(lba uint64, length uint32) error
SyncCache() error
BlockSize() uint32
VolumeSize() uint64 // total size in bytes
IsHealthy() bool
}
// SCSIHandler processes SCSI commands from iSCSI PDUs.
type SCSIHandler struct {
dev BlockDevice
vendorID string // 8 bytes for INQUIRY
prodID string // 16 bytes for INQUIRY
serial string // for VPD page 0x80
}
// NewSCSIHandler creates a SCSI command handler for the given block device.
func NewSCSIHandler(dev BlockDevice) *SCSIHandler {
return &SCSIHandler{
dev: dev,
vendorID: "SeaweedF",
prodID: "BlockVol ",
serial: "SWF00001",
}
}
// SCSIResult holds the result of a SCSI command execution.
type SCSIResult struct {
Status uint8 // SCSI status
Data []byte // Response data (for Data-In)
SenseKey uint8 // Sense key (if CHECK_CONDITION)
SenseASC uint8 // Additional sense code
SenseASCQ uint8 // Additional sense code qualifier
}
// HandleCommand dispatches a SCSI CDB to the appropriate handler.
// dataOut contains any data sent by the initiator (for WRITE commands).
func (h *SCSIHandler) HandleCommand(cdb [16]byte, dataOut []byte) SCSIResult {
opcode := cdb[0]
switch opcode {
case ScsiTestUnitReady:
return h.testUnitReady()
case ScsiInquiry:
return h.inquiry(cdb)
case ScsiModeSense6:
return h.modeSense6(cdb)
case ScsiReadCapacity10:
return h.readCapacity10()
case ScsiReadCapacity16:
sa := cdb[1] & 0x1f
if sa == ScsiSAReadCapacity16 {
return h.readCapacity16(cdb)
}
return illegalRequest(ASCInvalidOpcode, ASCQLuk)
case ScsiReportLuns:
return h.reportLuns(cdb)
case ScsiRead10:
return h.read10(cdb)
case ScsiRead16:
return h.read16(cdb)
case ScsiWrite10:
return h.write10(cdb, dataOut)
case ScsiWrite16:
return h.write16(cdb, dataOut)
case ScsiSyncCache10:
return h.syncCache()
case ScsiSyncCache16:
return h.syncCache()
case ScsiUnmap:
return h.unmap(cdb, dataOut)
default:
return illegalRequest(ASCInvalidOpcode, ASCQLuk)
}
}
// --- Metadata commands ---
func (h *SCSIHandler) testUnitReady() SCSIResult {
if !h.dev.IsHealthy() {
return SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: SenseNotReady,
SenseASC: ASCNotReady,
SenseASCQ: ASCQNotReady,
}
}
return SCSIResult{Status: SCSIStatusGood}
}
func (h *SCSIHandler) inquiry(cdb [16]byte) SCSIResult {
evpd := cdb[1] & 0x01
pageCode := cdb[2]
allocLen := binary.BigEndian.Uint16(cdb[3:5])
if allocLen == 0 {
allocLen = 36
}
if evpd != 0 {
return h.inquiryVPD(pageCode, allocLen)
}
// Standard INQUIRY response (SPC-5, Section 6.6.1)
data := make([]byte, 96)
data[0] = 0x00 // Peripheral device type: SBC (direct access block device)
data[1] = 0x00 // RMB=0 (not removable)
data[2] = 0x06 // SPC-4 version
data[3] = 0x02 // Response data format = 2 (SPC-2+)
data[4] = 91 // Additional length (96-5)
data[5] = 0x00 // SCCS, ACC, TPGS, 3PC
data[6] = 0x00 // Obsolete, EncServ, VS, MultiP
data[7] = 0x02 // CmdQue=1 (supports command queuing)
// Vendor ID (bytes 8-15, 8 chars, space padded)
copy(data[8:16], padRight(h.vendorID, 8))
// Product ID (bytes 16-31, 16 chars, space padded)
copy(data[16:32], padRight(h.prodID, 16))
// Product revision (bytes 32-35, 4 chars)
copy(data[32:36], "0001")
if int(allocLen) < len(data) {
data = data[:allocLen]
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
}
func (h *SCSIHandler) inquiryVPD(pageCode uint8, allocLen uint16) SCSIResult {
switch pageCode {
case 0x00: // Supported VPD pages
data := []byte{
0x00, // device type
0x00, // page code
0x00, 0x03, // page length
0x00, // supported pages: 0x00
0x80, // 0x80 (serial)
0x83, // 0x83 (device identification)
}
if int(allocLen) < len(data) {
data = data[:allocLen]
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
case 0x80: // Unit serial number
serial := padRight(h.serial, 8)
data := make([]byte, 4+len(serial))
data[0] = 0x00 // device type
data[1] = 0x80 // page code
binary.BigEndian.PutUint16(data[2:4], uint16(len(serial)))
copy(data[4:], serial)
if int(allocLen) < len(data) {
data = data[:allocLen]
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
case 0x83: // Device identification
// NAA identifier (8 bytes)
naaID := []byte{
0x01, // code set: binary
0x03, // identifier type: NAA
0x00, // reserved
0x08, // identifier length
0x60, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // NAA-6 fake
}
data := make([]byte, 4+len(naaID))
data[0] = 0x00 // device type
data[1] = 0x83 // page code
binary.BigEndian.PutUint16(data[2:4], uint16(len(naaID)))
copy(data[4:], naaID)
if int(allocLen) < len(data) {
data = data[:allocLen]
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
default:
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk)
}
}
func (h *SCSIHandler) readCapacity10() SCSIResult {
blockSize := h.dev.BlockSize()
totalBlocks := h.dev.VolumeSize() / uint64(blockSize)
data := make([]byte, 8)
// If >2TB (blocks > 0xFFFFFFFF), return 0xFFFFFFFF to signal use READ_CAPACITY_16
if totalBlocks > 0xFFFFFFFF {
binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF)
} else {
binary.BigEndian.PutUint32(data[0:4], uint32(totalBlocks-1)) // last LBA
}
binary.BigEndian.PutUint32(data[4:8], blockSize)
return SCSIResult{Status: SCSIStatusGood, Data: data}
}
func (h *SCSIHandler) readCapacity16(cdb [16]byte) SCSIResult {
allocLen := binary.BigEndian.Uint32(cdb[10:14])
if allocLen < 32 {
allocLen = 32
}
blockSize := h.dev.BlockSize()
totalBlocks := h.dev.VolumeSize() / uint64(blockSize)
data := make([]byte, 32)
binary.BigEndian.PutUint64(data[0:8], totalBlocks-1) // last LBA
binary.BigEndian.PutUint32(data[8:12], blockSize) // block length
// data[12]: LBPME (logical block provisioning management enabled) = 1 for UNMAP support
data[14] = 0x80 // LBPME bit
if allocLen < uint32(len(data)) {
data = data[:allocLen]
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
}
func (h *SCSIHandler) modeSense6(cdb [16]byte) SCSIResult {
// Minimal MODE SENSE(6) response — no mode pages
allocLen := cdb[4]
if allocLen == 0 {
allocLen = 4
}
data := make([]byte, 4)
data[0] = 3 // Mode data length (3 bytes follow)
data[1] = 0x00 // Medium type: default
data[2] = 0x00 // Device-specific parameter (no write protect)
data[3] = 0x00 // Block descriptor length = 0
if int(allocLen) < len(data) {
data = data[:allocLen]
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
}
func (h *SCSIHandler) reportLuns(cdb [16]byte) SCSIResult {
allocLen := binary.BigEndian.Uint32(cdb[6:10])
if allocLen < 16 {
allocLen = 16
}
// Report a single LUN (LUN 0)
data := make([]byte, 16)
binary.BigEndian.PutUint32(data[0:4], 8) // LUN list length (8 bytes, 1 LUN)
// data[4:7] reserved
// data[8:16] = LUN 0 (all zeros)
if allocLen < uint32(len(data)) {
data = data[:allocLen]
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
}
// --- Data commands (Task 2.6) ---
func (h *SCSIHandler) read10(cdb [16]byte) SCSIResult {
lba := uint64(binary.BigEndian.Uint32(cdb[2:6]))
transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9]))
return h.doRead(lba, transferLen)
}
func (h *SCSIHandler) read16(cdb [16]byte) SCSIResult {
lba := binary.BigEndian.Uint64(cdb[2:10])
transferLen := binary.BigEndian.Uint32(cdb[10:14])
return h.doRead(lba, transferLen)
}
func (h *SCSIHandler) write10(cdb [16]byte, dataOut []byte) SCSIResult {
lba := uint64(binary.BigEndian.Uint32(cdb[2:6]))
transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9]))
return h.doWrite(lba, transferLen, dataOut)
}
func (h *SCSIHandler) write16(cdb [16]byte, dataOut []byte) SCSIResult {
lba := binary.BigEndian.Uint64(cdb[2:10])
transferLen := binary.BigEndian.Uint32(cdb[10:14])
return h.doWrite(lba, transferLen, dataOut)
}
func (h *SCSIHandler) doRead(lba uint64, transferLen uint32) SCSIResult {
if transferLen == 0 {
return SCSIResult{Status: SCSIStatusGood}
}
blockSize := h.dev.BlockSize()
totalBlocks := h.dev.VolumeSize() / uint64(blockSize)
if lba+uint64(transferLen) > totalBlocks {
return illegalRequest(ASCLBAOutOfRange, ASCQLuk)
}
byteLen := transferLen * blockSize
data, err := h.dev.ReadAt(lba, byteLen)
if err != nil {
return SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: SenseMediumError,
SenseASC: 0x11, // Unrecovered read error
SenseASCQ: 0x00,
}
}
return SCSIResult{Status: SCSIStatusGood, Data: data}
}
func (h *SCSIHandler) doWrite(lba uint64, transferLen uint32, dataOut []byte) SCSIResult {
if transferLen == 0 {
return SCSIResult{Status: SCSIStatusGood}
}
blockSize := h.dev.BlockSize()
totalBlocks := h.dev.VolumeSize() / uint64(blockSize)
if lba+uint64(transferLen) > totalBlocks {
return illegalRequest(ASCLBAOutOfRange, ASCQLuk)
}
expectedBytes := transferLen * blockSize
if uint32(len(dataOut)) < expectedBytes {
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk)
}
if err := h.dev.WriteAt(lba, dataOut[:expectedBytes]); err != nil {
return SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: SenseMediumError,
SenseASC: 0x0C, // Write error
SenseASCQ: 0x00,
}
}
return SCSIResult{Status: SCSIStatusGood}
}
func (h *SCSIHandler) syncCache() SCSIResult {
if err := h.dev.SyncCache(); err != nil {
return SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: SenseHardwareError,
SenseASC: 0x00,
SenseASCQ: 0x00,
}
}
return SCSIResult{Status: SCSIStatusGood}
}
func (h *SCSIHandler) unmap(cdb [16]byte, dataOut []byte) SCSIResult {
if len(dataOut) < 8 {
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk)
}
// UNMAP parameter list header (8 bytes)
// descLen := binary.BigEndian.Uint16(dataOut[0:2]) // data length (unused)
blockDescLen := binary.BigEndian.Uint16(dataOut[2:4])
if int(blockDescLen)+8 > len(dataOut) {
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk)
}
// Each UNMAP block descriptor is 16 bytes
descData := dataOut[8 : 8+blockDescLen]
for len(descData) >= 16 {
lba := binary.BigEndian.Uint64(descData[0:8])
numBlocks := binary.BigEndian.Uint32(descData[8:12])
// descData[12:16] reserved
if numBlocks > 0 {
blockSize := h.dev.BlockSize()
if err := h.dev.Trim(lba, numBlocks*blockSize); err != nil {
return SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: SenseMediumError,
SenseASC: 0x0C,
SenseASCQ: 0x00,
}
}
}
descData = descData[16:]
}
return SCSIResult{Status: SCSIStatusGood}
}
// BuildSenseData constructs a fixed-format sense data buffer (18 bytes).
func BuildSenseData(key, asc, ascq uint8) []byte {
data := make([]byte, 18)
data[0] = 0x70 // Response code: current errors, fixed format
data[2] = key & 0x0f // Sense key
data[7] = 10 // Additional sense length
data[12] = asc // ASC
data[13] = ascq // ASCQ
return data
}
func illegalRequest(asc, ascq uint8) SCSIResult {
return SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: SenseIllegalRequest,
SenseASC: asc,
SenseASCQ: ascq,
}
}
func padRight(s string, n int) string {
if len(s) >= n {
return s[:n]
}
b := make([]byte, n)
copy(b, s)
for i := len(s); i < n; i++ {
b[i] = ' '
}
return string(b)
}

692
weed/storage/blockvol/iscsi/scsi_test.go

@ -0,0 +1,692 @@
package iscsi
import (
"encoding/binary"
"errors"
"testing"
)
// mockBlockDevice implements BlockDevice for testing.
type mockBlockDevice struct {
blockSize uint32
volumeSize uint64
healthy bool
blocks map[uint64][]byte // LBA -> data
syncErr error
readErr error
writeErr error
trimErr error
}
func newMockDevice(volumeSize uint64) *mockBlockDevice {
return &mockBlockDevice{
blockSize: 4096,
volumeSize: volumeSize,
healthy: true,
blocks: make(map[uint64][]byte),
}
}
func (m *mockBlockDevice) ReadAt(lba uint64, length uint32) ([]byte, error) {
if m.readErr != nil {
return nil, m.readErr
}
blockCount := length / m.blockSize
result := make([]byte, length)
for i := uint32(0); i < blockCount; i++ {
if data, ok := m.blocks[lba+uint64(i)]; ok {
copy(result[i*m.blockSize:], data)
}
// Unwritten blocks return zeros (already zeroed)
}
return result, nil
}
func (m *mockBlockDevice) WriteAt(lba uint64, data []byte) error {
if m.writeErr != nil {
return m.writeErr
}
blockCount := uint32(len(data)) / m.blockSize
for i := uint32(0); i < blockCount; i++ {
block := make([]byte, m.blockSize)
copy(block, data[i*m.blockSize:])
m.blocks[lba+uint64(i)] = block
}
return nil
}
func (m *mockBlockDevice) Trim(lba uint64, length uint32) error {
if m.trimErr != nil {
return m.trimErr
}
blockCount := length / m.blockSize
for i := uint32(0); i < blockCount; i++ {
delete(m.blocks, lba+uint64(i))
}
return nil
}
func (m *mockBlockDevice) SyncCache() error { return m.syncErr }
func (m *mockBlockDevice) BlockSize() uint32 { return m.blockSize }
func (m *mockBlockDevice) VolumeSize() uint64 { return m.volumeSize }
func (m *mockBlockDevice) IsHealthy() bool { return m.healthy }
func TestSCSI(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"test_unit_ready_good", testTestUnitReadyGood},
{"test_unit_ready_not_ready", testTestUnitReadyNotReady},
{"inquiry_standard", testInquiryStandard},
{"inquiry_vpd_supported_pages", testInquiryVPDSupportedPages},
{"inquiry_vpd_serial", testInquiryVPDSerial},
{"inquiry_vpd_device_id", testInquiryVPDDeviceID},
{"inquiry_vpd_unknown_page", testInquiryVPDUnknownPage},
{"inquiry_alloc_length", testInquiryAllocLength},
{"read_capacity_10", testReadCapacity10},
{"read_capacity_10_large", testReadCapacity10Large},
{"read_capacity_16", testReadCapacity16},
{"read_capacity_16_lbpme", testReadCapacity16LBPME},
{"mode_sense_6", testModeSense6},
{"report_luns", testReportLuns},
{"unknown_opcode", testUnknownOpcode},
{"read_10", testRead10},
{"read_16", testRead16},
{"write_10", testWrite10},
{"write_16", testWrite16},
{"read_write_roundtrip", testReadWriteRoundtrip},
{"write_oob", testWriteOOB},
{"read_oob", testReadOOB},
{"zero_length_transfer", testZeroLengthTransfer},
{"sync_cache", testSyncCache},
{"sync_cache_error", testSyncCacheError},
{"unmap_single", testUnmapSingle},
{"unmap_multiple_descriptors", testUnmapMultipleDescriptors},
{"unmap_short_param", testUnmapShortParam},
{"build_sense_data", testBuildSenseData},
{"read_error", testReadError},
{"write_error", testWriteError},
{"read_capacity_16_invalid_sa", testReadCapacity16InvalidSA},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testTestUnitReadyGood(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiTestUnitReady
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatalf("status: %d", r.Status)
}
}
func testTestUnitReadyNotReady(t *testing.T) {
dev := newMockDevice(1024 * 1024)
dev.healthy = false
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiTestUnitReady
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusCheckCond {
t.Fatalf("status: %d", r.Status)
}
if r.SenseKey != SenseNotReady {
t.Fatalf("sense key: %d", r.SenseKey)
}
}
func testInquiryStandard(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiInquiry
binary.BigEndian.PutUint16(cdb[3:5], 96)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatalf("status: %d", r.Status)
}
if len(r.Data) != 96 {
t.Fatalf("data length: %d", len(r.Data))
}
// Check peripheral device type
if r.Data[0] != 0x00 {
t.Fatal("not SBC device type")
}
// Check vendor
vendor := string(r.Data[8:16])
if vendor != "SeaweedF" {
t.Fatalf("vendor: %q", vendor)
}
// Check CmdQue
if r.Data[7]&0x02 == 0 {
t.Fatal("CmdQue not set")
}
}
func testInquiryVPDSupportedPages(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiInquiry
cdb[1] = 0x01 // EVPD
cdb[2] = 0x00 // Supported pages
binary.BigEndian.PutUint16(cdb[3:5], 255)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatalf("status: %d", r.Status)
}
if r.Data[1] != 0x00 {
t.Fatal("wrong page code")
}
// Should list pages 0x00, 0x80, 0x83
if len(r.Data) < 7 || r.Data[4] != 0x00 || r.Data[5] != 0x80 || r.Data[6] != 0x83 {
t.Fatalf("supported pages: %v", r.Data)
}
}
func testInquiryVPDSerial(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiInquiry
cdb[1] = 0x01
cdb[2] = 0x80
binary.BigEndian.PutUint16(cdb[3:5], 255)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatalf("status: %d", r.Status)
}
if r.Data[1] != 0x80 {
t.Fatal("wrong page code")
}
}
func testInquiryVPDDeviceID(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiInquiry
cdb[1] = 0x01
cdb[2] = 0x83
binary.BigEndian.PutUint16(cdb[3:5], 255)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatalf("status: %d", r.Status)
}
if r.Data[1] != 0x83 {
t.Fatal("wrong page code")
}
}
func testInquiryVPDUnknownPage(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiInquiry
cdb[1] = 0x01
cdb[2] = 0xFF // unknown page
binary.BigEndian.PutUint16(cdb[3:5], 255)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusCheckCond {
t.Fatal("expected CHECK_CONDITION")
}
if r.SenseKey != SenseIllegalRequest {
t.Fatal("expected ILLEGAL_REQUEST")
}
}
func testInquiryAllocLength(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiInquiry
binary.BigEndian.PutUint16(cdb[3:5], 10) // small alloc
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("should succeed")
}
if len(r.Data) != 10 {
t.Fatalf("truncation: expected 10, got %d", len(r.Data))
}
}
func testReadCapacity10(t *testing.T) {
dev := newMockDevice(100 * 4096) // 100 blocks
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiReadCapacity10
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("status not good")
}
if len(r.Data) != 8 {
t.Fatalf("data length: %d", len(r.Data))
}
lastLBA := binary.BigEndian.Uint32(r.Data[0:4])
blockSize := binary.BigEndian.Uint32(r.Data[4:8])
if lastLBA != 99 {
t.Fatalf("last LBA: %d, expected 99", lastLBA)
}
if blockSize != 4096 {
t.Fatalf("block size: %d", blockSize)
}
}
func testReadCapacity10Large(t *testing.T) {
// Volume with >2^32 blocks should return 0xFFFFFFFF
dev := newMockDevice(uint64(0x100000001) * 4096) // 2^32+1 blocks
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiReadCapacity10
r := h.HandleCommand(cdb, nil)
lastLBA := binary.BigEndian.Uint32(r.Data[0:4])
if lastLBA != 0xFFFFFFFF {
t.Fatalf("should return 0xFFFFFFFF for >2TB, got %d", lastLBA)
}
}
func testReadCapacity16(t *testing.T) {
dev := newMockDevice(3 * 1024 * 1024 * 1024 * 1024) // 3 TB
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiReadCapacity16
cdb[1] = ScsiSAReadCapacity16
binary.BigEndian.PutUint32(cdb[10:14], 32)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("status not good")
}
lastLBA := binary.BigEndian.Uint64(r.Data[0:8])
expectedBlocks := uint64(3*1024*1024*1024*1024) / 4096
if lastLBA != expectedBlocks-1 {
t.Fatalf("last LBA: %d, expected %d", lastLBA, expectedBlocks-1)
}
}
func testReadCapacity16LBPME(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiReadCapacity16
cdb[1] = ScsiSAReadCapacity16
binary.BigEndian.PutUint32(cdb[10:14], 32)
r := h.HandleCommand(cdb, nil)
// LBPME bit should be set (byte 14, bit 7)
if r.Data[14]&0x80 == 0 {
t.Fatal("LBPME bit not set")
}
}
func testModeSense6(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiModeSense6
cdb[4] = 255
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("status not good")
}
if len(r.Data) != 4 {
t.Fatalf("mode sense data: %d bytes", len(r.Data))
}
// No write protect
if r.Data[2]&0x80 != 0 {
t.Fatal("write protect set")
}
}
func testReportLuns(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiReportLuns
binary.BigEndian.PutUint32(cdb[6:10], 256)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("status not good")
}
lunListLen := binary.BigEndian.Uint32(r.Data[0:4])
if lunListLen != 8 {
t.Fatalf("LUN list length: %d (expected 8 for 1 LUN)", lunListLen)
}
}
func testUnknownOpcode(t *testing.T) {
dev := newMockDevice(1024 * 1024)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = 0xFF
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusCheckCond {
t.Fatal("expected CHECK_CONDITION")
}
if r.SenseKey != SenseIllegalRequest {
t.Fatal("expected ILLEGAL_REQUEST")
}
}
func testRead10(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
// Write some data first
data := make([]byte, 4096)
for i := range data {
data[i] = 0xAB
}
dev.blocks[5] = data
var cdb [16]byte
cdb[0] = ScsiRead10
binary.BigEndian.PutUint32(cdb[2:6], 5) // LBA=5
binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("read failed")
}
if len(r.Data) != 4096 {
t.Fatalf("data length: %d", len(r.Data))
}
if r.Data[0] != 0xAB {
t.Fatal("data mismatch")
}
}
func testRead16(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
data := make([]byte, 4096)
data[0] = 0xCD
dev.blocks[10] = data
var cdb [16]byte
cdb[0] = ScsiRead16
binary.BigEndian.PutUint64(cdb[2:10], 10)
binary.BigEndian.PutUint32(cdb[10:14], 1)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("read16 failed")
}
if r.Data[0] != 0xCD {
t.Fatal("data mismatch")
}
}
func testWrite10(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
dataOut := make([]byte, 4096)
dataOut[0] = 0xEF
var cdb [16]byte
cdb[0] = ScsiWrite10
binary.BigEndian.PutUint32(cdb[2:6], 7)
binary.BigEndian.PutUint16(cdb[7:9], 1)
r := h.HandleCommand(cdb, dataOut)
if r.Status != SCSIStatusGood {
t.Fatal("write failed")
}
if dev.blocks[7][0] != 0xEF {
t.Fatal("data not written")
}
}
func testWrite16(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
dataOut := make([]byte, 8192)
dataOut[0] = 0x11
dataOut[4096] = 0x22
var cdb [16]byte
cdb[0] = ScsiWrite16
binary.BigEndian.PutUint64(cdb[2:10], 50)
binary.BigEndian.PutUint32(cdb[10:14], 2)
r := h.HandleCommand(cdb, dataOut)
if r.Status != SCSIStatusGood {
t.Fatal("write16 failed")
}
if dev.blocks[50][0] != 0x11 {
t.Fatal("block 50 wrong")
}
if dev.blocks[51][0] != 0x22 {
t.Fatal("block 51 wrong")
}
}
func testReadWriteRoundtrip(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
// Write
dataOut := make([]byte, 4096)
for i := range dataOut {
dataOut[i] = byte(i % 256)
}
var wcdb [16]byte
wcdb[0] = ScsiWrite10
binary.BigEndian.PutUint32(wcdb[2:6], 0)
binary.BigEndian.PutUint16(wcdb[7:9], 1)
h.HandleCommand(wcdb, dataOut)
// Read back
var rcdb [16]byte
rcdb[0] = ScsiRead10
binary.BigEndian.PutUint32(rcdb[2:6], 0)
binary.BigEndian.PutUint16(rcdb[7:9], 1)
r := h.HandleCommand(rcdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("read failed")
}
for i := 0; i < 4096; i++ {
if r.Data[i] != byte(i%256) {
t.Fatalf("byte %d: got %d, want %d", i, r.Data[i], i%256)
}
}
}
func testWriteOOB(t *testing.T) {
dev := newMockDevice(10 * 4096) // 10 blocks
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiWrite10
binary.BigEndian.PutUint32(cdb[2:6], 9)
binary.BigEndian.PutUint16(cdb[7:9], 2) // LBA 9 + 2 blocks > 10
r := h.HandleCommand(cdb, make([]byte, 8192))
if r.Status != SCSIStatusCheckCond {
t.Fatal("should fail for OOB")
}
}
func testReadOOB(t *testing.T) {
dev := newMockDevice(10 * 4096)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiRead10
binary.BigEndian.PutUint32(cdb[2:6], 10) // LBA 10 == total blocks
binary.BigEndian.PutUint16(cdb[7:9], 1)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusCheckCond {
t.Fatal("should fail for OOB")
}
}
func testZeroLengthTransfer(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiRead10
binary.BigEndian.PutUint32(cdb[2:6], 0)
binary.BigEndian.PutUint16(cdb[7:9], 0) // 0 blocks
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("zero-length read should succeed")
}
}
func testSyncCache(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiSyncCache10
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusGood {
t.Fatal("sync cache failed")
}
}
func testSyncCacheError(t *testing.T) {
dev := newMockDevice(100 * 4096)
dev.syncErr = errors.New("disk error")
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiSyncCache10
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusCheckCond {
t.Fatal("should fail")
}
}
func testUnmapSingle(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
// Write data at LBA 5
dev.blocks[5] = make([]byte, 4096)
dev.blocks[5][0] = 0xFF
// UNMAP parameter list
unmapData := make([]byte, 24) // 8 header + 16 descriptor
binary.BigEndian.PutUint16(unmapData[0:2], 22) // data length
binary.BigEndian.PutUint16(unmapData[2:4], 16) // block desc length
binary.BigEndian.PutUint64(unmapData[8:16], 5) // LBA
binary.BigEndian.PutUint32(unmapData[16:20], 1) // num blocks
var cdb [16]byte
cdb[0] = ScsiUnmap
r := h.HandleCommand(cdb, unmapData)
if r.Status != SCSIStatusGood {
t.Fatal("unmap failed")
}
if _, ok := dev.blocks[5]; ok {
t.Fatal("block 5 should be trimmed")
}
}
func testUnmapMultipleDescriptors(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
dev.blocks[3] = make([]byte, 4096)
dev.blocks[7] = make([]byte, 4096)
// 2 descriptors
unmapData := make([]byte, 40) // 8 header + 2*16 descriptors
binary.BigEndian.PutUint16(unmapData[0:2], 38)
binary.BigEndian.PutUint16(unmapData[2:4], 32)
// Descriptor 1: LBA=3, count=1
binary.BigEndian.PutUint64(unmapData[8:16], 3)
binary.BigEndian.PutUint32(unmapData[16:20], 1)
// Descriptor 2: LBA=7, count=1
binary.BigEndian.PutUint64(unmapData[24:32], 7)
binary.BigEndian.PutUint32(unmapData[32:36], 1)
var cdb [16]byte
cdb[0] = ScsiUnmap
r := h.HandleCommand(cdb, unmapData)
if r.Status != SCSIStatusGood {
t.Fatal("unmap failed")
}
if _, ok := dev.blocks[3]; ok {
t.Fatal("block 3 should be trimmed")
}
if _, ok := dev.blocks[7]; ok {
t.Fatal("block 7 should be trimmed")
}
}
func testUnmapShortParam(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiUnmap
r := h.HandleCommand(cdb, []byte{1, 2, 3}) // too short
if r.Status != SCSIStatusCheckCond {
t.Fatal("should fail for short unmap params")
}
}
func testBuildSenseData(t *testing.T) {
data := BuildSenseData(SenseIllegalRequest, ASCInvalidOpcode, ASCQLuk)
if len(data) != 18 {
t.Fatalf("length: %d", len(data))
}
if data[0] != 0x70 {
t.Fatal("response code wrong")
}
if data[2] != SenseIllegalRequest {
t.Fatal("sense key wrong")
}
if data[12] != ASCInvalidOpcode {
t.Fatal("ASC wrong")
}
}
func testReadError(t *testing.T) {
dev := newMockDevice(100 * 4096)
dev.readErr = errors.New("io error")
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiRead10
binary.BigEndian.PutUint32(cdb[2:6], 0)
binary.BigEndian.PutUint16(cdb[7:9], 1)
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusCheckCond {
t.Fatal("should fail")
}
if r.SenseKey != SenseMediumError {
t.Fatal("should be MEDIUM_ERROR")
}
}
func testWriteError(t *testing.T) {
dev := newMockDevice(100 * 4096)
dev.writeErr = errors.New("io error")
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiWrite10
binary.BigEndian.PutUint32(cdb[2:6], 0)
binary.BigEndian.PutUint16(cdb[7:9], 1)
r := h.HandleCommand(cdb, make([]byte, 4096))
if r.Status != SCSIStatusCheckCond {
t.Fatal("should fail")
}
}
func testReadCapacity16InvalidSA(t *testing.T) {
dev := newMockDevice(100 * 4096)
h := NewSCSIHandler(dev)
var cdb [16]byte
cdb[0] = ScsiReadCapacity16
cdb[1] = 0x05 // wrong service action
r := h.HandleCommand(cdb, nil)
if r.Status != SCSIStatusCheckCond {
t.Fatal("should fail for wrong SA")
}
}

421
weed/storage/blockvol/iscsi/session.go

@ -0,0 +1,421 @@
package iscsi
import (
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
)
var (
ErrSessionClosed = errors.New("iscsi: session closed")
ErrCmdSNOutOfWindow = errors.New("iscsi: CmdSN out of window")
)
// SessionState tracks the lifecycle of an iSCSI session.
type SessionState int
const (
SessionLogin SessionState = iota // login phase
SessionLoggedIn // full feature phase
SessionLogout // logout requested
SessionClosed // terminated
)
// Session manages a single iSCSI session (one initiator connection).
type Session struct {
mu sync.Mutex
state SessionState
conn net.Conn
scsi *SCSIHandler
config TargetConfig
resolver TargetResolver
devices DeviceLookup
// Sequence numbers
expCmdSN atomic.Uint32 // expected CmdSN from initiator
maxCmdSN atomic.Uint32 // max CmdSN we allow
statSN uint32 // target status sequence number
// Login state
negotiator *LoginNegotiator
loginDone bool
// Data sequencing
dataInWriter *DataInWriter
// Shutdown
closed atomic.Bool
closeErr error
// Logging
logger *log.Logger
}
// NewSession creates a new iSCSI session on the given connection.
func NewSession(conn net.Conn, config TargetConfig, resolver TargetResolver, devices DeviceLookup, logger *log.Logger) *Session {
if logger == nil {
logger = log.Default()
}
s := &Session{
state: SessionLogin,
conn: conn,
config: config,
resolver: resolver,
devices: devices,
negotiator: NewLoginNegotiator(config),
logger: logger,
}
s.expCmdSN.Store(1)
s.maxCmdSN.Store(32) // window of 32 commands
return s
}
// HandleConnection processes PDUs until the connection is closed or an error occurs.
func (s *Session) HandleConnection() error {
defer s.close()
for !s.closed.Load() {
pdu, err := ReadPDU(s.conn)
if err != nil {
if s.closed.Load() {
return nil
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("read PDU: %w", err)
}
if err := s.dispatch(pdu); err != nil {
if s.closed.Load() {
return nil
}
return fmt.Errorf("dispatch %s: %w", OpcodeName(pdu.Opcode()), err)
}
}
return nil
}
// Close terminates the session.
func (s *Session) Close() error {
s.closed.Store(true)
return s.conn.Close()
}
func (s *Session) close() {
s.closed.Store(true)
s.mu.Lock()
s.state = SessionClosed
s.mu.Unlock()
s.conn.Close()
}
func (s *Session) dispatch(pdu *PDU) error {
op := pdu.Opcode()
switch op {
case OpLoginReq:
return s.handleLogin(pdu)
case OpTextReq:
return s.handleText(pdu)
case OpSCSICmd:
return s.handleSCSICmd(pdu)
case OpSCSIDataOut:
// Handled inline during write command processing
return nil
case OpNOPOut:
return s.handleNOPOut(pdu)
case OpLogoutReq:
return s.handleLogout(pdu)
case OpSCSITaskMgmt:
return s.handleTaskMgmt(pdu)
default:
s.logger.Printf("unhandled opcode: %s", OpcodeName(op))
return s.sendReject(pdu, 0x04) // command not supported
}
}
func (s *Session) handleLogin(pdu *PDU) error {
s.mu.Lock()
defer s.mu.Unlock()
resp := s.negotiator.HandleLoginPDU(pdu, s.resolver)
// Set sequence numbers
resp.SetStatSN(s.statSN)
s.statSN++
resp.SetExpCmdSN(s.expCmdSN.Load())
resp.SetMaxCmdSN(s.maxCmdSN.Load())
if err := WritePDU(s.conn, resp); err != nil {
return err
}
if s.negotiator.Done() {
s.loginDone = true
s.state = SessionLoggedIn
result := s.negotiator.Result()
s.dataInWriter = NewDataInWriter(uint32(result.MaxRecvDataSegLen))
// Bind SCSI handler to the device for the target the initiator logged into
if s.devices != nil && result.TargetName != "" {
dev := s.devices.LookupDevice(result.TargetName)
s.scsi = NewSCSIHandler(dev)
}
s.logger.Printf("login complete: initiator=%s target=%s session=%s",
result.InitiatorName, result.TargetName, result.SessionType)
}
return nil
}
func (s *Session) handleText(pdu *PDU) error {
s.mu.Lock()
defer s.mu.Unlock()
// Gather discovery targets from resolver if it supports listing
var targets []DiscoveryTarget
if lister, ok := s.resolver.(TargetLister); ok {
targets = lister.ListTargets()
}
resp := HandleTextRequest(pdu, targets)
resp.SetStatSN(s.statSN)
s.statSN++
resp.SetExpCmdSN(s.expCmdSN.Load())
resp.SetMaxCmdSN(s.maxCmdSN.Load())
return WritePDU(s.conn, resp)
}
func (s *Session) handleSCSICmd(pdu *PDU) error {
if !s.loginDone {
return s.sendReject(pdu, 0x0b) // protocol error
}
cdb := pdu.CDB()
itt := pdu.InitiatorTaskTag()
flags := pdu.OpSpecific1()
// Advance CmdSN
if !pdu.Immediate() {
s.advanceCmdSN()
}
isWrite := flags&FlagW != 0
isRead := flags&FlagR != 0
expectedLen := pdu.ExpectedDataTransferLength()
// Handle write commands — collect data
var dataOut []byte
if isWrite && expectedLen > 0 {
collector := NewDataOutCollector(expectedLen)
// Immediate data
if len(pdu.DataSegment) > 0 {
if err := collector.AddImmediateData(pdu.DataSegment); err != nil {
return s.sendCheckCondition(itt, SenseIllegalRequest, ASCInvalidFieldInCDB, ASCQLuk)
}
}
// If more data needed, send R2T and collect Data-Out PDUs
if !collector.Done() {
if err := s.collectDataOut(collector, itt); err != nil {
return err
}
}
dataOut = collector.Data()
}
// Execute SCSI command
result := s.scsi.HandleCommand(cdb, dataOut)
// Send response
s.mu.Lock()
expCmdSN := s.expCmdSN.Load()
maxCmdSN := s.maxCmdSN.Load()
s.mu.Unlock()
if isRead && result.Status == SCSIStatusGood && len(result.Data) > 0 {
// Send Data-In PDUs
s.mu.Lock()
_, err := s.dataInWriter.WriteDataIn(s.conn, result.Data, itt, expCmdSN, maxCmdSN, &s.statSN)
s.mu.Unlock()
return err
}
// Send SCSI Response
s.mu.Lock()
err := SendSCSIResponse(s.conn, result, itt, &s.statSN, expCmdSN, maxCmdSN)
s.mu.Unlock()
return err
}
func (s *Session) collectDataOut(collector *DataOutCollector, itt uint32) error {
var r2tSN uint32
ttt := itt // use ITT as TTT for simplicity
for !collector.Done() {
// Send R2T
s.mu.Lock()
r2t := BuildR2T(itt, ttt, r2tSN, s.totalReceived(collector), collector.Remaining(),
s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load())
s.mu.Unlock()
if err := WritePDU(s.conn, r2t); err != nil {
return err
}
r2tSN++
// Read Data-Out PDUs until F-bit
for {
doPDU, err := ReadPDU(s.conn)
if err != nil {
return err
}
if doPDU.Opcode() != OpSCSIDataOut {
return fmt.Errorf("expected Data-Out, got %s", OpcodeName(doPDU.Opcode()))
}
if err := collector.AddDataOut(doPDU); err != nil {
return err
}
if doPDU.OpSpecific1()&FlagF != 0 {
break
}
}
}
return nil
}
func (s *Session) totalReceived(c *DataOutCollector) uint32 {
return c.expectedLen - c.Remaining()
}
func (s *Session) handleNOPOut(pdu *PDU) error {
resp := &PDU{}
resp.SetOpcode(OpNOPIn)
resp.SetOpSpecific1(FlagF)
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag())
resp.SetTargetTransferTag(0xFFFFFFFF)
s.mu.Lock()
resp.SetStatSN(s.statSN)
s.statSN++
resp.SetExpCmdSN(s.expCmdSN.Load())
resp.SetMaxCmdSN(s.maxCmdSN.Load())
s.mu.Unlock()
// Echo back data if present
if len(pdu.DataSegment) > 0 {
resp.DataSegment = pdu.DataSegment
}
return WritePDU(s.conn, resp)
}
func (s *Session) handleLogout(pdu *PDU) error {
s.mu.Lock()
defer s.mu.Unlock()
s.state = SessionLogout
resp := &PDU{}
resp.SetOpcode(OpLogoutResp)
resp.SetOpSpecific1(FlagF)
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag())
resp.BHS[2] = 0x00 // response: connection/session closed successfully
resp.SetStatSN(s.statSN)
s.statSN++
resp.SetExpCmdSN(s.expCmdSN.Load())
resp.SetMaxCmdSN(s.maxCmdSN.Load())
if err := WritePDU(s.conn, resp); err != nil {
return err
}
// Signal HandleConnection to exit
s.closed.Store(true)
s.conn.Close()
return nil
}
func (s *Session) handleTaskMgmt(pdu *PDU) error {
s.mu.Lock()
defer s.mu.Unlock()
// Simplified: always respond with "function complete"
resp := &PDU{}
resp.SetOpcode(OpSCSITaskResp)
resp.SetOpSpecific1(FlagF)
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag())
resp.BHS[2] = 0x00 // function complete
resp.SetStatSN(s.statSN)
s.statSN++
resp.SetExpCmdSN(s.expCmdSN.Load())
resp.SetMaxCmdSN(s.maxCmdSN.Load())
return WritePDU(s.conn, resp)
}
func (s *Session) advanceCmdSN() {
s.expCmdSN.Add(1)
s.maxCmdSN.Add(1)
}
func (s *Session) sendReject(origPDU *PDU, reason uint8) error {
resp := &PDU{}
resp.SetOpcode(OpReject)
resp.SetOpSpecific1(FlagF)
resp.BHS[2] = reason
resp.SetInitiatorTaskTag(0xFFFFFFFF)
s.mu.Lock()
resp.SetStatSN(s.statSN)
s.statSN++
resp.SetExpCmdSN(s.expCmdSN.Load())
resp.SetMaxCmdSN(s.maxCmdSN.Load())
s.mu.Unlock()
// Include the rejected BHS in the data segment
resp.DataSegment = origPDU.BHS[:]
return WritePDU(s.conn, resp)
}
func (s *Session) sendCheckCondition(itt uint32, senseKey, asc, ascq uint8) error {
result := SCSIResult{
Status: SCSIStatusCheckCond,
SenseKey: senseKey,
SenseASC: asc,
SenseASCQ: ascq,
}
s.mu.Lock()
defer s.mu.Unlock()
return SendSCSIResponse(s.conn, result, itt, &s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load())
}
// TargetLister is an optional interface that TargetResolver can implement
// to support SendTargets discovery.
type TargetLister interface {
ListTargets() []DiscoveryTarget
}
// DeviceLookup resolves a target IQN to a BlockDevice.
// Used after login to bind the session to the correct volume.
type DeviceLookup interface {
LookupDevice(iqn string) BlockDevice
}
// State returns the current session state.
func (s *Session) State() SessionState {
s.mu.Lock()
defer s.mu.Unlock()
return s.state
}

364
weed/storage/blockvol/iscsi/session_test.go

@ -0,0 +1,364 @@
package iscsi
import (
"bytes"
"encoding/binary"
"io"
"log"
"net"
"testing"
"time"
)
// testResolver implements TargetResolver, TargetLister, and DeviceLookup.
type testResolver struct {
targets []DiscoveryTarget
dev BlockDevice
}
func (r *testResolver) HasTarget(name string) bool {
for _, t := range r.targets {
if t.Name == name {
return true
}
}
return false
}
func (r *testResolver) ListTargets() []DiscoveryTarget {
return r.targets
}
func (r *testResolver) LookupDevice(iqn string) BlockDevice {
if r.dev != nil {
return r.dev
}
return &nullDevice{}
}
func newTestResolver() *testResolver {
return newTestResolverWithDevice(nil)
}
func newTestResolverWithDevice(dev BlockDevice) *testResolver {
return &testResolver{
targets: []DiscoveryTarget{
{Name: testTargetName, Address: "127.0.0.1:3260,1"},
},
dev: dev,
}
}
func TestSession(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"login_and_read", testLoginAndRead},
{"login_and_write", testLoginAndWrite},
{"nop_ping", testNOPPing},
{"logout", testLogout},
{"discovery_session", testDiscoverySession},
{"task_mgmt", testTaskMgmt},
{"reject_scsi_before_login", testRejectSCSIBeforeLogin},
{"connection_close", testConnectionClose},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
type sessionTestEnv struct {
clientConn net.Conn
session *Session
done chan error
}
func setupSession(t *testing.T) *sessionTestEnv {
t.Helper()
client, server := net.Pipe()
dev := newMockDevice(1024 * 4096) // 1024 blocks
config := DefaultTargetConfig()
config.TargetName = testTargetName
resolver := newTestResolverWithDevice(dev)
logger := log.New(io.Discard, "", 0)
sess := NewSession(server, config, resolver, resolver, logger)
done := make(chan error, 1)
go func() {
done <- sess.HandleConnection()
}()
t.Cleanup(func() {
sess.Close()
client.Close()
})
return &sessionTestEnv{clientConn: client, session: sess, done: done}
}
func doLogin(t *testing.T, conn net.Conn) {
t.Helper()
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
params.Set("SessionType", "Normal")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
req.SetCmdSN(1)
if err := WritePDU(conn, req); err != nil {
t.Fatal(err)
}
resp, err := ReadPDU(conn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpLoginResp {
t.Fatalf("expected LoginResp, got %s", OpcodeName(resp.Opcode()))
}
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail())
}
}
func testLoginAndRead(t *testing.T) {
env := setupSession(t)
doLogin(t, env.clientConn)
// Send SCSI READ_10 for 1 block at LBA 0
cmd := &PDU{}
cmd.SetOpcode(OpSCSICmd)
cmd.SetOpSpecific1(FlagF | FlagR) // Final + Read
cmd.SetInitiatorTaskTag(1)
cmd.SetExpectedDataTransferLength(4096)
cmd.SetCmdSN(2)
cmd.SetExpStatSN(2)
var cdb [16]byte
cdb[0] = ScsiRead10
binary.BigEndian.PutUint32(cdb[2:6], 0) // LBA=0
binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block
cmd.SetCDB(cdb)
if err := WritePDU(env.clientConn, cmd); err != nil {
t.Fatal(err)
}
// Expect Data-In with S-bit (single PDU, 4096 bytes of zeros)
resp, err := ReadPDU(env.clientConn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpSCSIDataIn {
t.Fatalf("expected Data-In, got %s", OpcodeName(resp.Opcode()))
}
if resp.OpSpecific1()&FlagS == 0 {
t.Fatal("S-bit not set")
}
if len(resp.DataSegment) != 4096 {
t.Fatalf("data length: %d", len(resp.DataSegment))
}
}
func testLoginAndWrite(t *testing.T) {
env := setupSession(t)
doLogin(t, env.clientConn)
// SCSI WRITE_10: 1 block at LBA 5 with immediate data
cmd := &PDU{}
cmd.SetOpcode(OpSCSICmd)
cmd.SetOpSpecific1(FlagF | FlagW) // Final + Write
cmd.SetInitiatorTaskTag(1)
cmd.SetExpectedDataTransferLength(4096)
cmd.SetCmdSN(2)
var cdb [16]byte
cdb[0] = ScsiWrite10
binary.BigEndian.PutUint32(cdb[2:6], 5)
binary.BigEndian.PutUint16(cdb[7:9], 1)
cmd.SetCDB(cdb)
// Include immediate data
cmd.DataSegment = bytes.Repeat([]byte{0xAA}, 4096)
if err := WritePDU(env.clientConn, cmd); err != nil {
t.Fatal(err)
}
// Expect SCSI Response (good)
resp, err := ReadPDU(env.clientConn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpSCSIResp {
t.Fatalf("expected SCSI Response, got %s", OpcodeName(resp.Opcode()))
}
if resp.SCSIStatus() != SCSIStatusGood {
t.Fatalf("status: %d", resp.SCSIStatus())
}
}
func testNOPPing(t *testing.T) {
env := setupSession(t)
doLogin(t, env.clientConn)
// Send NOP-Out
nop := &PDU{}
nop.SetOpcode(OpNOPOut)
nop.SetOpSpecific1(FlagF)
nop.SetInitiatorTaskTag(0x9999)
nop.SetImmediate(true)
nop.DataSegment = []byte("ping")
if err := WritePDU(env.clientConn, nop); err != nil {
t.Fatal(err)
}
resp, err := ReadPDU(env.clientConn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpNOPIn {
t.Fatalf("expected NOP-In, got %s", OpcodeName(resp.Opcode()))
}
if resp.InitiatorTaskTag() != 0x9999 {
t.Fatal("ITT mismatch")
}
if string(resp.DataSegment) != "ping" {
t.Fatalf("echo data: %q", resp.DataSegment)
}
}
func testLogout(t *testing.T) {
env := setupSession(t)
doLogin(t, env.clientConn)
// Send Logout
logout := &PDU{}
logout.SetOpcode(OpLogoutReq)
logout.SetOpSpecific1(FlagF)
logout.SetInitiatorTaskTag(0xAAAA)
logout.SetCmdSN(2)
if err := WritePDU(env.clientConn, logout); err != nil {
t.Fatal(err)
}
resp, err := ReadPDU(env.clientConn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpLogoutResp {
t.Fatalf("expected Logout Response, got %s", OpcodeName(resp.Opcode()))
}
// After logout, the connection should be closed by the target.
// Verify by trying to read — should get EOF.
_, err = ReadPDU(env.clientConn)
if err == nil {
t.Fatal("expected EOF after logout")
}
}
func testDiscoverySession(t *testing.T) {
env := setupSession(t)
// Login as discovery session
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("SessionType", "Discovery")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
req.SetCmdSN(1)
WritePDU(env.clientConn, req)
ReadPDU(env.clientConn) // login resp
// Send SendTargets=All
textParams := NewParams()
textParams.Set("SendTargets", "All")
textReq := makeTextReq(textParams)
textReq.SetCmdSN(2)
if err := WritePDU(env.clientConn, textReq); err != nil {
t.Fatal(err)
}
resp, err := ReadPDU(env.clientConn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpTextResp {
t.Fatalf("expected Text Response, got %s", OpcodeName(resp.Opcode()))
}
body := string(resp.DataSegment)
if len(body) == 0 {
t.Fatal("empty discovery response")
}
}
func testTaskMgmt(t *testing.T) {
env := setupSession(t)
doLogin(t, env.clientConn)
tm := &PDU{}
tm.SetOpcode(OpSCSITaskMgmt)
tm.SetOpSpecific1(FlagF | 0x01) // ABORT TASK
tm.SetInitiatorTaskTag(0xBBBB)
tm.SetImmediate(true)
if err := WritePDU(env.clientConn, tm); err != nil {
t.Fatal(err)
}
resp, err := ReadPDU(env.clientConn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpSCSITaskResp {
t.Fatalf("expected Task Mgmt Response, got %s", OpcodeName(resp.Opcode()))
}
}
func testRejectSCSIBeforeLogin(t *testing.T) {
env := setupSession(t)
// Send SCSI command without login
cmd := &PDU{}
cmd.SetOpcode(OpSCSICmd)
cmd.SetOpSpecific1(FlagF | FlagR)
cmd.SetInitiatorTaskTag(1)
if err := WritePDU(env.clientConn, cmd); err != nil {
t.Fatal(err)
}
resp, err := ReadPDU(env.clientConn)
if err != nil {
t.Fatal(err)
}
if resp.Opcode() != OpReject {
t.Fatalf("expected Reject, got %s", OpcodeName(resp.Opcode()))
}
}
func testConnectionClose(t *testing.T) {
env := setupSession(t)
doLogin(t, env.clientConn)
// Close client side
env.clientConn.Close()
select {
case err := <-env.done:
if err != nil {
t.Fatalf("unexpected error on clean close: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("session did not detect close")
}
}

216
weed/storage/blockvol/iscsi/target.go

@ -0,0 +1,216 @@
package iscsi
import (
"errors"
"fmt"
"log"
"net"
"sync"
"sync/atomic"
)
var (
ErrTargetClosed = errors.New("iscsi: target server closed")
ErrVolumeNotFound = errors.New("iscsi: volume not found")
)
// TargetServer manages the iSCSI target: TCP listener, volume registry,
// and active sessions.
type TargetServer struct {
mu sync.RWMutex
listener net.Listener
config TargetConfig
volumes map[string]BlockDevice // target IQN -> device
addr string
// Active session tracking for graceful shutdown
activeMu sync.Mutex
active map[uint64]*Session
nextID atomic.Uint64
sessions sync.WaitGroup
closed chan struct{}
logger *log.Logger
}
// NewTargetServer creates a target server bound to the given address.
func NewTargetServer(addr string, config TargetConfig, logger *log.Logger) *TargetServer {
if logger == nil {
logger = log.Default()
}
return &TargetServer{
config: config,
volumes: make(map[string]BlockDevice),
active: make(map[uint64]*Session),
addr: addr,
closed: make(chan struct{}),
logger: logger,
}
}
// AddVolume registers a block device under the given target IQN.
func (ts *TargetServer) AddVolume(iqn string, dev BlockDevice) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.volumes[iqn] = dev
ts.logger.Printf("volume added: %s", iqn)
}
// RemoveVolume unregisters a target IQN.
func (ts *TargetServer) RemoveVolume(iqn string) {
ts.mu.Lock()
defer ts.mu.Unlock()
delete(ts.volumes, iqn)
ts.logger.Printf("volume removed: %s", iqn)
}
// HasTarget implements TargetResolver.
func (ts *TargetServer) HasTarget(name string) bool {
ts.mu.RLock()
defer ts.mu.RUnlock()
_, ok := ts.volumes[name]
return ok
}
// ListTargets implements TargetLister.
func (ts *TargetServer) ListTargets() []DiscoveryTarget {
ts.mu.RLock()
defer ts.mu.RUnlock()
targets := make([]DiscoveryTarget, 0, len(ts.volumes))
for iqn := range ts.volumes {
targets = append(targets, DiscoveryTarget{
Name: iqn,
Address: ts.ListenAddr(),
})
}
return targets
}
// ListenAndServe starts listening for iSCSI connections.
// Blocks until Close() is called or an error occurs.
func (ts *TargetServer) ListenAndServe() error {
ln, err := net.Listen("tcp", ts.addr)
if err != nil {
return fmt.Errorf("iscsi target: listen: %w", err)
}
ts.mu.Lock()
ts.listener = ln
ts.mu.Unlock()
ts.logger.Printf("iSCSI target listening on %s", ln.Addr())
return ts.acceptLoop(ln)
}
// Serve accepts connections on an existing listener.
func (ts *TargetServer) Serve(ln net.Listener) error {
ts.mu.Lock()
ts.listener = ln
ts.mu.Unlock()
return ts.acceptLoop(ln)
}
func (ts *TargetServer) acceptLoop(ln net.Listener) error {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-ts.closed:
return nil
default:
return fmt.Errorf("iscsi target: accept: %w", err)
}
}
ts.sessions.Add(1)
go ts.handleConn(conn)
}
}
func (ts *TargetServer) handleConn(conn net.Conn) {
defer ts.sessions.Done()
defer conn.Close()
ts.logger.Printf("new connection from %s", conn.RemoteAddr())
sess := NewSession(conn, ts.config, ts, ts, ts.logger)
id := ts.nextID.Add(1)
ts.activeMu.Lock()
ts.active[id] = sess
ts.activeMu.Unlock()
defer func() {
ts.activeMu.Lock()
delete(ts.active, id)
ts.activeMu.Unlock()
}()
if err := sess.HandleConnection(); err != nil {
ts.logger.Printf("session error (%s): %v", conn.RemoteAddr(), err)
}
ts.logger.Printf("connection closed: %s", conn.RemoteAddr())
}
// LookupDevice implements DeviceLookup. Returns the BlockDevice for the given IQN,
// or a nullDevice if not found.
func (ts *TargetServer) LookupDevice(iqn string) BlockDevice {
ts.mu.RLock()
defer ts.mu.RUnlock()
if dev, ok := ts.volumes[iqn]; ok {
return dev
}
return &nullDevice{}
}
// Close gracefully shuts down the target server.
func (ts *TargetServer) Close() error {
select {
case <-ts.closed:
return nil
default:
}
close(ts.closed)
ts.mu.RLock()
ln := ts.listener
ts.mu.RUnlock()
if ln != nil {
ln.Close()
}
// Close all active sessions to unblock ReadPDU
ts.activeMu.Lock()
for _, sess := range ts.active {
sess.Close()
}
ts.activeMu.Unlock()
ts.sessions.Wait()
return nil
}
// ListenAddr returns the actual listen address (useful when port=0).
func (ts *TargetServer) ListenAddr() string {
ts.mu.RLock()
defer ts.mu.RUnlock()
if ts.listener != nil {
return ts.listener.Addr().String()
}
return ts.addr
}
// nullDevice is a stub BlockDevice used when no volumes are registered.
type nullDevice struct{}
func (d *nullDevice) ReadAt(lba uint64, length uint32) ([]byte, error) {
return make([]byte, length), nil
}
func (d *nullDevice) WriteAt(lba uint64, data []byte) error { return nil }
func (d *nullDevice) Trim(lba uint64, length uint32) error { return nil }
func (d *nullDevice) SyncCache() error { return nil }
func (d *nullDevice) BlockSize() uint32 { return 4096 }
func (d *nullDevice) VolumeSize() uint64 { return 0 }
func (d *nullDevice) IsHealthy() bool { return false }

259
weed/storage/blockvol/iscsi/target_test.go

@ -0,0 +1,259 @@
package iscsi
import (
"encoding/binary"
"io"
"log"
"net"
"strings"
"testing"
"time"
)
func TestTarget(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{"listen_and_connect", testListenAndConnect},
{"discovery_via_target", testDiscoveryViaTarget},
{"login_read_write", testTargetLoginReadWrite},
{"graceful_shutdown", testGracefulShutdown},
{"add_remove_volume", testAddRemoveVolume},
{"multiple_connections", testMultipleConnections},
{"connect_no_volumes", testConnectNoVolumes},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func setupTarget(t *testing.T) (*TargetServer, string) {
t.Helper()
config := DefaultTargetConfig()
config.TargetName = testTargetName
logger := log.New(io.Discard, "", 0)
ts := NewTargetServer("127.0.0.1:0", config, logger)
dev := newMockDevice(256 * 4096) // 256 blocks = 1MB
ts.AddVolume(testTargetName, dev)
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
go ts.Serve(ln)
addr := ln.Addr().String()
t.Cleanup(func() {
ts.Close()
})
return ts, addr
}
func dialTarget(t *testing.T, addr string) net.Conn {
t.Helper()
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })
return conn
}
func loginToTarget(t *testing.T, conn net.Conn) {
t.Helper()
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("TargetName", testTargetName)
params.Set("SessionType", "Normal")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
req.SetCmdSN(1)
if err := WritePDU(conn, req); err != nil {
t.Fatal(err)
}
resp, err := ReadPDU(conn)
if err != nil {
t.Fatal(err)
}
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail())
}
}
func testListenAndConnect(t *testing.T) {
_, addr := setupTarget(t)
conn := dialTarget(t, addr)
loginToTarget(t, conn)
}
func testDiscoveryViaTarget(t *testing.T) {
_, addr := setupTarget(t)
conn := dialTarget(t, addr)
// Login as discovery
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("SessionType", "Discovery")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
req.SetCmdSN(1)
WritePDU(conn, req)
ReadPDU(conn)
// SendTargets=All
textParams := NewParams()
textParams.Set("SendTargets", "All")
textReq := makeTextReq(textParams)
textReq.SetCmdSN(2)
WritePDU(conn, textReq)
resp, err := ReadPDU(conn)
if err != nil {
t.Fatal(err)
}
body := string(resp.DataSegment)
if !strings.Contains(body, testTargetName) {
t.Fatalf("discovery response missing target: %q", body)
}
}
func testTargetLoginReadWrite(t *testing.T) {
_, addr := setupTarget(t)
conn := dialTarget(t, addr)
loginToTarget(t, conn)
// Write 1 block at LBA 0
cmd := &PDU{}
cmd.SetOpcode(OpSCSICmd)
cmd.SetOpSpecific1(FlagF | FlagW)
cmd.SetInitiatorTaskTag(1)
cmd.SetExpectedDataTransferLength(4096)
cmd.SetCmdSN(2)
var cdb [16]byte
cdb[0] = ScsiWrite10
binary.BigEndian.PutUint32(cdb[2:6], 0)
binary.BigEndian.PutUint16(cdb[7:9], 1)
cmd.SetCDB(cdb)
data := make([]byte, 4096)
for i := range data {
data[i] = 0xAB
}
cmd.DataSegment = data
WritePDU(conn, cmd)
resp, _ := ReadPDU(conn)
if resp.SCSIStatus() != SCSIStatusGood {
t.Fatalf("write failed: %d", resp.SCSIStatus())
}
// Read it back
cmd2 := &PDU{}
cmd2.SetOpcode(OpSCSICmd)
cmd2.SetOpSpecific1(FlagF | FlagR)
cmd2.SetInitiatorTaskTag(2)
cmd2.SetExpectedDataTransferLength(4096)
cmd2.SetCmdSN(3)
var cdb2 [16]byte
cdb2[0] = ScsiRead10
binary.BigEndian.PutUint32(cdb2[2:6], 0)
binary.BigEndian.PutUint16(cdb2[7:9], 1)
cmd2.SetCDB(cdb2)
WritePDU(conn, cmd2)
resp2, _ := ReadPDU(conn)
if resp2.Opcode() != OpSCSIDataIn {
t.Fatalf("expected Data-In, got %s", OpcodeName(resp2.Opcode()))
}
if resp2.DataSegment[0] != 0xAB {
t.Fatal("data mismatch")
}
}
func testGracefulShutdown(t *testing.T) {
ts, addr := setupTarget(t)
conn := dialTarget(t, addr)
loginToTarget(t, conn)
// Close the target — should shut down cleanly
ts.Close()
// Connection should be dropped
_, err := ReadPDU(conn)
if err == nil {
t.Fatal("expected error after shutdown")
}
}
func testAddRemoveVolume(t *testing.T) {
config := DefaultTargetConfig()
logger := log.New(io.Discard, "", 0)
ts := NewTargetServer("127.0.0.1:0", config, logger)
ts.AddVolume("iqn.2024.com.test:v1", newMockDevice(1024*4096))
ts.AddVolume("iqn.2024.com.test:v2", newMockDevice(1024*4096))
if !ts.HasTarget("iqn.2024.com.test:v1") {
t.Fatal("v1 should exist")
}
ts.RemoveVolume("iqn.2024.com.test:v1")
if ts.HasTarget("iqn.2024.com.test:v1") {
t.Fatal("v1 should be removed")
}
if !ts.HasTarget("iqn.2024.com.test:v2") {
t.Fatal("v2 should still exist")
}
targets := ts.ListTargets()
if len(targets) != 1 {
t.Fatalf("expected 1 target, got %d", len(targets))
}
}
func testMultipleConnections(t *testing.T) {
_, addr := setupTarget(t)
// Connect multiple clients
for i := 0; i < 5; i++ {
conn := dialTarget(t, addr)
loginToTarget(t, conn)
}
}
func testConnectNoVolumes(t *testing.T) {
config := DefaultTargetConfig()
logger := log.New(io.Discard, "", 0)
ts := NewTargetServer("127.0.0.1:0", config, logger)
// No volumes added
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
go ts.Serve(ln)
t.Cleanup(func() { ts.Close() })
conn := dialTarget(t, ln.Addr().String())
// Login should work (discovery is fine without volumes)
params := NewParams()
params.Set("InitiatorName", testInitiatorName)
params.Set("SessionType", "Discovery")
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params)
req.SetCmdSN(1)
WritePDU(conn, req)
resp, _ := ReadPDU(conn)
if resp.LoginStatusClass() != LoginStatusSuccess {
t.Fatal("discovery login should succeed without volumes")
}
}

38
weed/storage/blockvol/lba.go

@ -0,0 +1,38 @@
package blockvol
import (
"errors"
"fmt"
)
var (
ErrLBAOutOfBounds = errors.New("blockvol: LBA out of bounds")
ErrWritePastEnd = errors.New("blockvol: write extends past volume end")
ErrAlignment = errors.New("blockvol: data length not aligned to block size")
)
// ValidateLBA checks that lba is within the volume's logical address space.
func ValidateLBA(lba uint64, volumeSize uint64, blockSize uint32) error {
maxLBA := volumeSize / uint64(blockSize)
if lba >= maxLBA {
return fmt.Errorf("%w: lba=%d, max=%d", ErrLBAOutOfBounds, lba, maxLBA-1)
}
return nil
}
// ValidateWrite checks that a write at lba with dataLen bytes fits within
// the volume and that dataLen is aligned to blockSize.
func ValidateWrite(lba uint64, dataLen uint32, volumeSize uint64, blockSize uint32) error {
if err := ValidateLBA(lba, volumeSize, blockSize); err != nil {
return err
}
if dataLen%blockSize != 0 {
return fmt.Errorf("%w: dataLen=%d, blockSize=%d", ErrAlignment, dataLen, blockSize)
}
blocksNeeded := uint64(dataLen / blockSize)
maxLBA := volumeSize / uint64(blockSize)
if lba+blocksNeeded > maxLBA {
return fmt.Errorf("%w: lba=%d, blocks=%d, max=%d", ErrWritePastEnd, lba, blocksNeeded, maxLBA)
}
return nil
}

95
weed/storage/blockvol/lba_test.go

@ -0,0 +1,95 @@
package blockvol
import (
"errors"
"testing"
)
func TestLBAValidation(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "lba_within_bounds", run: testLBAWithinBounds},
{name: "lba_last_block", run: testLBALastBlock},
{name: "lba_out_of_bounds", run: testLBAOutOfBounds},
{name: "lba_write_spans_end", run: testLBAWriteSpansEnd},
{name: "lba_alignment", run: testLBAAlignment},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
const (
testVolSize = 100 * 1024 * 1024 * 1024 // 100GB
testBlockSize = 4096
)
func testLBAWithinBounds(t *testing.T) {
err := ValidateLBA(0, testVolSize, testBlockSize)
if err != nil {
t.Errorf("LBA=0 should be valid: %v", err)
}
err = ValidateLBA(1000, testVolSize, testBlockSize)
if err != nil {
t.Errorf("LBA=1000 should be valid: %v", err)
}
}
func testLBALastBlock(t *testing.T) {
maxLBA := uint64(testVolSize/testBlockSize) - 1
err := ValidateLBA(maxLBA, testVolSize, testBlockSize)
if err != nil {
t.Errorf("last LBA=%d should be valid: %v", maxLBA, err)
}
}
func testLBAOutOfBounds(t *testing.T) {
maxLBA := uint64(testVolSize / testBlockSize)
err := ValidateLBA(maxLBA, testVolSize, testBlockSize)
if !errors.Is(err, ErrLBAOutOfBounds) {
t.Errorf("LBA=%d (one past end): expected ErrLBAOutOfBounds, got %v", maxLBA, err)
}
err = ValidateLBA(maxLBA+1000, testVolSize, testBlockSize)
if !errors.Is(err, ErrLBAOutOfBounds) {
t.Errorf("LBA far past end: expected ErrLBAOutOfBounds, got %v", err)
}
}
func testLBAWriteSpansEnd(t *testing.T) {
maxLBA := uint64(testVolSize / testBlockSize)
// Write 2 blocks starting at last LBA -- spans past end.
lastLBA := maxLBA - 1
err := ValidateWrite(lastLBA, 2*testBlockSize, testVolSize, testBlockSize)
if !errors.Is(err, ErrWritePastEnd) {
t.Errorf("write spanning end: expected ErrWritePastEnd, got %v", err)
}
// Write 1 block at last LBA -- should succeed.
err = ValidateWrite(lastLBA, testBlockSize, testVolSize, testBlockSize)
if err != nil {
t.Errorf("write at last LBA should succeed: %v", err)
}
}
func testLBAAlignment(t *testing.T) {
err := ValidateWrite(0, 4096, testVolSize, testBlockSize)
if err != nil {
t.Errorf("aligned write should succeed: %v", err)
}
err = ValidateWrite(0, 4000, testVolSize, testBlockSize)
if !errors.Is(err, ErrAlignment) {
t.Errorf("unaligned write: expected ErrAlignment, got %v", err)
}
err = ValidateWrite(0, 1, testVolSize, testBlockSize)
if !errors.Is(err, ErrAlignment) {
t.Errorf("1-byte write: expected ErrAlignment, got %v", err)
}
}

157
weed/storage/blockvol/recovery.go

@ -0,0 +1,157 @@
package blockvol
import (
"fmt"
"os"
)
// RecoveryResult contains the outcome of WAL recovery.
type RecoveryResult struct {
EntriesReplayed int // number of entries replayed into dirty map
HighestLSN uint64 // highest LSN seen during recovery
TornEntries int // entries discarded due to CRC failure
}
// RecoverWAL scans the WAL region from tail to head, replaying valid entries
// into the dirty map. Entries with LSN <= checkpointLSN are skipped (already
// in extent). Scanning stops at the first CRC failure (torn write).
//
// The WAL is a circular buffer. If head >= tail, scan [tail, head).
// If head < tail (wrapped), scan [tail, walSize) then [0, head).
func RecoverWAL(fd *os.File, sb *Superblock, dirtyMap *DirtyMap) (RecoveryResult, error) {
result := RecoveryResult{}
logicalHead := sb.WALHead
logicalTail := sb.WALTail
walOffset := sb.WALOffset
walSize := sb.WALSize
checkpointLSN := sb.WALCheckpointLSN
if logicalHead == logicalTail {
// WAL is empty (or fully flushed).
return result, nil
}
// Convert logical positions to physical.
physHead := logicalHead % walSize
physTail := logicalTail % walSize
// Build the list of byte ranges to scan.
type scanRange struct {
start, end uint64 // physical positions within WAL
}
var ranges []scanRange
if physHead > physTail {
// No wrap: scan [tail, head).
ranges = append(ranges, scanRange{physTail, physHead})
} else if physHead == physTail {
// Head and tail at same physical position but different logical positions
// means the WAL is completely full. Scan the entire region.
ranges = append(ranges, scanRange{physTail, walSize})
if physHead > 0 {
ranges = append(ranges, scanRange{0, physHead})
}
} else {
// Wrapped: scan [tail, walSize) then [0, head).
ranges = append(ranges, scanRange{physTail, walSize})
if physHead > 0 {
ranges = append(ranges, scanRange{0, physHead})
}
}
for _, r := range ranges {
pos := r.start
for pos < r.end {
remaining := r.end - pos
// Need at least a header to proceed.
if remaining < uint64(walEntryHeaderSize) {
break
}
// Read header.
headerBuf := make([]byte, walEntryHeaderSize)
absOff := int64(walOffset + pos)
if _, err := fd.ReadAt(headerBuf, absOff); err != nil {
return result, fmt.Errorf("recovery: read header at WAL+%d: %w", pos, err)
}
// Parse entry type and length field.
entryType := headerBuf[16]
lengthField := parseLength(headerBuf)
// For padding entries, skip forward.
if entryType == EntryTypePadding {
entrySize := uint64(walEntryHeaderSize) + uint64(lengthField)
pos += entrySize
continue
}
// Calculate on-disk entry size. WRITE and PADDING carry data payload;
// TRIM and BARRIER do not (Length is metadata, not data size).
var payloadLen uint64
if entryType == EntryTypeWrite {
payloadLen = uint64(lengthField)
}
entrySize := uint64(walEntryHeaderSize) + payloadLen
if entrySize > remaining {
// Torn write: entry extends past available data.
result.TornEntries++
break
}
// Read full entry.
fullBuf := make([]byte, entrySize)
if _, err := fd.ReadAt(fullBuf, absOff); err != nil {
return result, fmt.Errorf("recovery: read entry at WAL+%d: %w", pos, err)
}
// Decode and validate CRC.
entry, err := DecodeWALEntry(fullBuf)
if err != nil {
// CRC failure or corrupt entry — stop here (torn write).
result.TornEntries++
break
}
// Skip entries already flushed to extent.
if entry.LSN <= checkpointLSN {
pos += entrySize
continue
}
// Replay entry.
switch entry.Type {
case EntryTypeWrite:
blocks := entry.Length / sb.BlockSize
for i := uint32(0); i < blocks; i++ {
dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize)
}
result.EntriesReplayed++
case EntryTypeTrim:
// TRIM carries Length (bytes) covering multiple blocks.
blocks := entry.Length / sb.BlockSize
if blocks == 0 {
blocks = 1 // legacy single-block trim
}
for i := uint32(0); i < blocks; i++ {
dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize)
}
result.EntriesReplayed++
case EntryTypeBarrier:
// Barriers don't modify data, just skip.
}
if entry.LSN > result.HighestLSN {
result.HighestLSN = entry.LSN
}
pos += entrySize
}
}
return result, nil
}

416
weed/storage/blockvol/recovery_test.go

@ -0,0 +1,416 @@
package blockvol
import (
"bytes"
"path/filepath"
"testing"
"time"
)
func TestRecovery(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "recover_empty_wal", run: testRecoverEmptyWAL},
{name: "recover_one_entry", run: testRecoverOneEntry},
{name: "recover_many_entries", run: testRecoverManyEntries},
{name: "recover_torn_write", run: testRecoverTornWrite},
{name: "recover_after_checkpoint", run: testRecoverAfterCheckpoint},
{name: "recover_idempotent", run: testRecoverIdempotent},
{name: "recover_wal_full", run: testRecoverWALFull},
{name: "recover_barrier_only", run: testRecoverBarrierOnly},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
// simulateCrash closes the volume without syncing and returns the path.
func simulateCrash(v *BlockVol) string {
path := v.Path()
v.groupCommit.Stop()
v.fd.Close()
return path
}
func testRecoverEmptyWAL(t *testing.T) {
v := createTestVol(t)
// Sync to make superblock durable.
v.fd.Sync()
path := simulateCrash(v)
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
// Empty volume: read should return zeros.
got, err := v2.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA: %v", err)
}
if !bytes.Equal(got, make([]byte, 4096)) {
t.Error("expected zeros from empty volume")
}
}
func testRecoverOneEntry(t *testing.T) {
v := createTestVol(t)
data := makeBlock('A')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
// Sync WAL to make the entry durable.
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
// Update superblock with WAL state.
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
v.fd.Seek(0, 0)
v.super.WriteTo(v.fd)
v.fd.Sync()
path := simulateCrash(v)
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
got, err := v2.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA after recovery: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("data not recovered")
}
}
func testRecoverManyEntries(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "many.blockvol")
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: 4 * 1024 * 1024, // 4MB
BlockSize: 4096,
WALSize: 2 * 1024 * 1024, // 2MB WAL -- enough for ~500 entries
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
const numWrites = 400
for i := uint64(0); i < numWrites; i++ {
if err := v.WriteLBA(i, makeBlock(byte(i%26+'A'))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
v.fd.Seek(0, 0)
v.super.WriteTo(v.fd)
v.fd.Sync()
path = simulateCrash(v)
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
for i := uint64(0); i < numWrites; i++ {
got, err := v2.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d) after recovery: %v", i, err)
}
expected := makeBlock(byte(i%26 + 'A'))
if !bytes.Equal(got, expected) {
t.Errorf("block %d: data mismatch after recovery", i)
}
}
}
func testRecoverTornWrite(t *testing.T) {
v := createTestVol(t)
// Write 2 entries.
if err := v.WriteLBA(0, makeBlock('A')); err != nil {
t.Fatalf("WriteLBA(0): %v", err)
}
if err := v.WriteLBA(1, makeBlock('B')); err != nil {
t.Fatalf("WriteLBA(1): %v", err)
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
// Save superblock with correct WAL state.
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
v.fd.Seek(0, 0)
v.super.WriteTo(v.fd)
v.fd.Sync()
// Corrupt the last 2 bytes of the second entry to simulate torn write.
entrySize := uint64(walEntryHeaderSize + 4096)
secondEntryEnd := v.super.WALOffset + entrySize*2
corruptOff := int64(secondEntryEnd - 2)
v.fd.WriteAt([]byte{0xFF, 0xFF}, corruptOff)
v.fd.Sync()
path := simulateCrash(v)
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
// First entry should be recovered.
got, err := v2.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA(0) after torn recovery: %v", err)
}
if !bytes.Equal(got, makeBlock('A')) {
t.Error("block 0 should be recovered")
}
// Second entry was torn -- should NOT be in dirty map.
// Read returns zeros (from extent).
got, err = v2.ReadLBA(1, 4096)
if err != nil {
t.Fatalf("ReadLBA(1) after torn recovery: %v", err)
}
if !bytes.Equal(got, make([]byte, 4096)) {
t.Error("block 1 (torn) should return zeros")
}
}
func testRecoverAfterCheckpoint(t *testing.T) {
v := createTestVol(t)
// Write blocks 0-9.
for i := uint64(0); i < 10; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
// Flush first 5 blocks (flusher moves them to extent, advances checkpoint).
f := NewFlusher(FlusherConfig{
FD: v.fd,
Super: &v.super,
WAL: v.wal,
DirtyMap: v.dirtyMap,
Interval: 1 * time.Hour,
})
// Flush all (flusher takes all dirty entries).
// To simulate partial flush, we'll manually set checkpoint to midpoint.
if err := f.FlushOnce(); err != nil {
t.Fatalf("FlushOnce: %v", err)
}
// Now write 5 more blocks (these will need replay after crash).
for i := uint64(10); i < 15; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
// Update superblock.
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
v.fd.Seek(0, 0)
v.super.WriteTo(v.fd)
v.fd.Sync()
path := simulateCrash(v)
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
// Blocks 0-9 should be in extent (flushed). Blocks 10-14 replayed from WAL.
for i := uint64(0); i < 15; i++ {
got, err := v2.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d): %v", i, err)
}
expected := makeBlock(byte('A' + i))
if !bytes.Equal(got, expected) {
t.Errorf("block %d: data mismatch after checkpoint recovery", i)
}
}
}
func testRecoverIdempotent(t *testing.T) {
v := createTestVol(t)
data := makeBlock('X')
if err := v.WriteLBA(0, data); err != nil {
t.Fatalf("WriteLBA: %v", err)
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
v.fd.Seek(0, 0)
v.super.WriteTo(v.fd)
v.fd.Sync()
path := simulateCrash(v)
// First recovery.
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol 1: %v", err)
}
// Verify data.
got, err := v2.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA 1: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("first recovery: data mismatch")
}
// Close and recover again (should be idempotent).
path2 := simulateCrash(v2)
v3, err := OpenBlockVol(path2)
if err != nil {
t.Fatalf("OpenBlockVol 2: %v", err)
}
defer v3.Close()
got, err = v3.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA 2: %v", err)
}
if !bytes.Equal(got, data) {
t.Error("second recovery: data mismatch")
}
}
func testRecoverWALFull(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "full.blockvol")
entrySize := uint64(walEntryHeaderSize + 4096)
walSize := entrySize * 5 // room for exactly 5 entries
v, err := CreateBlockVol(path, CreateOptions{
VolumeSize: 1 * 1024 * 1024,
BlockSize: 4096,
WALSize: walSize,
})
if err != nil {
t.Fatalf("CreateBlockVol: %v", err)
}
// Write exactly 5 entries to fill WAL.
for i := uint64(0); i < 5; i++ {
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil {
t.Fatalf("WriteLBA(%d): %v", i, err)
}
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
v.fd.Seek(0, 0)
v.super.WriteTo(v.fd)
v.fd.Sync()
path = simulateCrash(v)
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
for i := uint64(0); i < 5; i++ {
got, err := v2.ReadLBA(i, 4096)
if err != nil {
t.Fatalf("ReadLBA(%d): %v", i, err)
}
expected := makeBlock(byte('A' + i))
if !bytes.Equal(got, expected) {
t.Errorf("block %d: data mismatch after full WAL recovery", i)
}
}
}
func testRecoverBarrierOnly(t *testing.T) {
v := createTestVol(t)
// Write a barrier entry.
lsn := v.nextLSN.Add(1) - 1
entry := &WALEntry{
LSN: lsn,
Type: EntryTypeBarrier,
LBA: 0,
}
if _, err := v.wal.Append(entry); err != nil {
t.Fatalf("Append barrier: %v", err)
}
if err := v.SyncCache(); err != nil {
t.Fatalf("SyncCache: %v", err)
}
v.super.WALHead = v.wal.LogicalHead()
v.super.WALTail = v.wal.LogicalTail()
v.fd.Seek(0, 0)
v.super.WriteTo(v.fd)
v.fd.Sync()
path := simulateCrash(v)
v2, err := OpenBlockVol(path)
if err != nil {
t.Fatalf("OpenBlockVol: %v", err)
}
defer v2.Close()
// No data changes from barrier.
got, err := v2.ReadLBA(0, 4096)
if err != nil {
t.Fatalf("ReadLBA: %v", err)
}
if !bytes.Equal(got, make([]byte, 4096)) {
t.Error("barrier-only WAL should leave data as zeros")
}
}

249
weed/storage/blockvol/superblock.go

@ -0,0 +1,249 @@
package blockvol
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/google/uuid"
)
const (
SuperblockSize = 4096
MagicSWBK = "SWBK"
CurrentVersion = 1
)
var (
ErrNotBlockVol = errors.New("blockvol: not a blockvol file (bad magic)")
ErrUnsupportedVersion = errors.New("blockvol: unsupported version")
ErrInvalidVolumeSize = errors.New("blockvol: volume size must be > 0")
)
// Superblock is the 4KB header at offset 0 of a blockvol file.
// It identifies the file format, stores volume geometry, and tracks WAL state.
type Superblock struct {
Magic [4]byte
Version uint16
Flags uint16
UUID [16]byte
VolumeSize uint64 // logical size in bytes
ExtentSize uint32 // default 64KB
BlockSize uint32 // default 4KB (min I/O unit)
WALOffset uint64 // byte offset where WAL region starts
WALSize uint64 // WAL region size in bytes
WALHead uint64 // logical WAL write position (monotonically increasing)
WALTail uint64 // logical WAL flush position (monotonically increasing)
WALCheckpointLSN uint64 // last LSN flushed to extent region
Replication [4]byte
CreatedAt uint64 // unix timestamp
SnapshotCount uint32
}
// superblockOnDisk is the fixed-size on-disk layout (binary.Write/Read target).
// Must be <= SuperblockSize bytes. Remaining space is zero-padded.
type superblockOnDisk struct {
Magic [4]byte
Version uint16
Flags uint16
UUID [16]byte
VolumeSize uint64
ExtentSize uint32
BlockSize uint32
WALOffset uint64
WALSize uint64
WALHead uint64
WALTail uint64
WALCheckpointLSN uint64
Replication [4]byte
CreatedAt uint64
SnapshotCount uint32
}
// NewSuperblock creates a superblock with defaults and a fresh UUID.
func NewSuperblock(volumeSize uint64, opts CreateOptions) (Superblock, error) {
if volumeSize == 0 {
return Superblock{}, ErrInvalidVolumeSize
}
extentSize := opts.ExtentSize
if extentSize == 0 {
extentSize = 64 * 1024 // 64KB
}
blockSize := opts.BlockSize
if blockSize == 0 {
blockSize = 4096 // 4KB
}
walSize := opts.WALSize
if walSize == 0 {
walSize = 64 * 1024 * 1024 // 64MB
}
var repl [4]byte
if opts.Replication != "" {
copy(repl[:], opts.Replication)
} else {
copy(repl[:], "000")
}
id := uuid.New()
sb := Superblock{
Version: CurrentVersion,
VolumeSize: volumeSize,
ExtentSize: extentSize,
BlockSize: blockSize,
WALOffset: SuperblockSize,
WALSize: walSize,
}
copy(sb.Magic[:], MagicSWBK)
sb.UUID = id
sb.Replication = repl
return sb, nil
}
// WriteTo serializes the superblock to w as a 4096-byte block.
func (sb *Superblock) WriteTo(w io.Writer) (int64, error) {
buf := make([]byte, SuperblockSize)
d := superblockOnDisk{
Magic: sb.Magic,
Version: sb.Version,
Flags: sb.Flags,
UUID: sb.UUID,
VolumeSize: sb.VolumeSize,
ExtentSize: sb.ExtentSize,
BlockSize: sb.BlockSize,
WALOffset: sb.WALOffset,
WALSize: sb.WALSize,
WALHead: sb.WALHead,
WALTail: sb.WALTail,
WALCheckpointLSN: sb.WALCheckpointLSN,
Replication: sb.Replication,
CreatedAt: sb.CreatedAt,
SnapshotCount: sb.SnapshotCount,
}
// Encode into beginning of buf; rest stays zero (padding).
endian := binary.LittleEndian
off := 0
off += copy(buf[off:], d.Magic[:])
endian.PutUint16(buf[off:], d.Version)
off += 2
endian.PutUint16(buf[off:], d.Flags)
off += 2
off += copy(buf[off:], d.UUID[:])
endian.PutUint64(buf[off:], d.VolumeSize)
off += 8
endian.PutUint32(buf[off:], d.ExtentSize)
off += 4
endian.PutUint32(buf[off:], d.BlockSize)
off += 4
endian.PutUint64(buf[off:], d.WALOffset)
off += 8
endian.PutUint64(buf[off:], d.WALSize)
off += 8
endian.PutUint64(buf[off:], d.WALHead)
off += 8
endian.PutUint64(buf[off:], d.WALTail)
off += 8
endian.PutUint64(buf[off:], d.WALCheckpointLSN)
off += 8
off += copy(buf[off:], d.Replication[:])
endian.PutUint64(buf[off:], d.CreatedAt)
off += 8
endian.PutUint32(buf[off:], d.SnapshotCount)
n, err := w.Write(buf)
return int64(n), err
}
// ReadSuperblock reads and validates a superblock from r.
func ReadSuperblock(r io.Reader) (Superblock, error) {
buf := make([]byte, SuperblockSize)
if _, err := io.ReadFull(r, buf); err != nil {
return Superblock{}, fmt.Errorf("blockvol: read superblock: %w", err)
}
endian := binary.LittleEndian
var sb Superblock
off := 0
copy(sb.Magic[:], buf[off:off+4])
off += 4
if string(sb.Magic[:]) != MagicSWBK {
return Superblock{}, ErrNotBlockVol
}
sb.Version = endian.Uint16(buf[off:])
off += 2
if sb.Version != CurrentVersion {
return Superblock{}, fmt.Errorf("%w: got %d, want %d", ErrUnsupportedVersion, sb.Version, CurrentVersion)
}
sb.Flags = endian.Uint16(buf[off:])
off += 2
copy(sb.UUID[:], buf[off:off+16])
off += 16
sb.VolumeSize = endian.Uint64(buf[off:])
off += 8
if sb.VolumeSize == 0 {
return Superblock{}, ErrInvalidVolumeSize
}
sb.ExtentSize = endian.Uint32(buf[off:])
off += 4
sb.BlockSize = endian.Uint32(buf[off:])
off += 4
sb.WALOffset = endian.Uint64(buf[off:])
off += 8
sb.WALSize = endian.Uint64(buf[off:])
off += 8
sb.WALHead = endian.Uint64(buf[off:])
off += 8
sb.WALTail = endian.Uint64(buf[off:])
off += 8
sb.WALCheckpointLSN = endian.Uint64(buf[off:])
off += 8
copy(sb.Replication[:], buf[off:off+4])
off += 4
sb.CreatedAt = endian.Uint64(buf[off:])
off += 8
sb.SnapshotCount = endian.Uint32(buf[off:])
return sb, nil
}
var ErrInvalidSuperblock = errors.New("blockvol: invalid superblock")
// Validate checks that the superblock fields are internally consistent.
func (sb *Superblock) Validate() error {
if string(sb.Magic[:]) != MagicSWBK {
return ErrNotBlockVol
}
if sb.Version != CurrentVersion {
return fmt.Errorf("%w: got %d", ErrUnsupportedVersion, sb.Version)
}
if sb.VolumeSize == 0 {
return ErrInvalidVolumeSize
}
if sb.BlockSize == 0 {
return fmt.Errorf("%w: BlockSize is 0", ErrInvalidSuperblock)
}
if sb.ExtentSize == 0 {
return fmt.Errorf("%w: ExtentSize is 0", ErrInvalidSuperblock)
}
if sb.WALSize == 0 {
return fmt.Errorf("%w: WALSize is 0", ErrInvalidSuperblock)
}
if sb.WALOffset != SuperblockSize {
return fmt.Errorf("%w: WALOffset=%d, expected %d", ErrInvalidSuperblock, sb.WALOffset, SuperblockSize)
}
if sb.VolumeSize%uint64(sb.BlockSize) != 0 {
return fmt.Errorf("%w: VolumeSize %d not aligned to BlockSize %d", ErrInvalidSuperblock, sb.VolumeSize, sb.BlockSize)
}
return nil
}

146
weed/storage/blockvol/superblock_test.go

@ -0,0 +1,146 @@
package blockvol
import (
"bytes"
"encoding/binary"
"errors"
"testing"
)
func TestSuperblock(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "superblock_roundtrip", run: testSuperblockRoundtrip},
{name: "superblock_magic_check", run: testSuperblockMagicCheck},
{name: "superblock_version_check", run: testSuperblockVersionCheck},
{name: "superblock_zero_vol_size", run: testSuperblockZeroVolSize},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testSuperblockRoundtrip(t *testing.T) {
sb, err := NewSuperblock(100*1024*1024*1024, CreateOptions{
ExtentSize: 64 * 1024,
BlockSize: 4096,
WALSize: 64 * 1024 * 1024,
Replication: "010",
})
if err != nil {
t.Fatalf("NewSuperblock: %v", err)
}
sb.CreatedAt = 1700000000
sb.SnapshotCount = 3
sb.WALHead = 8192
sb.WALCheckpointLSN = 42
var buf bytes.Buffer
n, err := sb.WriteTo(&buf)
if err != nil {
t.Fatalf("WriteTo: %v", err)
}
if n != SuperblockSize {
t.Fatalf("WriteTo wrote %d bytes, want %d", n, SuperblockSize)
}
got, err := ReadSuperblock(&buf)
if err != nil {
t.Fatalf("ReadSuperblock: %v", err)
}
if string(got.Magic[:]) != MagicSWBK {
t.Errorf("Magic = %q, want %q", got.Magic, MagicSWBK)
}
if got.Version != CurrentVersion {
t.Errorf("Version = %d, want %d", got.Version, CurrentVersion)
}
if got.UUID != sb.UUID {
t.Errorf("UUID mismatch")
}
if got.VolumeSize != sb.VolumeSize {
t.Errorf("VolumeSize = %d, want %d", got.VolumeSize, sb.VolumeSize)
}
if got.ExtentSize != sb.ExtentSize {
t.Errorf("ExtentSize = %d, want %d", got.ExtentSize, sb.ExtentSize)
}
if got.BlockSize != sb.BlockSize {
t.Errorf("BlockSize = %d, want %d", got.BlockSize, sb.BlockSize)
}
if got.WALOffset != sb.WALOffset {
t.Errorf("WALOffset = %d, want %d", got.WALOffset, sb.WALOffset)
}
if got.WALSize != sb.WALSize {
t.Errorf("WALSize = %d, want %d", got.WALSize, sb.WALSize)
}
if got.WALHead != sb.WALHead {
t.Errorf("WALHead = %d, want %d", got.WALHead, sb.WALHead)
}
if got.WALCheckpointLSN != sb.WALCheckpointLSN {
t.Errorf("WALCheckpointLSN = %d, want %d", got.WALCheckpointLSN, sb.WALCheckpointLSN)
}
if got.Replication != sb.Replication {
t.Errorf("Replication = %q, want %q", got.Replication, sb.Replication)
}
if got.CreatedAt != sb.CreatedAt {
t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, sb.CreatedAt)
}
if got.SnapshotCount != sb.SnapshotCount {
t.Errorf("SnapshotCount = %d, want %d", got.SnapshotCount, sb.SnapshotCount)
}
}
func testSuperblockMagicCheck(t *testing.T) {
// Write a valid superblock, then corrupt the magic bytes.
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{})
var buf bytes.Buffer
sb.WriteTo(&buf)
data := buf.Bytes()
copy(data[0:4], "XXXX") // corrupt magic
_, err := ReadSuperblock(bytes.NewReader(data))
if !errors.Is(err, ErrNotBlockVol) {
t.Errorf("expected ErrNotBlockVol, got %v", err)
}
}
func testSuperblockVersionCheck(t *testing.T) {
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{})
var buf bytes.Buffer
sb.WriteTo(&buf)
data := buf.Bytes()
// Version is at offset 4, uint16 little-endian.
binary.LittleEndian.PutUint16(data[4:], 99)
_, err := ReadSuperblock(bytes.NewReader(data))
if !errors.Is(err, ErrUnsupportedVersion) {
t.Errorf("expected ErrUnsupportedVersion, got %v", err)
}
}
func testSuperblockZeroVolSize(t *testing.T) {
_, err := NewSuperblock(0, CreateOptions{})
if !errors.Is(err, ErrInvalidVolumeSize) {
t.Errorf("NewSuperblock(0): expected ErrInvalidVolumeSize, got %v", err)
}
// Also test reading a superblock with zero volume size.
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{})
var buf bytes.Buffer
sb.WriteTo(&buf)
data := buf.Bytes()
// VolumeSize is at offset 4+2+2+16 = 24, uint64 little-endian.
binary.LittleEndian.PutUint64(data[24:], 0)
_, err = ReadSuperblock(bytes.NewReader(data))
if !errors.Is(err, ErrInvalidVolumeSize) {
t.Errorf("ReadSuperblock(vol_size=0): expected ErrInvalidVolumeSize, got %v", err)
}
}

153
weed/storage/blockvol/wal_entry.go

@ -0,0 +1,153 @@
package blockvol
import (
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
)
const (
EntryTypeWrite = 0x01
EntryTypeTrim = 0x02
EntryTypeBarrier = 0x03
EntryTypePadding = 0xFF
// walEntryHeaderSize is the fixed portion: LSN(8) + Epoch(8) + Type(1) +
// Flags(1) + LBA(8) + Length(4) + CRC32(4) + EntrySize(4) = 38 bytes.
walEntryHeaderSize = 38
)
var (
ErrCRCMismatch = errors.New("blockvol: CRC mismatch")
ErrInvalidEntry = errors.New("blockvol: invalid WAL entry")
ErrEntryTruncated = errors.New("blockvol: WAL entry truncated")
)
// WALEntry is a variable-size record in the WAL region.
type WALEntry struct {
LSN uint64
Epoch uint64 // writer's epoch (Phase 1: always 0)
Type uint8 // EntryTypeWrite, EntryTypeTrim, EntryTypeBarrier
Flags uint8
LBA uint64 // in blocks
Length uint32 // data length in bytes (WRITE: data size, TRIM: trim size, BARRIER: 0)
Data []byte // present only for WRITE
CRC32 uint32 // covers LSN through Data
EntrySize uint32 // total serialized size
}
// Encode serializes the entry into a byte slice, computing CRC over all fields.
func (e *WALEntry) Encode() ([]byte, error) {
switch e.Type {
case EntryTypeWrite:
if len(e.Data) == 0 {
return nil, fmt.Errorf("%w: WRITE entry with no data", ErrInvalidEntry)
}
if uint32(len(e.Data)) != e.Length {
return nil, fmt.Errorf("%w: data length %d != Length field %d", ErrInvalidEntry, len(e.Data), e.Length)
}
case EntryTypeTrim:
if len(e.Data) != 0 {
return nil, fmt.Errorf("%w: TRIM entry must have no data payload", ErrInvalidEntry)
}
// TRIM carries Length (trim size in bytes) but no Data payload.
case EntryTypeBarrier:
if e.Length != 0 || len(e.Data) != 0 {
return nil, fmt.Errorf("%w: BARRIER entry must have no data", ErrInvalidEntry)
}
}
totalSize := uint32(walEntryHeaderSize + len(e.Data))
buf := make([]byte, totalSize)
le := binary.LittleEndian
off := 0
le.PutUint64(buf[off:], e.LSN)
off += 8
le.PutUint64(buf[off:], e.Epoch)
off += 8
buf[off] = e.Type
off++
buf[off] = e.Flags
off++
le.PutUint64(buf[off:], e.LBA)
off += 8
le.PutUint32(buf[off:], e.Length)
off += 4
if len(e.Data) > 0 {
copy(buf[off:], e.Data)
off += len(e.Data)
}
// CRC covers everything from start through Data (offset 0 to off).
checksum := crc32.ChecksumIEEE(buf[:off])
le.PutUint32(buf[off:], checksum)
off += 4
le.PutUint32(buf[off:], totalSize)
e.CRC32 = checksum
e.EntrySize = totalSize
return buf, nil
}
// DecodeWALEntry deserializes a WAL entry from buf, validating CRC.
func DecodeWALEntry(buf []byte) (WALEntry, error) {
if len(buf) < walEntryHeaderSize {
return WALEntry{}, fmt.Errorf("%w: need %d bytes, have %d", ErrEntryTruncated, walEntryHeaderSize, len(buf))
}
le := binary.LittleEndian
var e WALEntry
off := 0
e.LSN = le.Uint64(buf[off:])
off += 8
e.Epoch = le.Uint64(buf[off:])
off += 8
e.Type = buf[off]
off++
e.Flags = buf[off]
off++
e.LBA = le.Uint64(buf[off:])
off += 8
e.Length = le.Uint32(buf[off:])
off += 4
// For WRITE entries, Length is the data payload size.
// For TRIM entries, Length is the trim extent in bytes (no data payload).
// For BARRIER/PADDING, Length is 0 (or padding size).
var dataLen int
if e.Type == EntryTypeWrite || e.Type == EntryTypePadding {
dataLen = int(e.Length)
}
dataEnd := off + dataLen
if dataEnd+8 > len(buf) { // +8 for CRC32 + EntrySize
return WALEntry{}, fmt.Errorf("%w: need %d bytes for data+footer, have %d", ErrEntryTruncated, dataEnd+8, len(buf))
}
if dataLen > 0 {
e.Data = make([]byte, dataLen)
copy(e.Data, buf[off:dataEnd])
}
off = dataEnd
e.CRC32 = le.Uint32(buf[off:])
off += 4
e.EntrySize = le.Uint32(buf[off:])
// Verify CRC: covers LSN through Data.
expected := crc32.ChecksumIEEE(buf[:dataEnd])
if e.CRC32 != expected {
return WALEntry{}, fmt.Errorf("%w: stored=%08x computed=%08x", ErrCRCMismatch, e.CRC32, expected)
}
// Verify EntrySize matches actual layout.
expectedSize := uint32(walEntryHeaderSize) + uint32(dataLen)
if e.EntrySize != expectedSize {
return WALEntry{}, fmt.Errorf("%w: EntrySize=%d, expected=%d", ErrInvalidEntry, e.EntrySize, expectedSize)
}
return e, nil
}

284
weed/storage/blockvol/wal_entry_test.go

@ -0,0 +1,284 @@
package blockvol
import (
"bytes"
"encoding/binary"
"errors"
"testing"
)
func TestWALEntry(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "wal_entry_roundtrip", run: testWALEntryRoundtrip},
{name: "wal_entry_trim_no_data", run: testWALEntryTrimNoData},
{name: "wal_entry_barrier", run: testWALEntryBarrier},
{name: "wal_entry_crc_valid", run: testWALEntryCRCValid},
{name: "wal_entry_crc_corrupt", run: testWALEntryCRCCorrupt},
{name: "wal_entry_max_size", run: testWALEntryMaxSize},
{name: "wal_entry_zero_length", run: testWALEntryZeroLength},
{name: "wal_entry_trim_rejects_data", run: testWALEntryTrimRejectsData},
{name: "wal_entry_barrier_rejects_data", run: testWALEntryBarrierRejectsData},
{name: "wal_entry_bad_entry_size", run: testWALEntryBadEntrySize},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
func testWALEntryRoundtrip(t *testing.T) {
data := make([]byte, 4096)
for i := range data {
data[i] = byte(i % 251) // deterministic pattern
}
e := WALEntry{
LSN: 1,
Epoch: 0,
Type: EntryTypeWrite,
Flags: 0,
LBA: 100,
Length: uint32(len(data)),
Data: data,
}
buf, err := e.Encode()
if err != nil {
t.Fatalf("Encode: %v", err)
}
got, err := DecodeWALEntry(buf)
if err != nil {
t.Fatalf("DecodeWALEntry: %v", err)
}
if got.LSN != e.LSN {
t.Errorf("LSN = %d, want %d", got.LSN, e.LSN)
}
if got.Epoch != e.Epoch {
t.Errorf("Epoch = %d, want %d", got.Epoch, e.Epoch)
}
if got.Type != e.Type {
t.Errorf("Type = %d, want %d", got.Type, e.Type)
}
if got.LBA != e.LBA {
t.Errorf("LBA = %d, want %d", got.LBA, e.LBA)
}
if got.Length != e.Length {
t.Errorf("Length = %d, want %d", got.Length, e.Length)
}
if !bytes.Equal(got.Data, e.Data) {
t.Errorf("Data mismatch")
}
if got.EntrySize != uint32(walEntryHeaderSize+len(data)) {
t.Errorf("EntrySize = %d, want %d", got.EntrySize, walEntryHeaderSize+len(data))
}
}
func testWALEntryTrimNoData(t *testing.T) {
// TRIM carries Length (trim size in bytes) but no Data payload.
e := WALEntry{
LSN: 5,
Epoch: 0,
Type: EntryTypeTrim,
LBA: 200,
Length: 4096,
}
buf, err := e.Encode()
if err != nil {
t.Fatalf("Encode: %v", err)
}
got, err := DecodeWALEntry(buf)
if err != nil {
t.Fatalf("DecodeWALEntry: %v", err)
}
if got.Type != EntryTypeTrim {
t.Errorf("Type = %d, want %d", got.Type, EntryTypeTrim)
}
if got.Length != 4096 {
t.Errorf("Length = %d, want 4096", got.Length)
}
if len(got.Data) != 0 {
t.Errorf("Data should be empty, got %d bytes", len(got.Data))
}
// TRIM with Length but no Data: EntrySize = header only.
if got.EntrySize != uint32(walEntryHeaderSize) {
t.Errorf("EntrySize = %d, want %d (header only)", got.EntrySize, walEntryHeaderSize)
}
}
func testWALEntryBarrier(t *testing.T) {
e := WALEntry{
LSN: 10,
Epoch: 0,
Type: EntryTypeBarrier,
}
buf, err := e.Encode()
if err != nil {
t.Fatalf("Encode: %v", err)
}
got, err := DecodeWALEntry(buf)
if err != nil {
t.Fatalf("DecodeWALEntry: %v", err)
}
if got.Type != EntryTypeBarrier {
t.Errorf("Type = %d, want %d", got.Type, EntryTypeBarrier)
}
if len(got.Data) != 0 {
t.Errorf("Barrier should have no data, got %d bytes", len(got.Data))
}
}
func testWALEntryCRCValid(t *testing.T) {
data := []byte("hello blockvol WAL")
e := WALEntry{
LSN: 1,
Type: EntryTypeWrite,
LBA: 0,
Length: uint32(len(data)),
Data: data,
}
buf, err := e.Encode()
if err != nil {
t.Fatalf("Encode: %v", err)
}
// Decode should succeed with valid CRC.
_, err = DecodeWALEntry(buf)
if err != nil {
t.Fatalf("valid entry should decode without error, got: %v", err)
}
}
func testWALEntryCRCCorrupt(t *testing.T) {
data := []byte("hello blockvol WAL")
e := WALEntry{
LSN: 1,
Type: EntryTypeWrite,
LBA: 0,
Length: uint32(len(data)),
Data: data,
}
buf, err := e.Encode()
if err != nil {
t.Fatalf("Encode: %v", err)
}
// Flip one bit in the data region.
buf[walEntryHeaderSize-8] ^= 0x01 // flip bit in data area (before CRC/EntrySize footer)
_, err = DecodeWALEntry(buf)
if !errors.Is(err, ErrCRCMismatch) {
t.Errorf("expected ErrCRCMismatch, got %v", err)
}
}
func testWALEntryMaxSize(t *testing.T) {
data := make([]byte, 64*1024) // 64KB
for i := range data {
data[i] = byte(i % 256)
}
e := WALEntry{
LSN: 99,
Epoch: 0,
Type: EntryTypeWrite,
LBA: 0,
Length: uint32(len(data)),
Data: data,
}
buf, err := e.Encode()
if err != nil {
t.Fatalf("Encode 64KB entry: %v", err)
}
got, err := DecodeWALEntry(buf)
if err != nil {
t.Fatalf("DecodeWALEntry: %v", err)
}
if !bytes.Equal(got.Data, data) {
t.Errorf("64KB data mismatch after roundtrip")
}
}
func testWALEntryZeroLength(t *testing.T) {
e := WALEntry{
LSN: 1,
Type: EntryTypeWrite,
LBA: 0,
Length: 0,
Data: nil,
}
_, err := e.Encode()
if !errors.Is(err, ErrInvalidEntry) {
t.Errorf("WRITE with zero data: expected ErrInvalidEntry, got %v", err)
}
}
func testWALEntryTrimRejectsData(t *testing.T) {
// TRIM allows Length (trim extent) but rejects Data payload.
e := WALEntry{
LSN: 1,
Type: EntryTypeTrim,
LBA: 0,
Length: 4096,
Data: []byte("bad!"),
}
_, err := e.Encode()
if !errors.Is(err, ErrInvalidEntry) {
t.Errorf("TRIM with data payload: expected ErrInvalidEntry, got %v", err)
}
}
func testWALEntryBarrierRejectsData(t *testing.T) {
e := WALEntry{
LSN: 1,
Type: EntryTypeBarrier,
Length: 1,
Data: []byte("x"),
}
_, err := e.Encode()
if !errors.Is(err, ErrInvalidEntry) {
t.Errorf("BARRIER with data: expected ErrInvalidEntry, got %v", err)
}
}
func testWALEntryBadEntrySize(t *testing.T) {
data := []byte("test data for entry size validation")
e := WALEntry{
LSN: 1,
Type: EntryTypeWrite,
LBA: 0,
Length: uint32(len(data)),
Data: data,
}
buf, err := e.Encode()
if err != nil {
t.Fatalf("Encode: %v", err)
}
// Corrupt EntrySize (last 4 bytes of buf).
le := binary.LittleEndian
le.PutUint32(buf[len(buf)-4:], 9999)
_, err = DecodeWALEntry(buf)
if !errors.Is(err, ErrInvalidEntry) {
t.Errorf("bad EntrySize: expected ErrInvalidEntry, got %v", err)
}
}

192
weed/storage/blockvol/wal_writer.go

@ -0,0 +1,192 @@
package blockvol
import (
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"os"
"sync"
)
var (
ErrWALFull = errors.New("blockvol: WAL region full")
)
// WALWriter appends entries to the circular WAL region of a blockvol file.
//
// It uses logical (monotonically increasing) head and tail counters to track
// used space. Physical position = logical % walSize. This eliminates the
// classic circular buffer ambiguity where head==tail could mean empty or full.
// Used space = logicalHead - logicalTail. Free space = walSize - used.
type WALWriter struct {
mu sync.Mutex
fd *os.File
walOffset uint64 // absolute file offset where WAL region starts
walSize uint64 // size of the WAL region in bytes
logicalHead uint64 // monotonically increasing write position
logicalTail uint64 // monotonically increasing flush position
}
// NewWALWriter creates a WAL writer for the given file.
// head and tail are physical positions relative to WAL region start.
// For a fresh WAL, both are 0.
func NewWALWriter(fd *os.File, walOffset, walSize, head, tail uint64) *WALWriter {
return &WALWriter{
fd: fd,
walOffset: walOffset,
walSize: walSize,
logicalHead: head, // on fresh WAL, physical == logical (both start at 0)
logicalTail: tail,
}
}
// physicalPos converts a logical position to a physical WAL offset.
func (w *WALWriter) physicalPos(logical uint64) uint64 {
return logical % w.walSize
}
// used returns the number of bytes occupied in the WAL.
func (w *WALWriter) used() uint64 {
return w.logicalHead - w.logicalTail
}
// Append writes a serialized WAL entry to the circular WAL region.
// Returns the physical WAL-relative offset where the entry was written.
// If the entry doesn't fit in the remaining space before the region end,
// a padding entry is written and the real entry starts at physical offset 0.
func (w *WALWriter) Append(entry *WALEntry) (walRelOffset uint64, err error) {
buf, err := entry.Encode()
if err != nil {
return 0, fmt.Errorf("WALWriter.Append: encode: %w", err)
}
w.mu.Lock()
defer w.mu.Unlock()
entryLen := uint64(len(buf))
if entryLen > w.walSize {
return 0, fmt.Errorf("%w: entry size %d exceeds WAL size %d", ErrWALFull, entryLen, w.walSize)
}
physHead := w.physicalPos(w.logicalHead)
remaining := w.walSize - physHead
if remaining < entryLen {
// Not enough room at end of region -- write padding and wrap.
// Padding consumes 'remaining' bytes logically.
if w.used()+remaining+entryLen > w.walSize {
return 0, ErrWALFull
}
if err := w.writePadding(remaining, physHead); err != nil {
return 0, fmt.Errorf("WALWriter.Append: padding: %w", err)
}
w.logicalHead += remaining
physHead = 0
}
// Check if there's enough free space for the entry.
if w.used()+entryLen > w.walSize {
return 0, ErrWALFull
}
absOffset := int64(w.walOffset + physHead)
if _, err := w.fd.WriteAt(buf, absOffset); err != nil {
return 0, fmt.Errorf("WALWriter.Append: pwrite at offset %d: %w", absOffset, err)
}
writeOffset := physHead
w.logicalHead += entryLen
return writeOffset, nil
}
// writePadding writes a padding entry at the given physical position.
func (w *WALWriter) writePadding(size uint64, physPos uint64) error {
if size < walEntryHeaderSize {
// Too small for a proper entry header -- zero it out.
buf := make([]byte, size)
absOffset := int64(w.walOffset + physPos)
_, err := w.fd.WriteAt(buf, absOffset)
return err
}
buf := make([]byte, size)
le := binary.LittleEndian
off := 0
le.PutUint64(buf[off:], 0) // LSN=0
off += 8
le.PutUint64(buf[off:], 0) // Epoch=0
off += 8
buf[off] = EntryTypePadding
off++
buf[off] = 0 // Flags
off++
le.PutUint64(buf[off:], 0) // LBA=0
off += 8
paddingDataLen := uint32(size) - uint32(walEntryHeaderSize)
le.PutUint32(buf[off:], paddingDataLen)
off += 4
dataEnd := off + int(paddingDataLen)
crc := crc32.ChecksumIEEE(buf[:dataEnd])
le.PutUint32(buf[dataEnd:], crc)
le.PutUint32(buf[dataEnd+4:], uint32(size))
absOffset := int64(w.walOffset + physPos)
_, err := w.fd.WriteAt(buf, absOffset)
return err
}
// AdvanceTail moves the tail forward, freeing WAL space.
// Called by the flusher after entries have been written to the extent region.
// newTail is a physical position; it is converted to a logical advance.
func (w *WALWriter) AdvanceTail(newTail uint64) {
w.mu.Lock()
physTail := w.physicalPos(w.logicalTail)
var advance uint64
if newTail >= physTail {
advance = newTail - physTail
} else {
// Tail wrapped around.
advance = w.walSize - physTail + newTail
}
w.logicalTail += advance
w.mu.Unlock()
}
// Head returns the current physical head position (relative to WAL start).
func (w *WALWriter) Head() uint64 {
w.mu.Lock()
h := w.physicalPos(w.logicalHead)
w.mu.Unlock()
return h
}
// Tail returns the current physical tail position (relative to WAL start).
func (w *WALWriter) Tail() uint64 {
w.mu.Lock()
t := w.physicalPos(w.logicalTail)
w.mu.Unlock()
return t
}
// LogicalHead returns the logical (monotonically increasing) head position.
func (w *WALWriter) LogicalHead() uint64 {
w.mu.Lock()
h := w.logicalHead
w.mu.Unlock()
return h
}
// LogicalTail returns the logical (monotonically increasing) tail position.
func (w *WALWriter) LogicalTail() uint64 {
w.mu.Lock()
t := w.logicalTail
w.mu.Unlock()
return t
}
// Sync fsyncs the underlying file descriptor.
func (w *WALWriter) Sync() error {
return w.fd.Sync()
}

244
weed/storage/blockvol/wal_writer_test.go

@ -0,0 +1,244 @@
package blockvol
import (
"os"
"path/filepath"
"testing"
)
func TestWALWriter(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{name: "wal_writer_append_read_back", run: testWALWriterAppendReadBack},
{name: "wal_writer_multiple_entries", run: testWALWriterMultipleEntries},
{name: "wal_writer_wrap_around", run: testWALWriterWrapAround},
{name: "wal_writer_full", run: testWALWriterFull},
{name: "wal_writer_advance_tail_frees_space", run: testWALWriterAdvanceTailFreesSpace},
{name: "wal_writer_fill_no_flusher", run: testWALWriterFillNoFlusher},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.run(t)
})
}
}
// createTestWAL creates a temp file with a WAL region at the given offset and size.
func createTestWAL(t *testing.T, walOffset, walSize uint64) (*os.File, func()) {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "test.blockvol")
fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
t.Fatalf("create temp file: %v", err)
}
// Extend file to cover superblock + WAL region.
totalSize := int64(walOffset + walSize)
if err := fd.Truncate(totalSize); err != nil {
fd.Close()
t.Fatalf("truncate: %v", err)
}
return fd, func() { fd.Close() }
}
func testWALWriterAppendReadBack(t *testing.T) {
walOffset := uint64(SuperblockSize)
walSize := uint64(64 * 1024) // 64KB WAL
fd, cleanup := createTestWAL(t, walOffset, walSize)
defer cleanup()
w := NewWALWriter(fd, walOffset, walSize, 0, 0)
data := []byte("hello WAL writer test!")
// Pad to block size for a realistic entry.
padded := make([]byte, 4096)
copy(padded, data)
entry := &WALEntry{
LSN: 1,
Type: EntryTypeWrite,
LBA: 0,
Length: uint32(len(padded)),
Data: padded,
}
off, err := w.Append(entry)
if err != nil {
t.Fatalf("Append: %v", err)
}
if off != 0 {
t.Errorf("first entry offset = %d, want 0", off)
}
// Read back from file and decode.
buf := make([]byte, entry.EntrySize)
if _, err := fd.ReadAt(buf, int64(walOffset+off)); err != nil {
t.Fatalf("ReadAt: %v", err)
}
decoded, err := DecodeWALEntry(buf)
if err != nil {
t.Fatalf("DecodeWALEntry: %v", err)
}
if decoded.LSN != 1 {
t.Errorf("decoded LSN = %d, want 1", decoded.LSN)
}
}
func testWALWriterMultipleEntries(t *testing.T) {
walOffset := uint64(SuperblockSize)
walSize := uint64(64 * 1024)
fd, cleanup := createTestWAL(t, walOffset, walSize)
defer cleanup()
w := NewWALWriter(fd, walOffset, walSize, 0, 0)
for i := uint64(1); i <= 5; i++ {
entry := &WALEntry{
LSN: i,
Type: EntryTypeWrite,
LBA: i * 10,
Length: 4096,
Data: make([]byte, 4096),
}
_, err := w.Append(entry)
if err != nil {
t.Fatalf("Append entry %d: %v", i, err)
}
}
expectedHead := uint64(5 * (walEntryHeaderSize + 4096))
if w.Head() != expectedHead {
t.Errorf("head = %d, want %d", w.Head(), expectedHead)
}
}
func testWALWriterWrapAround(t *testing.T) {
walOffset := uint64(SuperblockSize)
// Small WAL: 2.5x entry size. After writing 2 entries (head at 2*entrySize),
// remaining (0.5*entrySize) is too small for entry 3, so it wraps to offset 0.
// Tail must be advanced past both entries so there's free space after wrap.
entrySize := uint64(walEntryHeaderSize + 4096)
walSize := entrySize*2 + entrySize/2
fd, cleanup := createTestWAL(t, walOffset, walSize)
defer cleanup()
w := NewWALWriter(fd, walOffset, walSize, 0, 0)
// Write 2 entries to fill most of the WAL.
for i := uint64(1); i <= 2; i++ {
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)}
if _, err := w.Append(entry); err != nil {
t.Fatalf("Append entry %d: %v", i, err)
}
}
// Advance tail past both entries (simulates flusher flushed them).
w.AdvanceTail(entrySize * 2)
// Write a 3rd entry -- should wrap around.
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 3, Length: 4096, Data: make([]byte, 4096)}
off, err := w.Append(entry3)
if err != nil {
t.Fatalf("Append entry 3 (wrap): %v", err)
}
// After wrap, entry should be written at offset 0.
if off != 0 {
t.Errorf("wrapped entry offset = %d, want 0", off)
}
}
func testWALWriterFull(t *testing.T) {
walOffset := uint64(SuperblockSize)
entrySize := uint64(walEntryHeaderSize + 4096)
walSize := entrySize * 2 // fits exactly 2 entries
fd, cleanup := createTestWAL(t, walOffset, walSize)
defer cleanup()
w := NewWALWriter(fd, walOffset, walSize, 0, 0)
// Fill the WAL with 2 entries (exact fit).
for i := uint64(1); i <= 2; i++ {
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)}
if _, err := w.Append(entry); err != nil {
t.Fatalf("Append entry %d: %v", i, err)
}
}
// Third entry should fail -- tail hasn't moved, so no free space.
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)}
_, err := w.Append(entry3)
if err == nil {
t.Fatal("expected ErrWALFull when WAL is full")
}
}
func testWALWriterAdvanceTailFreesSpace(t *testing.T) {
walOffset := uint64(SuperblockSize)
entrySize := uint64(walEntryHeaderSize + 4096)
walSize := entrySize * 2
fd, cleanup := createTestWAL(t, walOffset, walSize)
defer cleanup()
w := NewWALWriter(fd, walOffset, walSize, 0, 0)
// Fill the WAL with 2 entries.
for i := uint64(1); i <= 2; i++ {
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)}
if _, err := w.Append(entry); err != nil {
t.Fatalf("Append entry %d: %v", i, err)
}
}
// WAL is full. Third entry should fail.
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)}
if _, err := w.Append(entry3); err == nil {
t.Fatal("expected ErrWALFull before AdvanceTail")
}
// Advance tail to free space for 1 entry.
w.AdvanceTail(entrySize)
// Now entry 3 should succeed.
if _, err := w.Append(entry3); err != nil {
t.Fatalf("Append after AdvanceTail: %v", err)
}
}
func testWALWriterFillNoFlusher(t *testing.T) {
// QA-001 regression: fill WAL without flusher (tail stays at 0).
// After wrap, head=0 and tail=0 must NOT be treated as "empty".
walOffset := uint64(SuperblockSize)
entrySize := uint64(walEntryHeaderSize + 4096)
walSize := entrySize * 10 // room for ~10 entries
fd, cleanup := createTestWAL(t, walOffset, walSize)
defer cleanup()
w := NewWALWriter(fd, walOffset, walSize, 0, 0)
// Fill the WAL completely -- tail never moves (no flusher).
written := 0
for i := uint64(1); ; i++ {
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)}
_, err := w.Append(entry)
if err != nil {
// Should eventually get ErrWALFull, NOT wrap and overwrite.
break
}
written++
if written > 20 {
t.Fatalf("wrote %d entries to a 10-entry WAL without ErrWALFull -- wrap overwrote live entries (QA-001)", written)
}
}
if written == 0 {
t.Fatal("should have written at least 1 entry")
}
t.Logf("correctly wrote %d entries then got ErrWALFull", written)
}
Loading…
Cancel
Save