diff --git a/weed/storage/blockvol/blockvol.go b/weed/storage/blockvol/blockvol.go new file mode 100644 index 000000000..8062582b9 --- /dev/null +++ b/weed/storage/blockvol/blockvol.go @@ -0,0 +1,371 @@ +// Package blockvol implements a block volume storage engine with WAL, +// dirty map, group commit, and crash recovery. It operates on a single +// file and has no dependency on SeaweedFS internals. +package blockvol + +import ( + "encoding/binary" + "fmt" + "os" + "sync" + "sync/atomic" + "time" +) + +// CreateOptions configures a new block volume. +type CreateOptions struct { + VolumeSize uint64 // required, logical size in bytes + ExtentSize uint32 // default 64KB + BlockSize uint32 // default 4KB + WALSize uint64 // default 64MB + Replication string // default "000" +} + +// BlockVol is the core block volume engine. +type BlockVol struct { + mu sync.RWMutex + fd *os.File + path string + super Superblock + wal *WALWriter + dirtyMap *DirtyMap + groupCommit *GroupCommitter + flusher *Flusher + nextLSN atomic.Uint64 + healthy atomic.Bool +} + +// CreateBlockVol creates a new block volume file at path. +func CreateBlockVol(path string, opts CreateOptions) (*BlockVol, error) { + if opts.VolumeSize == 0 { + return nil, ErrInvalidVolumeSize + } + + sb, err := NewSuperblock(opts.VolumeSize, opts) + if err != nil { + return nil, fmt.Errorf("blockvol: create superblock: %w", err) + } + sb.CreatedAt = uint64(time.Now().Unix()) + + // Extent region starts after superblock + WAL. + extentStart := sb.WALOffset + sb.WALSize + totalFileSize := extentStart + opts.VolumeSize + + fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0644) + if err != nil { + return nil, fmt.Errorf("blockvol: create file: %w", err) + } + + if err := fd.Truncate(int64(totalFileSize)); err != nil { + fd.Close() + os.Remove(path) + return nil, fmt.Errorf("blockvol: truncate: %w", err) + } + + if _, err := sb.WriteTo(fd); err != nil { + fd.Close() + os.Remove(path) + return nil, fmt.Errorf("blockvol: write superblock: %w", err) + } + + if err := fd.Sync(); err != nil { + fd.Close() + os.Remove(path) + return nil, fmt.Errorf("blockvol: sync: %w", err) + } + + dm := NewDirtyMap() + wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, 0, 0) + v := &BlockVol{ + fd: fd, + path: path, + super: sb, + wal: wal, + dirtyMap: dm, + } + v.nextLSN.Store(1) + v.healthy.Store(true) + v.groupCommit = NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: fd.Sync, + OnDegraded: func() { v.healthy.Store(false) }, + }) + go v.groupCommit.Run() + v.flusher = NewFlusher(FlusherConfig{ + FD: fd, + Super: &v.super, + WAL: wal, + DirtyMap: dm, + }) + go v.flusher.Run() + return v, nil +} + +// OpenBlockVol opens an existing block volume file and runs crash recovery. +func OpenBlockVol(path string) (*BlockVol, error) { + fd, err := os.OpenFile(path, os.O_RDWR, 0644) + if err != nil { + return nil, fmt.Errorf("blockvol: open file: %w", err) + } + + sb, err := ReadSuperblock(fd) + if err != nil { + fd.Close() + return nil, fmt.Errorf("blockvol: read superblock: %w", err) + } + if err := sb.Validate(); err != nil { + fd.Close() + return nil, fmt.Errorf("blockvol: validate superblock: %w", err) + } + + dirtyMap := NewDirtyMap() + + // Run WAL recovery: replay entries from tail to head. + result, err := RecoverWAL(fd, &sb, dirtyMap) + if err != nil { + fd.Close() + return nil, fmt.Errorf("blockvol: recovery: %w", err) + } + + nextLSN := sb.WALCheckpointLSN + 1 + if result.HighestLSN >= nextLSN { + nextLSN = result.HighestLSN + 1 + } + + wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, sb.WALHead, sb.WALTail) + v := &BlockVol{ + fd: fd, + path: path, + super: sb, + wal: wal, + dirtyMap: dirtyMap, + } + v.nextLSN.Store(nextLSN) + v.healthy.Store(true) + v.groupCommit = NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: fd.Sync, + OnDegraded: func() { v.healthy.Store(false) }, + }) + go v.groupCommit.Run() + v.flusher = NewFlusher(FlusherConfig{ + FD: fd, + Super: &v.super, + WAL: wal, + DirtyMap: dirtyMap, + }) + go v.flusher.Run() + return v, nil +} + +// WriteLBA writes data at the given logical block address. +// Data length must be a multiple of BlockSize. +func (v *BlockVol) WriteLBA(lba uint64, data []byte) error { + if err := ValidateWrite(lba, uint32(len(data)), v.super.VolumeSize, v.super.BlockSize); err != nil { + return err + } + + lsn := v.nextLSN.Add(1) - 1 + entry := &WALEntry{ + LSN: lsn, + Epoch: 0, // Phase 1: no fencing + Type: EntryTypeWrite, + LBA: lba, + Length: uint32(len(data)), + Data: data, + } + + walOff, err := v.wal.Append(entry) + if err != nil { + return fmt.Errorf("blockvol: WriteLBA: %w", err) + } + + // Update dirty map: one entry per block written. + blocks := uint32(len(data)) / v.super.BlockSize + for i := uint32(0); i < blocks; i++ { + blockOff := walOff // all blocks in this entry share the same WAL offset + v.dirtyMap.Put(lba+uint64(i), blockOff, lsn, v.super.BlockSize) + } + + return nil +} + +// ReadLBA reads data at the given logical block address. +// length is in bytes and must be a multiple of BlockSize. +func (v *BlockVol) ReadLBA(lba uint64, length uint32) ([]byte, error) { + if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil { + return nil, err + } + + blocks := length / v.super.BlockSize + result := make([]byte, length) + + for i := uint32(0); i < blocks; i++ { + blockLBA := lba + uint64(i) + blockData, err := v.readOneBlock(blockLBA) + if err != nil { + return nil, fmt.Errorf("blockvol: ReadLBA block %d: %w", blockLBA, err) + } + copy(result[i*v.super.BlockSize:], blockData) + } + + return result, nil +} + +// readOneBlock reads a single block, checking dirty map first, then extent. +func (v *BlockVol) readOneBlock(lba uint64) ([]byte, error) { + walOff, _, _, ok := v.dirtyMap.Get(lba) + if ok { + return v.readBlockFromWAL(walOff, lba) + } + return v.readBlockFromExtent(lba) +} + +// maxWALEntryDataLen caps the data length we trust from a WAL entry header. +// Anything larger than the WAL region itself is corrupt. +const maxWALEntryDataLen = 256 * 1024 * 1024 // 256MB absolute ceiling + +// readBlockFromWAL reads a block's data from its WAL entry. +func (v *BlockVol) readBlockFromWAL(walOff uint64, lba uint64) ([]byte, error) { + // Read the WAL entry header to get the full entry size. + headerBuf := make([]byte, walEntryHeaderSize) + absOff := int64(v.super.WALOffset + walOff) + if _, err := v.fd.ReadAt(headerBuf, absOff); err != nil { + return nil, fmt.Errorf("readBlockFromWAL: read header at %d: %w", absOff, err) + } + + // Check entry type first — TRIM has no data payload, so Length is + // metadata (trim extent), not a data size to allocate. + entryType := headerBuf[16] // Type is at offset LSN(8) + Epoch(8) = 16 + if entryType == EntryTypeTrim { + // TRIM entry: return zeros regardless of Length field. + return make([]byte, v.super.BlockSize), nil + } + if entryType != EntryTypeWrite { + return nil, fmt.Errorf("readBlockFromWAL: expected WRITE or TRIM entry, got type 0x%02x", entryType) + } + + // Parse and validate the data Length field before allocating (WRITE only). + dataLen := v.parseDataLength(headerBuf) + if uint64(dataLen) > v.super.WALSize || uint64(dataLen) > maxWALEntryDataLen { + return nil, fmt.Errorf("readBlockFromWAL: corrupt entry length %d exceeds WAL size %d", dataLen, v.super.WALSize) + } + + entryLen := walEntryHeaderSize + int(dataLen) + fullBuf := make([]byte, entryLen) + if _, err := v.fd.ReadAt(fullBuf, absOff); err != nil { + return nil, fmt.Errorf("readBlockFromWAL: read entry at %d: %w", absOff, err) + } + + entry, err := DecodeWALEntry(fullBuf) + if err != nil { + return nil, fmt.Errorf("readBlockFromWAL: decode: %w", err) + } + + // Find the block within the entry's data. + blockOffset := (lba - entry.LBA) * uint64(v.super.BlockSize) + if blockOffset+uint64(v.super.BlockSize) > uint64(len(entry.Data)) { + return nil, fmt.Errorf("readBlockFromWAL: block offset %d out of range for entry data len %d", blockOffset, len(entry.Data)) + } + + block := make([]byte, v.super.BlockSize) + copy(block, entry.Data[blockOffset:blockOffset+uint64(v.super.BlockSize)]) + return block, nil +} + +// parseDataLength extracts the Length field from a WAL entry header buffer. +func (v *BlockVol) parseDataLength(headerBuf []byte) uint32 { + // Length is at offset: LSN(8) + Epoch(8) + Type(1) + Flags(1) + LBA(8) = 26 + return binary.LittleEndian.Uint32(headerBuf[26:]) +} + +// readBlockFromExtent reads a block directly from the extent region. +func (v *BlockVol) readBlockFromExtent(lba uint64) ([]byte, error) { + extentStart := v.super.WALOffset + v.super.WALSize + byteOffset := extentStart + lba*uint64(v.super.BlockSize) + + block := make([]byte, v.super.BlockSize) + if _, err := v.fd.ReadAt(block, int64(byteOffset)); err != nil { + return nil, fmt.Errorf("readBlockFromExtent: pread at %d: %w", byteOffset, err) + } + return block, nil +} + +// Trim marks blocks as deallocated. Subsequent reads return zeros. +// The trim is recorded in the WAL with a Length field so the flusher +// can zero the extent region and recovery can replay the trim. +func (v *BlockVol) Trim(lba uint64, length uint32) error { + if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil { + return err + } + + lsn := v.nextLSN.Add(1) - 1 + entry := &WALEntry{ + LSN: lsn, + Epoch: 0, + Type: EntryTypeTrim, + LBA: lba, + Length: length, + } + + walOff, err := v.wal.Append(entry) + if err != nil { + return fmt.Errorf("blockvol: Trim: %w", err) + } + + // Update dirty map: mark each trimmed block so the flusher sees it. + // readOneBlock checks entry type and returns zeros for TRIM entries. + blocks := length / v.super.BlockSize + for i := uint32(0); i < blocks; i++ { + v.dirtyMap.Put(lba+uint64(i), walOff, lsn, v.super.BlockSize) + } + + return nil +} + +// Path returns the file path of the block volume. +func (v *BlockVol) Path() string { + return v.path +} + +// Info returns volume metadata. +func (v *BlockVol) Info() VolumeInfo { + return VolumeInfo{ + VolumeSize: v.super.VolumeSize, + BlockSize: v.super.BlockSize, + ExtentSize: v.super.ExtentSize, + WALSize: v.super.WALSize, + Healthy: v.healthy.Load(), + } +} + +// VolumeInfo contains read-only volume metadata. +type VolumeInfo struct { + VolumeSize uint64 + BlockSize uint32 + ExtentSize uint32 + WALSize uint64 + Healthy bool +} + +// SyncCache ensures all previously written WAL entries are durable on disk. +// It submits a sync request to the group committer, which batches fsyncs. +func (v *BlockVol) SyncCache() error { + return v.groupCommit.Submit() +} + +// Close shuts down the block volume and closes the file. +// Shutdown order: group committer → stop flusher goroutine → final flush → close fd. +func (v *BlockVol) Close() error { + if v.groupCommit != nil { + v.groupCommit.Stop() + } + var flushErr error + if v.flusher != nil { + v.flusher.Stop() // stop background goroutine first (no concurrent flush) + flushErr = v.flusher.FlushOnce() // then do final flush safely + } + closeErr := v.fd.Close() + if flushErr != nil { + return flushErr + } + return closeErr +} diff --git a/weed/storage/blockvol/blockvol_qa_test.go b/weed/storage/blockvol/blockvol_qa_test.go new file mode 100644 index 000000000..b5cdefd7e --- /dev/null +++ b/weed/storage/blockvol/blockvol_qa_test.go @@ -0,0 +1,3903 @@ +package blockvol + +// QA adversarial tests — written by QA Manager (separate from dev team's unit tests). +// Attack vectors: boundary conditions, multi-block I/O, trim semantics, +// concurrency, oracle pattern, corruption injection, lifecycle edge cases. + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math/rand" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestQA(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_multi_block_write_read_middle", run: testQAMultiBlockWriteReadMiddle}, + {name: "qa_trim_then_read_zeros", run: testQATrimThenReadZeros}, + {name: "qa_trim_dirty_then_read_zeros", run: testQATrimDirtyThenReadZeros}, + {name: "qa_write_last_lba", run: testQAWriteLastLBA}, + {name: "qa_overwrite_wider", run: testQAOverwriteWider}, + {name: "qa_overwrite_narrower", run: testQAOverwriteNarrower}, + {name: "qa_read_never_written", run: testQAReadNeverWritten}, + {name: "qa_concurrent_writes", run: testQAConcurrentWrites}, + {name: "qa_concurrent_write_read", run: testQAConcurrentWriteRead}, + {name: "qa_wal_fill_advance_refill", run: testQAWALFillAdvanceRefill}, + {name: "qa_create_block_size_512", run: testQACreateBlockSize512}, + {name: "qa_create_block_size_8192", run: testQACreateBlockSize8192}, + {name: "qa_validate_write_zero_length", run: testQAValidateWriteZeroLength}, + {name: "qa_double_close", run: testQADoubleClose}, + {name: "qa_write_read_all_lbas", run: testQAWriteReadAllLBAs}, + {name: "qa_dirty_map_range_during_delete", run: testQADirtyMapRangeDuringDelete}, + {name: "qa_wal_entry_bitflip_systematic", run: testQAWALEntryBitflipSystematic}, + {name: "qa_oracle_random_ops", run: testQAOracleRandomOps}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// --- Multi-block I/O --- + +// testQAMultiBlockWriteReadMiddle: Write 3 blocks as one WriteLBA call, +// then read only the 2nd block. Exercises blockOffset calculation in readBlockFromWAL. +func testQAMultiBlockWriteReadMiddle(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // 3 blocks: 'A' 'B' 'C' + data := make([]byte, 3*4096) + for i := 0; i < 4096; i++ { + data[i] = 'A' + data[4096+i] = 'B' + data[2*4096+i] = 'C' + } + + // Write 3 blocks starting at LBA 5 + if err := v.WriteLBA(5, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Read only block at LBA 6 (the middle one — should be 'B') + got, err := v.ReadLBA(6, 4096) + if err != nil { + t.Fatalf("ReadLBA(6): %v", err) + } + if !bytes.Equal(got, makeBlock('B')) { + t.Errorf("middle block: got %q..., want all 'B'", got[:8]) + } + + // Read only block at LBA 7 (last — should be 'C') + got, err = v.ReadLBA(7, 4096) + if err != nil { + t.Fatalf("ReadLBA(7): %v", err) + } + if !bytes.Equal(got, makeBlock('C')) { + t.Errorf("last block: got %q..., want all 'C'", got[:8]) + } + + // Read only first block at LBA 5 (should be 'A') + got, err = v.ReadLBA(5, 4096) + if err != nil { + t.Fatalf("ReadLBA(5): %v", err) + } + if !bytes.Equal(got, makeBlock('A')) { + t.Errorf("first block: got %q..., want all 'A'", got[:8]) + } +} + +// --- Trim semantics --- + +// testQATrimThenReadZeros: Write a block, trim it, read back — must get zeros. +func testQATrimThenReadZeros(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + data := makeBlock('Z') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Verify data is there. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA before trim: %v", err) + } + if !bytes.Equal(got, data) { + t.Fatal("data not written correctly before trim") + } + + // Trim the block. + if err := v.Trim(0, 4096); err != nil { + t.Fatalf("Trim: %v", err) + } + + // Read after trim — must be zeros (from extent, which was never written). + got, err = v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after trim: %v", err) + } + zeros := make([]byte, 4096) + if !bytes.Equal(got, zeros) { + t.Errorf("after trim: expected zeros, got non-zero data (first byte = 0x%02x)", got[0]) + } +} + +// testQATrimDirtyThenReadZeros: Write two blocks, trim only one, verify only the +// trimmed one returns zeros and the other is intact. +func testQATrimDirtyThenReadZeros(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + if err := v.WriteLBA(0, makeBlock('X')); err != nil { + t.Fatalf("WriteLBA(0): %v", err) + } + if err := v.WriteLBA(1, makeBlock('Y')); err != nil { + t.Fatalf("WriteLBA(1): %v", err) + } + + // Trim only LBA 0. + if err := v.Trim(0, 4096); err != nil { + t.Fatalf("Trim(0): %v", err) + } + + // LBA 0 should be zeros. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0) after trim: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("LBA 0 should be zeros after trim") + } + + // LBA 1 should still be 'Y'. + got, err = v.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1): %v", err) + } + if !bytes.Equal(got, makeBlock('Y')) { + t.Error("LBA 1 should still be 'Y' after trimming LBA 0") + } +} + +// --- Boundary conditions --- + +// testQAWriteLastLBA: Write to the very last block of the volume. +func testQAWriteLastLBA(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Volume is 1MB with 4KB blocks -> 256 blocks -> last LBA is 255. + lastLBA := v.super.VolumeSize/uint64(v.super.BlockSize) - 1 + + data := makeBlock('L') + if err := v.WriteLBA(lastLBA, data); err != nil { + t.Fatalf("WriteLBA(last=%d): %v", lastLBA, err) + } + + got, err := v.ReadLBA(lastLBA, 4096) + if err != nil { + t.Fatalf("ReadLBA(last=%d): %v", lastLBA, err) + } + if !bytes.Equal(got, data) { + t.Error("last LBA data mismatch") + } + + // One past last should fail. + if err := v.WriteLBA(lastLBA+1, data); err == nil { + t.Error("expected error writing past last LBA") + } +} + +// --- Overwrite with different sizes --- + +// testQAOverwriteWider: Write 1 block at LBA 0, then 2 blocks at LBA 0. +// Both blocks in dirty map should reflect the 2-block write. +func testQAOverwriteWider(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Write 1 block of 'A' at LBA 0. + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA 1-block: %v", err) + } + + // Overwrite with 2 blocks at LBA 0: 'X' and 'Y'. + wideData := make([]byte, 2*4096) + for i := 0; i < 4096; i++ { + wideData[i] = 'X' + wideData[4096+i] = 'Y' + } + if err := v.WriteLBA(0, wideData); err != nil { + t.Fatalf("WriteLBA 2-block: %v", err) + } + + // Read LBA 0 — should be 'X' (not 'A'). + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0): %v", err) + } + if !bytes.Equal(got, makeBlock('X')) { + t.Error("LBA 0 should be 'X' after wider overwrite") + } + + // Read LBA 1 — should be 'Y'. + got, err = v.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1): %v", err) + } + if !bytes.Equal(got, makeBlock('Y')) { + t.Error("LBA 1 should be 'Y' after wider overwrite") + } +} + +// testQAOverwriteNarrower: Write 2 blocks at LBA 0, then 1 block at LBA 0. +// LBA 0 gets new data, LBA 1 retains old data. +func testQAOverwriteNarrower(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Write 2 blocks: 'A' at LBA 0, 'B' at LBA 1. + wideData := make([]byte, 2*4096) + for i := 0; i < 4096; i++ { + wideData[i] = 'A' + wideData[4096+i] = 'B' + } + if err := v.WriteLBA(0, wideData); err != nil { + t.Fatalf("WriteLBA 2-block: %v", err) + } + + // Overwrite only LBA 0 with 'Z'. + if err := v.WriteLBA(0, makeBlock('Z')); err != nil { + t.Fatalf("WriteLBA 1-block: %v", err) + } + + // LBA 0 should be 'Z'. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0): %v", err) + } + if !bytes.Equal(got, makeBlock('Z')) { + t.Error("LBA 0 should be 'Z' after narrower overwrite") + } + + // LBA 1 should still be 'B' from the original 2-block write. + got, err = v.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1): %v", err) + } + if !bytes.Equal(got, makeBlock('B')) { + t.Error("LBA 1 should still be 'B' — narrower overwrite shouldn't touch it") + } +} + +// --- Never-written blocks --- + +// testQAReadNeverWritten: Read a block that was never written (isolated test). +func testQAReadNeverWritten(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + got, err := v.ReadLBA(42, 4096) + if err != nil { + t.Fatalf("ReadLBA(42) never-written: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("never-written block should be all zeros") + } +} + +// --- Concurrency --- + +// testQAConcurrentWrites: Hammer WriteLBA from 16 goroutines. +// Verify: no panics, all reads return valid data, LSNs are unique. +func testQAConcurrentWrites(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "concurrent.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 4 * 1024 * 1024, // 4MB (1024 LBAs) + BlockSize: 4096, + WALSize: 2 * 1024 * 1024, // 2MB WAL (plenty of room) + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + defer v.Close() + + const goroutines = 16 + const opsPerGoroutine = 50 + + var wg sync.WaitGroup + errs := make(chan error, goroutines*opsPerGoroutine) + + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + // Each goroutine writes to its own LBA range to avoid dirty map + // read-back races (we test write correctness, not read-write ordering). + baseLBA := uint64(id * opsPerGoroutine) + for i := 0; i < opsPerGoroutine; i++ { + lba := baseLBA + uint64(i) + if lba >= 1024 { + continue // stay within volume + } + data := makeBlock(byte('A' + id%26)) + if err := v.WriteLBA(lba, data); err != nil { + if errors.Is(err, ErrWALFull) { + return // WAL full is expected, not a bug + } + errs <- err + return + } + } + }(g) + } + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent write error: %v", err) + } + + // Spot-check a few reads. Some may not have been written (WAL full). + for g := 0; g < goroutines; g++ { + lba := uint64(g * opsPerGoroutine) + if lba >= 1024 { + continue + } + got, err := v.ReadLBA(lba, 4096) + if err != nil { + t.Errorf("ReadLBA(%d) after concurrent writes: %v", lba, err) + continue + } + expected := makeBlock(byte('A' + g%26)) + zeros := make([]byte, 4096) + if !bytes.Equal(got, expected) && !bytes.Equal(got, zeros) { + t.Errorf("LBA %d: data mismatch after concurrent write (not expected data or zeros)", lba) + } + } +} + +// testQAConcurrentWriteRead: One writer and multiple readers on same LBA. +// Readers should always see either old data or new data, never garbage. +func testQAConcurrentWriteRead(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Seed with initial data. + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("seed write: %v", err) + } + + var wg sync.WaitGroup + stop := make(chan struct{}) + + // Writer: overwrites LBA 0 with 'B'. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + v.WriteLBA(0, makeBlock('B')) + } + close(stop) + }() + + // Readers: read LBA 0 and verify data is coherent (all same byte). + for r := 0; r < 4; r++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + } + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Errorf("concurrent read: %v", err) + return + } + // Every byte in the block should be the same value. + first := got[0] + for j, b := range got { + if b != first { + t.Errorf("torn read at byte %d: got 0x%02x, expected 0x%02x", j, b, first) + return + } + } + } + }() + } + + wg.Wait() +} + +// --- WAL capacity management --- + +// testQAWALFillAdvanceRefill: Fill WAL, advance tail, write more entries. +func testQAWALFillAdvanceRefill(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "wal_refill.blockvol") + // Small WAL: 128KB. + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 128 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + defer v.Close() + + entrySize := uint64(walEntryHeaderSize + 4096) // ~4134 bytes per entry + maxEntries := 128 * 1024 / int(entrySize) // ~31 entries + + // Write until WAL is full. + var lastOK int + for i := 0; i < maxEntries+5; i++ { + err := v.WriteLBA(uint64(i%256), makeBlock(byte('A'+i%26))) + if err != nil { + break + } + lastOK = i + } + if lastOK == 0 { + t.Fatal("couldn't write any entries") + } + + // Advance tail to free half the WAL. + halfEntries := uint64(lastOK/2+1) * entrySize + v.wal.AdvanceTail(halfEntries) + + // Write more — should succeed now. + for i := 0; i < 5; i++ { + if err := v.WriteLBA(uint64(i), makeBlock(byte('a'+i))); err != nil { + t.Fatalf("write after tail advance %d: %v", i, err) + } + } + + // Verify latest writes are readable. + for i := 0; i < 5; i++ { + got, err := v.ReadLBA(uint64(i), 4096) + if err != nil { + t.Fatalf("ReadLBA(%d) after refill: %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('a'+i))) { + t.Errorf("LBA %d: data mismatch after refill", i) + } + } +} + +// --- Non-default block sizes --- + +func testQACreateBlockSize512(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bs512.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 512, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol(512): %v", err) + } + defer v.Close() + + data := make([]byte, 512) + for i := range data { + data[i] = 0xAB + } + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA(bs=512): %v", err) + } + got, err := v.ReadLBA(0, 512) + if err != nil { + t.Fatalf("ReadLBA(bs=512): %v", err) + } + if !bytes.Equal(got, data) { + t.Error("512-byte block: data mismatch") + } +} + +func testQACreateBlockSize8192(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bs8192.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 8192, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol(8192): %v", err) + } + defer v.Close() + + data := make([]byte, 8192) + for i := range data { + data[i] = 0xCD + } + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA(bs=8192): %v", err) + } + got, err := v.ReadLBA(0, 8192) + if err != nil { + t.Fatalf("ReadLBA(bs=8192): %v", err) + } + if !bytes.Equal(got, data) { + t.Error("8192-byte block: data mismatch") + } +} + +// --- Validation edge cases --- + +func testQAValidateWriteZeroLength(t *testing.T) { + err := ValidateWrite(0, 0, 1024*1024, 4096) + // 0-length write: dataLen%blockSize == 0 (0%4096 == 0), but blocksNeeded == 0. + // This should arguably be rejected, but current code may allow it. + // If it's allowed, at least it shouldn't crash. + if err != nil { + // Good — zero-length writes rejected. + return + } + // If allowed, WriteLBA with empty data should be caught by WAL entry validation. + v := createTestVol(t) + defer v.Close() + err = v.WriteLBA(0, []byte{}) + if err == nil { + t.Error("WriteLBA with empty data should be rejected") + } +} + +// --- Lifecycle --- + +func testQADoubleClose(t *testing.T) { + v := createTestVol(t) + if err := v.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + // Second close should not panic (may return error, that's fine). + _ = v.Close() +} + +// --- Exhaustive small-volume test --- + +// testQAWriteReadAllLBAs: Write unique data to every LBA, read all back. +func testQAWriteReadAllLBAs(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "alllba.blockvol") + volSize := uint64(32 * 4096) // 32 blocks + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: volSize, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + defer v.Close() + + totalBlocks := volSize / 4096 + + // Write every block with unique pattern. + for lba := uint64(0); lba < totalBlocks; lba++ { + data := makeBlock(byte(lba)) + if err := v.WriteLBA(lba, data); err != nil { + t.Fatalf("WriteLBA(%d): %v", lba, err) + } + } + + // Read every block and verify. + for lba := uint64(0); lba < totalBlocks; lba++ { + got, err := v.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", lba, err) + } + expected := makeBlock(byte(lba)) + if !bytes.Equal(got, expected) { + t.Errorf("LBA %d: data mismatch", lba) + } + } +} + +// --- DirtyMap adversarial --- + +// testQADirtyMapRangeDuringDelete: Verify Range + Delete doesn't deadlock. +// This tests the reviewer fix (snapshot-then-iterate pattern). +func testQADirtyMapRangeDuringDelete(t *testing.T) { + dm := NewDirtyMap() + + // Populate 100 entries. + for i := uint64(0); i < 100; i++ { + dm.Put(i, i*100, i, 4096) + } + + // Range over all, delete each one inside the callback. + dm.Range(0, 100, func(lba, walOffset, lsn uint64, length uint32) { + dm.Delete(lba) + }) + + // All entries should be deleted. + if dm.Len() != 0 { + t.Errorf("expected 0 entries after Range+Delete, got %d", dm.Len()) + } +} + +// --- WAL entry corruption --- + +// testQAWALEntryBitflipSystematic: Encode a valid entry, flip one bit at +// each byte position, verify Decode detects the corruption. +func testQAWALEntryBitflipSystematic(t *testing.T) { + entry := &WALEntry{ + LSN: 42, + Epoch: 7, + Type: EntryTypeWrite, + LBA: 100, + Length: 64, + Data: bytes.Repeat([]byte("DEADBEEF"), 8), + } + + original, err := entry.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Verify original decodes fine. + if _, err := DecodeWALEntry(original); err != nil { + t.Fatalf("original decode failed: %v", err) + } + + corrupted := 0 + detected := 0 + + for bytePos := 0; bytePos < len(original); bytePos++ { + for bit := 0; bit < 8; bit++ { + flipped := make([]byte, len(original)) + copy(flipped, original) + flipped[bytePos] ^= 1 << uint(bit) + + corrupted++ + _, err := DecodeWALEntry(flipped) + if err != nil { + detected++ + } + } + } + + // CRC32 should catch the vast majority. With 38+64=102 bytes and + // single-bit flips, CRC32 IEEE guarantees detection for bursts up to 32 bits. + detectionRate := float64(detected) / float64(corrupted) * 100 + t.Logf("bitflip detection: %d/%d (%.1f%%)", detected, corrupted, detectionRate) + + // We expect near-100% detection. Allow for the CRC and EntrySize bytes + // themselves which may produce self-consistent mutations. + if detectionRate < 95.0 { + t.Errorf("detection rate %.1f%% is too low (expected >= 95%%)", detectionRate) + } +} + +// --- Oracle pattern (the crown jewel of adversarial testing) --- + +// testQAOracleRandomOps: Execute random write/read/trim operations against +// both BlockVol and an in-memory oracle. Assert they always agree. +func testQAOracleRandomOps(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "oracle.blockvol") + const blockSize = 4096 + const numBlocks = 64 // small volume for fast test + const volSize = numBlocks * blockSize + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: volSize, + BlockSize: blockSize, + WALSize: 512 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + defer v.Close() + + // Oracle: simple map from LBA to block data. + // Missing entries mean zeros. + oracle := make(map[uint64][]byte) + + rng := rand.New(rand.NewSource(0xDEADBEEF)) + + const numOps = 500 + + for i := 0; i < numOps; i++ { + op := rng.Intn(3) // 0=write, 1=read, 2=trim + lba := uint64(rng.Intn(numBlocks)) + + switch op { + case 0: // WRITE + // Write 1-4 blocks (if they fit). + maxBlocks := numBlocks - int(lba) + if maxBlocks <= 0 { + continue + } + nBlocks := rng.Intn(min(4, maxBlocks)) + 1 + data := make([]byte, nBlocks*blockSize) + rng.Read(data) + + err := v.WriteLBA(lba, data) + if err != nil { + if errors.Is(err, ErrWALFull) { + continue // WAL full is expected without flusher + } + t.Fatalf("op %d: WriteLBA(%d, %d blocks): %v", i, lba, nBlocks, err) + } + + // Update oracle. + for b := 0; b < nBlocks; b++ { + blockData := make([]byte, blockSize) + copy(blockData, data[b*blockSize:(b+1)*blockSize]) + oracle[lba+uint64(b)] = blockData + } + + case 1: // READ + got, err := v.ReadLBA(lba, blockSize) + if err != nil { + t.Fatalf("op %d: ReadLBA(%d): %v", i, lba, err) + } + + // Oracle answer. + expected, ok := oracle[lba] + if !ok { + expected = make([]byte, blockSize) // zeros + } + if !bytes.Equal(got, expected) { + t.Fatalf("op %d: ReadLBA(%d) oracle mismatch at op %d", i, lba, i) + } + + case 2: // TRIM + err := v.Trim(lba, blockSize) + if err != nil { + if errors.Is(err, ErrWALFull) { + continue // WAL full is expected without flusher + } + t.Fatalf("op %d: Trim(%d): %v", i, lba, err) + } + delete(oracle, lba) + } + } + + // Final verification: read every block and compare to oracle. + for lba := uint64(0); lba < numBlocks; lba++ { + got, err := v.ReadLBA(lba, blockSize) + if err != nil { + t.Fatalf("final ReadLBA(%d): %v", lba, err) + } + expected, ok := oracle[lba] + if !ok { + expected = make([]byte, blockSize) + } + if !bytes.Equal(got, expected) { + t.Errorf("final LBA %d: oracle mismatch", lba) + } + } + + t.Logf("oracle test: %d ops, %d blocks, all consistent", numOps, numBlocks) +} + +// --- Superblock validation adversarial --- + +func TestQASuperblockValidation(t *testing.T) { + tests := []struct { + name string + mutate func(sb *Superblock) + wantErr error + }{ + { + name: "extent_size_zero", + mutate: func(sb *Superblock) { sb.ExtentSize = 0 }, + wantErr: ErrInvalidSuperblock, + }, + { + name: "wal_size_zero", + mutate: func(sb *Superblock) { sb.WALSize = 0 }, + wantErr: ErrInvalidSuperblock, + }, + { + name: "wal_offset_wrong", + mutate: func(sb *Superblock) { sb.WALOffset = 999 }, + wantErr: ErrInvalidSuperblock, + }, + { + name: "volume_not_aligned", + mutate: func(sb *Superblock) { sb.VolumeSize = 4097 }, + wantErr: ErrInvalidSuperblock, + }, + { + name: "bad_magic", + mutate: func(sb *Superblock) { copy(sb.Magic[:], "BAAD") }, + wantErr: ErrNotBlockVol, + }, + { + name: "bad_version", + mutate: func(sb *Superblock) { sb.Version = 99 }, + wantErr: ErrUnsupportedVersion, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sb, err := NewSuperblock(1024*1024, CreateOptions{}) + if err != nil { + t.Fatalf("NewSuperblock: %v", err) + } + tt.mutate(&sb) + err = sb.Validate() + if err == nil { + t.Fatal("expected Validate() error, got nil") + } + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected %v, got %v", tt.wantErr, err) + } + }) + } +} + +// --- WAL writer adversarial --- + +func TestQAWALWriterEdgeCases(t *testing.T) { + t.Run("entry_larger_than_wal", func(t *testing.T) { + walOffset := uint64(SuperblockSize) + walSize := uint64(1024) // tiny WAL + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + // Entry with 4KB data > 1KB WAL. + entry := &WALEntry{LSN: 1, Type: EntryTypeWrite, LBA: 0, Length: 4096, Data: make([]byte, 4096)} + _, err := w.Append(entry) + if err == nil { + t.Error("expected error when entry exceeds WAL size") + } + }) + + t.Run("padding_smaller_than_header", func(t *testing.T) { + walOffset := uint64(SuperblockSize) + entrySize := uint64(walEntryHeaderSize + 64) + // WAL that leaves less than walEntryHeaderSize bytes after one entry. + // After padding + wrap, logical tracking allows exact fit (no 1-byte reservation). + // padding gap = walSize - entrySize (< walEntryHeaderSize). + // After wrap: logicalHead consumed entrySize + gap, logicalTail advanced to entrySize. + // used = (entrySize + gap) - entrySize = gap. free = walSize - gap. + // If free >= entrySize, second entry fits. + walSize := entrySize + uint64(walEntryHeaderSize) - 5 + // gap = walSize - entrySize = walEntryHeaderSize - 5 = 33 bytes + // free after wrap = walSize - gap = entrySize = 102 bytes + // entrySize = 102. free == entrySize → fits with logical tracking (uses >, not <) + + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + entry1 := &WALEntry{LSN: 1, Type: EntryTypeWrite, LBA: 0, Length: 64, Data: make([]byte, 64)} + if _, err := w.Append(entry1); err != nil { + t.Fatalf("first append: %v", err) + } + + // Advance tail past first entry so wrap has space. + w.AdvanceTail(entrySize) + + // Second entry wraps. With logical counters (no 1-byte reservation), + // the entry fits exactly when available == needed. + entry2 := &WALEntry{LSN: 2, Type: EntryTypeWrite, LBA: 1, Length: 64, Data: make([]byte, 64)} + if _, err := w.Append(entry2); err != nil { + t.Fatalf("second append after wrap should succeed with logical tracking: %v", err) + } + }) + + t.Run("padding_smaller_than_header_with_room", func(t *testing.T) { + walOffset := uint64(SuperblockSize) + entrySize := uint64(walEntryHeaderSize + 64) // 102 + // Need: after first entry + padding gap + wrap, free > entrySize. + // padding gap = walSize - entrySize (what's left at end, < header size). + // After wrap: head=0, tail=entrySize. free = tail - head = entrySize. + // strict < needs free > entryLen, so entrySize > entrySize is false. + // Need extra room: walSize = entrySize + gap + extra. + // With gap < walEntryHeaderSize (say 30) and extra >= 2: + walSize := entrySize + 30 + entrySize + 2 // room for 2 entries + 30-byte gap + 2-byte margin + + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + entry1 := &WALEntry{LSN: 1, Type: EntryTypeWrite, LBA: 0, Length: 64, Data: make([]byte, 64)} + if _, err := w.Append(entry1); err != nil { + t.Fatalf("first append: %v", err) + } + + // Write second entry — pushes head to 2*entrySize = 204. + entry2 := &WALEntry{LSN: 2, Type: EntryTypeWrite, LBA: 1, Length: 64, Data: make([]byte, 64)} + if _, err := w.Append(entry2); err != nil { + t.Fatalf("second append: %v", err) + } + + // Advance tail past both entries. + w.AdvanceTail(entrySize * 2) + + // remaining = walSize - 204 = 30 bytes (< walEntryHeaderSize=38) + // → padding uses zero-fill path, head wraps to 0 + // → free = tail - head = 204 - 0 = 204 > 102 → fits! + entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 64, Data: make([]byte, 64)} + off, err := w.Append(entry3) + if err != nil { + t.Fatalf("wrap append with room: %v", err) + } + if off != 0 { + t.Errorf("wrapped entry should be at offset 0, got %d", off) + } + }) +} + +// --- WAL entry edge cases --- + +func TestQAWALEntryEdgeCases(t *testing.T) { + t.Run("decode_truncated_data", func(t *testing.T) { + entry := &WALEntry{LSN: 1, Type: EntryTypeWrite, LBA: 0, Length: 4096, Data: make([]byte, 4096)} + buf, err := entry.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + // Truncate buffer to header + partial data. + truncated := buf[:walEntryHeaderSize+100] + _, err = DecodeWALEntry(truncated) + if err == nil { + t.Error("expected error decoding truncated entry") + } + }) + + t.Run("decode_header_only", func(t *testing.T) { + _, err := DecodeWALEntry(make([]byte, walEntryHeaderSize-1)) + if err == nil { + t.Error("expected error for buffer smaller than header") + } + }) + + t.Run("corrupt_entry_size_field", func(t *testing.T) { + entry := &WALEntry{LSN: 1, Type: EntryTypeWrite, LBA: 0, Length: 64, Data: make([]byte, 64)} + buf, err := entry.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + // Corrupt the EntrySize field (last 4 bytes). + binary.LittleEndian.PutUint32(buf[len(buf)-4:], 99999) + _, err = DecodeWALEntry(buf) + if err == nil { + t.Error("expected error for corrupt EntrySize") + } + }) + + t.Run("unknown_entry_type_encode", func(t *testing.T) { + // Type 0x99 is not recognized — Encode should still work + // (only WRITE/TRIM/BARRIER have special validation). + entry := &WALEntry{LSN: 1, Type: 0x99, LBA: 0} + _, err := entry.Encode() + // Unknown type with no data — may or may not error. + // Just verify no panic. + _ = err + }) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// ============================================================================ +// QA Adversarial Tests — Tasks 1.7 (GroupCommitter), 1.8 (Flusher), 1.9 (Recovery) +// ============================================================================ + +// --- Task 1.7: GroupCommitter adversarial tests --- + +func TestQAGroupCommitter(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_gc_double_stop", run: testQAGCDoubleStop}, + {name: "qa_gc_submit_storm_during_stop", run: testQAGCSubmitStormDuringStop}, + {name: "qa_gc_fsync_error_all_waiters", run: testQAGCFsyncErrorAllWaiters}, + {name: "qa_gc_intermittent_fsync_error", run: testQAGCIntermittentFsyncError}, + {name: "qa_gc_max_batch_exact", run: testQAGCMaxBatchExact}, + {name: "qa_gc_zero_delay_still_works", run: testQAGCZeroDelayStillWorks}, + {name: "qa_gc_sync_count_accuracy", run: testQAGCSyncCountAccuracy}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQAGCDoubleStop: Stop() twice must not panic or deadlock. +func testQAGCDoubleStop(t *testing.T) { + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { return nil }, + }) + go gc.Run() + + gc.Stop() + + // Second stop — must not panic or deadlock. + done := make(chan struct{}) + go func() { + gc.Stop() + close(done) + }() + + select { + case <-done: + // Good. + case <-time.After(2 * time.Second): + t.Fatal("second Stop() deadlocked") + } +} + +// testQAGCSubmitStormDuringStop: Many goroutines Submit() while Stop() is called. +// All must either succeed or get ErrGroupCommitShutdown — no panics, no deadlocks. +// +// BUG QA-002: There is a race between drainPending() and close(gc.done) in Run(). +// Goroutines that pass the gc.done double-check AFTER drainPending() releases gc.mu +// but BEFORE close(gc.done) can enqueue to pending with no goroutine to drain them, +// causing a permanent hang on <-ch in Submit(). +// +// Race sequence: +// 1. Run(): drainPending() → gc.mu.Lock → take pending → gc.mu.Unlock → send errors +// 2. Submit(): passes first select<-gc.done (not closed yet) +// 3. Submit(): gc.mu.Lock → passes second select<-gc.done → append ch → gc.mu.Unlock +// 4. Run(): close(gc.done) ← too late, ch already enqueued with no consumer +// 5. Submit(): <-ch blocks forever +func testQAGCSubmitStormDuringStop(t *testing.T) { + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { return nil }, + MaxDelay: 1 * time.Millisecond, + }) + go gc.Run() + + const goroutines = 32 + var wg sync.WaitGroup + errs := make(chan error, goroutines*10) + + // Launch submitters. + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + err := gc.Submit() + if err != nil && !errors.Is(err, ErrGroupCommitShutdown) { + errs <- fmt.Errorf("unexpected error: %w", err) + return + } + if errors.Is(err, ErrGroupCommitShutdown) { + return // stopped, don't retry + } + } + }() + } + + // Race: stop while submitters are in flight. + time.Sleep(1 * time.Millisecond) + gc.Stop() + + // Use timeout to detect QA-002 hang. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // All goroutines exited cleanly. + case <-time.After(5 * time.Second): + // QA-002: Submit() goroutines stuck waiting for response after Stop(). + t.Fatal("BUG QA-002: Submit() goroutines deadlocked during Stop() — " + + "drainPending/close(done) race allows enqueue after drain") + } + + close(errs) + for err := range errs { + t.Errorf("submit storm: %v", err) + } +} + +// testQAGCFsyncErrorAllWaiters: When fsync fails, ALL waiters in the batch +// must receive the error (not just the first one). +func testQAGCFsyncErrorAllWaiters(t *testing.T) { + errDisk := fmt.Errorf("disk on fire") + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { return errDisk }, + MaxDelay: 50 * time.Millisecond, + MaxBatch: 100, + OnDegraded: func() {}, + }) + go gc.Run() + defer gc.Stop() + + const n = 20 + var wg sync.WaitGroup + results := make([]error, n) + + wg.Add(n) + for i := 0; i < n; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = gc.Submit() + }(i) + } + wg.Wait() + + for i, err := range results { + if err == nil { + t.Errorf("waiter %d got nil, want error", i) + } + } +} + +// testQAGCIntermittentFsyncError: fsync alternates success/failure. +// Verify each batch's waiters get the correct result. +func testQAGCIntermittentFsyncError(t *testing.T) { + var callCount atomic.Uint64 + errFlaky := fmt.Errorf("flaky disk") + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + n := callCount.Add(1) + if n%2 == 0 { + return errFlaky // even calls fail + } + return nil // odd calls succeed + }, + MaxDelay: 2 * time.Millisecond, + }) + go gc.Run() + defer gc.Stop() + + // Submit 20 sequential requests (each likely in its own batch). + var successes, failures int + for i := 0; i < 20; i++ { + err := gc.Submit() + if err == nil { + successes++ + } else { + failures++ + } + } + + // Both successes and failures should occur. + if successes == 0 { + t.Error("expected some successful syncs") + } + if failures == 0 { + t.Error("expected some failed syncs from intermittent error") + } + t.Logf("intermittent: %d success, %d failure out of 20", successes, failures) +} + +// testQAGCMaxBatchExact: Submit exactly maxBatch waiters. +// They should all complete quickly (trigger immediate flush, not wait for maxDelay). +func testQAGCMaxBatchExact(t *testing.T) { + const maxBatch = 8 + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { return nil }, + MaxDelay: 10 * time.Second, // very long — should NOT wait + MaxBatch: maxBatch, + }) + go gc.Run() + defer gc.Stop() + + var wg sync.WaitGroup + wg.Add(maxBatch) + for i := 0; i < maxBatch; i++ { + go func() { + defer wg.Done() + gc.Submit() + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Good — completed quickly. + case <-time.After(3 * time.Second): + t.Fatal("maxBatch exact count did not trigger immediate flush") + } +} + +// testQAGCZeroDelayStillWorks: MaxDelay=0 should not panic or hang. +func testQAGCZeroDelayStillWorks(t *testing.T) { + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { return nil }, + MaxDelay: 0, // should get default 1ms + }) + go gc.Run() + defer gc.Stop() + + done := make(chan error, 1) + go func() { + done <- gc.Submit() + }() + + select { + case err := <-done: + if err != nil { + t.Errorf("Submit with zero delay: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Submit with zero delay hung") + } +} + +// testQAGCSyncCountAccuracy: Verify SyncCount matches actual fsync calls. +func testQAGCSyncCountAccuracy(t *testing.T) { + var actualSyncs atomic.Uint64 + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + actualSyncs.Add(1) + return nil + }, + MaxDelay: 1 * time.Millisecond, + }) + go gc.Run() + defer gc.Stop() + + // 10 sequential submits (each should be its own batch). + for i := 0; i < 10; i++ { + if err := gc.Submit(); err != nil { + t.Fatalf("Submit %d: %v", i, err) + } + } + + if gc.SyncCount() != actualSyncs.Load() { + t.Errorf("SyncCount=%d, actual=%d", gc.SyncCount(), actualSyncs.Load()) + } +} + +// --- Task 1.8: Flusher adversarial tests --- + +func TestQAFlusher(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_flush_empty_dirty_map", run: testQAFlushEmptyDirtyMap}, + {name: "qa_flush_overwrite_during_flush", run: testQAFlushOverwriteDuringFlush}, + {name: "qa_flush_trim_zeros_extent", run: testQAFlushTrimZerosExtent}, + {name: "qa_flush_preserves_newer_writes", run: testQAFlushPreservesNewerWrites}, + {name: "qa_flush_checkpoint_persists", run: testQAFlushCheckpointPersists}, + {name: "qa_flush_wal_reclaim_then_write", run: testQAFlushWALReclaimThenWrite}, + {name: "qa_flush_multi_block_entry", run: testQAFlushMultiBlockEntry}, + {name: "qa_flusher_stop_idempotent", run: testQAFlusherStopIdempotent}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQAFlushEmptyDirtyMap: FlushOnce with no dirty entries is a no-op. +func testQAFlushEmptyDirtyMap(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // No writes — flush should not error or change anything. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce on empty: %v", err) + } + if f.CheckpointLSN() != 0 { + t.Errorf("checkpoint should be 0 on empty flush, got %d", f.CheckpointLSN()) + } +} + +// testQAFlushOverwriteDuringFlush: Write, flush, overwrite same LBA, flush again. +// Verify final state is the overwritten data. +func testQAFlushOverwriteDuringFlush(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write initial data. + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA(A): %v", err) + } + + // Flush — moves 'A' to extent. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 1: %v", err) + } + + // Overwrite with 'B'. + if err := v.WriteLBA(0, makeBlock('B')); err != nil { + t.Fatalf("WriteLBA(B): %v", err) + } + + // Read should return 'B' (from dirty map, not extent). + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA before second flush: %v", err) + } + if !bytes.Equal(got, makeBlock('B')) { + t.Error("before second flush: should read 'B' from WAL") + } + + // Flush again — moves 'B' to extent. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 2: %v", err) + } + + // Read should still return 'B' (now from extent). + got, err = v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after second flush: %v", err) + } + if !bytes.Equal(got, makeBlock('B')) { + t.Error("after second flush: should read 'B' from extent") + } +} + +// testQAFlushTrimZerosExtent: Write, flush (data in extent), trim, flush again. +// After second flush, extent should contain zeros. +func testQAFlushTrimZerosExtent(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write data. + if err := v.WriteLBA(0, makeBlock('X')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Flush — 'X' goes to extent. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 1: %v", err) + } + + // Verify extent has 'X'. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after first flush: %v", err) + } + if !bytes.Equal(got, makeBlock('X')) { + t.Fatal("extent should have 'X' after first flush") + } + + // Trim the block. + if err := v.Trim(0, 4096); err != nil { + t.Fatalf("Trim: %v", err) + } + + // Read from dirty map (TRIM entry) — should return zeros. + got, err = v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after trim: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("after trim: dirty map read should return zeros") + } + + // Flush again — flusher zeros the extent. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 2: %v", err) + } + + // Dirty map should be empty. + if v.dirtyMap.Len() != 0 { + t.Errorf("dirty map should be empty after flush, got %d", v.dirtyMap.Len()) + } + + // Read from extent — should be zeros (flusher zeroed it). + got, err = v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after trim flush: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("after trim+flush: extent should be zeros") + } +} + +// testQAFlushPreservesNewerWrites: Write A, start flush snapshot, write B to same LBA +// before flush removes from dirty map. Dirty map should keep B (newer LSN). +func testQAFlushPreservesNewerWrites(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write 'A' to LBA 0. + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA(A): %v", err) + } + + // Also write 'M' to LBA 5 to ensure flush has something. + if err := v.WriteLBA(5, makeBlock('M')); err != nil { + t.Fatalf("WriteLBA(5): %v", err) + } + + // Flush moves both to extent. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // Overwrite LBA 0 with 'B' AFTER flush. + if err := v.WriteLBA(0, makeBlock('B')); err != nil { + t.Fatalf("WriteLBA(B): %v", err) + } + + // Now flush again — flusher should see the new entry for LBA 0. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 2: %v", err) + } + + // LBA 0 should be 'B' from extent. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, makeBlock('B')) { + t.Error("LBA 0 should be 'B' after overwrite+flush") + } +} + +// testQAFlushCheckpointPersists: Flush, crash, reopen. Checkpoint LSN should +// be persisted so recovery skips already-flushed entries. +func testQAFlushCheckpointPersists(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "checkpoint.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write and flush blocks 0-4. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + f := NewFlusher(FlusherConfig{ + FD: v.fd, + Super: &v.super, + WAL: v.wal, + DirtyMap: v.dirtyMap, + Interval: 1 * time.Hour, + }) + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + checkpointLSN := f.CheckpointLSN() + if checkpointLSN == 0 { + t.Fatal("checkpoint should be non-zero after flush") + } + + // Write more blocks AFTER checkpoint. + for i := uint64(5); i < 10; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Update superblock and crash. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + // Reopen — recovery should skip LSN <= checkpoint and replay 5-9. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Blocks 0-4 from extent (flushed), blocks 5-9 from WAL (replayed). + for i := uint64(0); i < 10; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: data mismatch after checkpoint recovery", i) + } + } +} + +// testQAFlushWALReclaimThenWrite: Fill WAL, flush (reclaim all), write again. +// Tests that WAL space is truly freed and reusable. +func testQAFlushWALReclaimThenWrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "reclaim.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 128 * 1024, // small WAL + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + defer v.Close() + + entrySize := uint64(walEntryHeaderSize + 4096) + maxEntries := int(128 * 1024 / entrySize) + + // Fill WAL completely. + for i := 0; i < maxEntries; i++ { + if err := v.WriteLBA(uint64(i), makeBlock(byte(i%26+'A'))); err != nil { + break // expected ErrWALFull + } + } + + // Confirm WAL is full. + err = v.WriteLBA(0, makeBlock('Z')) + if err == nil { + // Might succeed if we didn't quite fill it. Try more. + for i := 0; i < 100; i++ { + if err := v.WriteLBA(uint64(i), makeBlock('Z')); err != nil { + break + } + } + } + + // Flush — reclaim all WAL space. + f := NewFlusher(FlusherConfig{ + FD: v.fd, + Super: &v.super, + WAL: v.wal, + DirtyMap: v.dirtyMap, + Interval: 1 * time.Hour, + }) + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // WAL should be empty now. Write again — should succeed. + for i := 0; i < 5; i++ { + if err := v.WriteLBA(uint64(i), makeBlock(byte('a'+i))); err != nil { + t.Fatalf("write after reclaim %d: %v", i, err) + } + } + + // Verify. + for i := 0; i < 5; i++ { + got, err := v.ReadLBA(uint64(i), 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('a'+i))) { + t.Errorf("block %d: data mismatch after reclaim+rewrite", i) + } + } +} + +// testQAFlushMultiBlockEntry: Write multi-block entry, flush, verify all blocks +// are correctly placed in the extent. +func testQAFlushMultiBlockEntry(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write 3 blocks as one WriteLBA call. + data := make([]byte, 3*4096) + for i := 0; i < 4096; i++ { + data[i] = 'P' + data[4096+i] = 'Q' + data[2*4096+i] = 'R' + } + if err := v.WriteLBA(10, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Flush. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // All 3 blocks should be readable from extent. + for i, expected := range []byte{'P', 'Q', 'R'} { + got, err := v.ReadLBA(uint64(10+i), 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", 10+i, err) + } + if !bytes.Equal(got, makeBlock(expected)) { + t.Errorf("block %d: expected '%c', got different data", 10+i, expected) + } + } +} + +// testQAFlusherStopIdempotent: Stop() twice on the flusher goroutine. +func testQAFlusherStopIdempotent(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + go f.Run() + f.Stop() + + // Second stop — must not panic or deadlock. + done := make(chan struct{}) + go func() { + f.Stop() + close(done) + }() + + select { + case <-done: + // Good. + case <-time.After(2 * time.Second): + t.Fatal("second Flusher.Stop() deadlocked") + } +} + +// --- Task 1.9: Recovery adversarial tests --- + +func TestQARecovery(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_recover_trim_entry", run: testQARecoverTrimEntry}, + {name: "qa_recover_mixed_write_trim_barrier", run: testQARecoverMixedWriteTrimBarrier}, + {name: "qa_recover_after_flush_then_crash", run: testQARecoverAfterFlushThenCrash}, + {name: "qa_recover_overwrite_same_lba", run: testQARecoverOverwriteSameLBA}, + {name: "qa_recover_crash_loop", run: testQARecoverCrashLoop}, + {name: "qa_recover_corrupt_middle_entry", run: testQARecoverCorruptMiddleEntry}, + {name: "qa_recover_multi_block_write", run: testQARecoverMultiBlockWrite}, + {name: "qa_recover_oracle_with_crash", run: testQARecoverOracleWithCrash}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQARecoverTrimEntry: Write, trim, sync, crash, recover. +// After recovery, trimmed LBA should return zeros. +func testQARecoverTrimEntry(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "trim_recover.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write then trim. + if err := v.WriteLBA(3, makeBlock('T')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + if err := v.Trim(3, 4096); err != nil { + t.Fatalf("Trim: %v", err) + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Persist superblock. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + got, err := v2.ReadLBA(3, 4096) + if err != nil { + t.Fatalf("ReadLBA(3): %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("trimmed LBA should be zeros after recovery") + } +} + +// testQARecoverMixedWriteTrimBarrier: Interleave WRITE, TRIM, and BARRIER entries. +// Verify recovery replays correctly. +func testQARecoverMixedWriteTrimBarrier(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "mixed_recover.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write LBA 0, 1, 2. + for i := uint64(0); i < 3; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Trim LBA 1. + if err := v.Trim(1, 4096); err != nil { + t.Fatalf("Trim(1): %v", err) + } + + // Write a barrier. + lsn := v.nextLSN.Add(1) - 1 + barrier := &WALEntry{LSN: lsn, Type: EntryTypeBarrier, LBA: 0} + if _, err := v.wal.Append(barrier); err != nil { + t.Fatalf("Append barrier: %v", err) + } + + // Write LBA 5. + if err := v.WriteLBA(5, makeBlock('Z')); err != nil { + t.Fatalf("WriteLBA(5): %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // LBA 0: 'A' + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0): %v", err) + } + if !bytes.Equal(got, makeBlock('A')) { + t.Error("LBA 0 should be 'A'") + } + + // LBA 1: trimmed — zeros. + got, err = v2.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1): %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("LBA 1 should be zeros (trimmed)") + } + + // LBA 2: 'C' + got, err = v2.ReadLBA(2, 4096) + if err != nil { + t.Fatalf("ReadLBA(2): %v", err) + } + if !bytes.Equal(got, makeBlock('C')) { + t.Error("LBA 2 should be 'C'") + } + + // LBA 5: 'Z' + got, err = v2.ReadLBA(5, 4096) + if err != nil { + t.Fatalf("ReadLBA(5): %v", err) + } + if !bytes.Equal(got, makeBlock('Z')) { + t.Error("LBA 5 should be 'Z'") + } +} + +// testQARecoverAfterFlushThenCrash: Flush some entries, write more, crash. +// Flushed entries should be in extent; post-flush writes recovered from WAL. +func testQARecoverAfterFlushThenCrash(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "flush_crash.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write blocks 0-4 and flush. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + f := NewFlusher(FlusherConfig{ + FD: v.fd, Super: &v.super, WAL: v.wal, DirtyMap: v.dirtyMap, + Interval: 1 * time.Hour, + }) + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // Write blocks 5-9 AFTER flush. + for i := uint64(5); i < 10; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Persist superblock. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + for i := uint64(0); i < 10; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: data mismatch (flush+crash recovery)", i) + } + } +} + +// testQARecoverOverwriteSameLBA: Write LBA 0 three times, sync, crash, recover. +// Recovery should replay the latest write for LBA 0. +func testQARecoverOverwriteSameLBA(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "overwrite_recover.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write LBA 0 three times with different data. + if err := v.WriteLBA(0, makeBlock('X')); err != nil { + t.Fatalf("WriteLBA(X): %v", err) + } + if err := v.WriteLBA(0, makeBlock('Y')); err != nil { + t.Fatalf("WriteLBA(Y): %v", err) + } + if err := v.WriteLBA(0, makeBlock('Z')); err != nil { + t.Fatalf("WriteLBA(Z): %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, makeBlock('Z')) { + t.Error("LBA 0 should be 'Z' (latest write) after recovery") + } +} + +// testQARecoverCrashLoop: Write, sync, crash, recover — 20 iterations. +// Each iteration writes new data and verifies previous data survived. +func testQARecoverCrashLoop(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "crashloop.blockvol") + + // Create initial volume. + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + const iterations = 20 + + for iter := 0; iter < iterations; iter++ { + lba := uint64(iter % 200) // spread across LBAs + data := makeBlock(byte(iter % 256)) + + if err := v.WriteLBA(lba, data); err != nil { + t.Fatalf("iter %d WriteLBA: %v", iter, err) + } + if err := v.SyncCache(); err != nil { + t.Fatalf("iter %d SyncCache: %v", iter, err) + } + + // Persist superblock state. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + // Crash and recover. + path = simulateCrash(v) + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("iter %d OpenBlockVol: %v", iter, err) + } + + // Verify the data we just wrote. + got, err := v.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("iter %d ReadLBA: %v", iter, err) + } + if !bytes.Equal(got, data) { + t.Fatalf("iter %d: data mismatch for LBA %d", iter, lba) + } + } + + v.Close() +} + +// testQARecoverCorruptMiddleEntry: Write 3 entries, corrupt the 2nd entry's CRC. +// Recovery should replay entry 1, skip entries 2+3 (torn write boundary). +func testQARecoverCorruptMiddleEntry(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "corrupt_mid.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write 3 entries. + for i := uint64(0); i < 3; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + // Corrupt 2nd entry CRC (byte in data area of 2nd entry). + entrySize := uint64(walEntryHeaderSize + 4096) + // Corrupt a byte inside the 2nd entry's data region. + corruptOff := int64(v.super.WALOffset + entrySize + uint64(walEntryHeaderSize) + 10) + v.fd.WriteAt([]byte{0xFF}, corruptOff) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Entry 1 (LBA 0) should be recovered. + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0): %v", err) + } + if !bytes.Equal(got, makeBlock('A')) { + t.Error("LBA 0 should be 'A' (recovered before corrupt entry)") + } + + // Entries 2+3 (LBA 1,2) should NOT be recovered (CRC failure stops scan). + got, err = v2.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1): %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("LBA 1 should be zeros (corrupt entry discarded)") + } + + got, err = v2.ReadLBA(2, 4096) + if err != nil { + t.Fatalf("ReadLBA(2): %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("LBA 2 should be zeros (entry after corrupt entry discarded)") + } +} + +// testQARecoverMultiBlockWrite: Write multi-block entry, crash, recover. +// Verify all blocks from the multi-block entry are recovered. +func testQARecoverMultiBlockWrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "multiblock_recover.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write 4 blocks as one call. + data := make([]byte, 4*4096) + for i := 0; i < 4; i++ { + for j := 0; j < 4096; j++ { + data[i*4096+j] = byte('W' + i) + } + } + if err := v.WriteLBA(10, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + for i := 0; i < 4; i++ { + got, err := v2.ReadLBA(uint64(10+i), 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", 10+i, err) + } + expected := makeBlock(byte('W' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: expected '%c', got different data", 10+i, byte('W'+i)) + } + } +} + +// testQARecoverOracleWithCrash: Oracle pattern with periodic crash+recover. +// This is the most valuable adversarial test — it exercises the full +// write→sync→crash→recover→verify cycle with random operations. +func testQARecoverOracleWithCrash(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "oracle_crash.blockvol") + + const blockSize = 4096 + const numBlocks = 32 + const volSize = numBlocks * blockSize + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: volSize, + BlockSize: blockSize, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + oracle := make(map[uint64][]byte) + rng := rand.New(rand.NewSource(0xCAFEBABE)) + + const iterations = 10 + const opsPerIter = 30 + + for iter := 0; iter < iterations; iter++ { + // Execute random ops. + for op := 0; op < opsPerIter; op++ { + lba := uint64(rng.Intn(numBlocks)) + action := rng.Intn(3) + + switch action { + case 0: // WRITE + data := make([]byte, blockSize) + rng.Read(data) + err := v.WriteLBA(lba, data) + if err != nil { + if errors.Is(err, ErrWALFull) { + continue + } + t.Fatalf("iter %d op %d: WriteLBA(%d): %v", iter, op, lba, err) + } + oracle[lba] = data + + case 1: // READ (verify against oracle) + got, err := v.ReadLBA(lba, blockSize) + if err != nil { + t.Fatalf("iter %d op %d: ReadLBA(%d): %v", iter, op, lba, err) + } + expected, ok := oracle[lba] + if !ok { + expected = make([]byte, blockSize) + } + if !bytes.Equal(got, expected) { + t.Fatalf("iter %d op %d: LBA %d oracle mismatch", iter, op, lba) + } + + case 2: // TRIM + err := v.Trim(lba, blockSize) + if err != nil { + if errors.Is(err, ErrWALFull) { + continue + } + t.Fatalf("iter %d op %d: Trim(%d): %v", iter, op, lba, err) + } + delete(oracle, lba) + } + } + + // Sync and crash. + if err := v.SyncCache(); err != nil { + t.Fatalf("iter %d SyncCache: %v", iter, err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + // Recover. + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("iter %d OpenBlockVol: %v", iter, err) + } + + // Verify all oracle entries after recovery. + for lba := uint64(0); lba < numBlocks; lba++ { + got, err := v.ReadLBA(lba, blockSize) + if err != nil { + t.Fatalf("iter %d verify LBA %d: %v", iter, lba, err) + } + expected, ok := oracle[lba] + if !ok { + expected = make([]byte, blockSize) + } + if !bytes.Equal(got, expected) { + t.Fatalf("iter %d post-recovery: LBA %d oracle mismatch", iter, lba) + } + } + } + + v.Close() + t.Logf("oracle crash test: %d iterations x %d ops, all consistent", iterations, opsPerIter) +} + +// ============================================================================ +// QA Adversarial Tests — Tasks 1.10 (Lifecycle), 1.11 (Crash Stress) +// ============================================================================ + +// --- Task 1.10: Lifecycle adversarial tests --- + +func TestQALifecycle(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_lifecycle_write_after_close", run: testQALifecycleWriteAfterClose}, + {name: "qa_lifecycle_read_after_close", run: testQALifecycleReadAfterClose}, + {name: "qa_lifecycle_sync_after_close", run: testQALifecycleSyncAfterClose}, + {name: "qa_lifecycle_close_drains_dirty", run: testQALifecycleCloseDrainsDirty}, + {name: "qa_lifecycle_multi_cycle_accumulate", run: testQALifecycleMultiCycleAccumulate}, + {name: "qa_lifecycle_close_with_background_flusher", run: testQALifecycleCloseWithBackgroundFlusher}, + {name: "qa_lifecycle_healthy_flag", run: testQALifecycleHealthyFlag}, + {name: "qa_lifecycle_open_close_rapid", run: testQALifecycleOpenCloseRapid}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQALifecycleWriteAfterClose: WriteLBA after Close must fail gracefully (not panic). +func testQALifecycleWriteAfterClose(t *testing.T) { + v := createTestVol(t) + v.Close() + + // Write after close — fd is closed, should get an error, never a panic. + err := v.WriteLBA(0, makeBlock('X')) + if err == nil { + t.Error("WriteLBA after Close should fail") + } +} + +// testQALifecycleReadAfterClose: ReadLBA after Close must fail gracefully (not panic). +func testQALifecycleReadAfterClose(t *testing.T) { + v := createTestVol(t) + + // Write something first so dirty map has an entry. + if err := v.WriteLBA(0, makeBlock('R')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + v.Close() + + // Read after close — fd is closed, should error, not panic. + _, err := v.ReadLBA(0, 4096) + if err == nil { + t.Error("ReadLBA after Close should fail") + } +} + +// testQALifecycleSyncAfterClose: SyncCache after Close must return shutdown error. +func testQALifecycleSyncAfterClose(t *testing.T) { + v := createTestVol(t) + v.Close() + + err := v.SyncCache() + if err == nil { + t.Error("SyncCache after Close should fail") + } + if !errors.Is(err, ErrGroupCommitShutdown) { + t.Errorf("SyncCache after Close: got %v, want ErrGroupCommitShutdown", err) + } +} + +// testQALifecycleCloseDrainsDirty: Close does a final flush — dirty map should +// be empty and data should be in extent region. Reopen should find blocks in +// extent (not WAL) and dirty map should be empty. +func testQALifecycleCloseDrainsDirty(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "drain.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write 20 blocks. + for i := uint64(0); i < 20; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i%26))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Close — should flush all dirty blocks to extent. + if err := v.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Reopen — recovery should find nothing in WAL (all flushed). + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Dirty map should be empty (all data is in extent). + if v2.dirtyMap.Len() != 0 { + t.Errorf("dirty map after reopen should be 0, got %d (close didn't fully flush)", v2.dirtyMap.Len()) + } + + // Verify all blocks readable. + for i := uint64(0); i < 20; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i%26)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: mismatch after close+reopen", i) + } + } +} + +// testQALifecycleMultiCycleAccumulate: Write→sync→close→reopen→write more, 5 cycles. +// Each cycle adds new blocks. Verify all accumulated blocks survive. +func testQALifecycleMultiCycleAccumulate(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "accumulate.blockvol") + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + oracle := make(map[uint64]byte) + + for cycle := 0; cycle < 5; cycle++ { + // Write 10 blocks per cycle at different LBAs. + for i := 0; i < 10; i++ { + lba := uint64(cycle*10 + i) + fill := byte('A' + (cycle*10+i)%26) + if err := v.WriteLBA(lba, makeBlock(fill)); err != nil { + t.Fatalf("cycle %d WriteLBA(%d): %v", cycle, lba, err) + } + oracle[lba] = fill + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("cycle %d SyncCache: %v", cycle, err) + } + if err := v.Close(); err != nil { + t.Fatalf("cycle %d Close: %v", cycle, err) + } + + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("cycle %d OpenBlockVol: %v", cycle, err) + } + + // Verify all accumulated data. + for lba, fill := range oracle { + got, err := v.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("cycle %d ReadLBA(%d): %v", cycle, lba, err) + } + if !bytes.Equal(got, makeBlock(fill)) { + t.Fatalf("cycle %d block %d: data mismatch", cycle, lba) + } + } + } + + v.Close() + t.Logf("multi-cycle: 5 cycles, %d blocks accumulated, all consistent", len(oracle)) +} + +// testQALifecycleCloseWithBackgroundFlusher: Write enough to trigger background +// flusher (100ms interval), then close. Verify shutdown ordering is correct. +func testQALifecycleCloseWithBackgroundFlusher(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bgflush.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write blocks and let the background flusher potentially kick in. + for i := uint64(0); i < 30; i++ { + if err := v.WriteLBA(i, makeBlock(byte(i%26+'A'))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Wait a bit to let the flusher potentially run. + time.Sleep(150 * time.Millisecond) + + // Close — must coordinate with flusher goroutine. + if err := v.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Reopen and verify. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + for i := uint64(0); i < 30; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte(i%26+'A'))) { + t.Errorf("block %d: mismatch after flusher+close+reopen", i) + } + } +} + +// testQALifecycleHealthyFlag: Verify Info().Healthy reflects fsync failures. +func testQALifecycleHealthyFlag(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + if !v.Info().Healthy { + t.Error("volume should be healthy initially") + } + + // Force unhealthy by directly setting the flag (simulating fsync error + // that the OnDegraded callback would trigger). + v.healthy.Store(false) + if v.Info().Healthy { + t.Error("volume should report unhealthy after flag set") + } + + // Restore. + v.healthy.Store(true) + if !v.Info().Healthy { + t.Error("volume should report healthy after restoration") + } +} + +// testQALifecycleOpenCloseRapid: Open and close 20 times rapidly. +// Tests for goroutine/resource leaks. +func testQALifecycleOpenCloseRapid(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "rapid.blockvol") + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 128 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write one block so there's something to flush. + if err := v.WriteLBA(0, makeBlock('R')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + v.Close() + + for i := 0; i < 20; i++ { + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("open %d: %v", i, err) + } + if err := v.Close(); err != nil { + t.Fatalf("close %d: %v", i, err) + } + } + + // Final open — verify data survived 20 open/close cycles. + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("final open: %v", err) + } + defer v.Close() + + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("final read: %v", err) + } + if !bytes.Equal(got, makeBlock('R')) { + t.Error("data lost after 20 rapid open/close cycles") + } +} + +// --- Task 1.11: Crash stress adversarial tests --- + +func TestQACrashStress(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_crash_no_sync_data_loss_ok", run: testQACrashNoSyncDataLossOK}, + {name: "qa_crash_with_flush_then_crash", run: testQACrashWithFlushThenCrash}, + {name: "qa_crash_wal_near_full", run: testQACrashWALNearFull}, + {name: "qa_crash_concurrent_writers", run: testQACrashConcurrentWriters}, + {name: "qa_crash_trim_heavy", run: testQACrashTrimHeavy}, + {name: "qa_crash_multi_block_stress", run: testQACrashMultiBlockStress}, + {name: "qa_crash_overwrite_storm", run: testQACrashOverwriteStorm}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQACrashNoSyncDataLossOK: Write WITHOUT SyncCache, crash, recover. +// Un-synced data MAY be lost — this is correct behavior, not a bug. +// The key invariant: volume must open without error and previously synced +// data must survive. +func testQACrashNoSyncDataLossOK(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "nosync.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write block 0 and SYNC it. + if err := v.WriteLBA(0, makeBlock('S')); err != nil { + t.Fatalf("WriteLBA(synced): %v", err) + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Persist superblock. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + // Write block 1 WITHOUT sync. + if err := v.WriteLBA(1, makeBlock('U')); err != nil { + t.Fatalf("WriteLBA(unsynced): %v", err) + } + + // Hard crash (no sync, no superblock update for block 1). + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Synced block 0 must survive. + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0): %v", err) + } + if !bytes.Equal(got, makeBlock('S')) { + t.Error("synced block 0 should survive crash") + } + + // Un-synced block 1: may or may not be there — both are correct. + // Just verify we can read without error. + _, err = v2.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1) should not error: %v", err) + } +} + +// testQACrashWithFlushThenCrash: Write, let flusher run (data in extent), +// write more, crash WITHOUT sync. Flushed data must survive. +func testQACrashWithFlushThenCrash(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "flush_crash.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write blocks 0-4 and sync. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Wait for background flusher to pick up the entries (100ms interval). + time.Sleep(200 * time.Millisecond) + + // Write more blocks (in WAL, not yet synced to superblock). + for i := uint64(5); i < 8; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Persist superblock with latest WAL state. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // All blocks 0-7 should be readable (0-4 from extent, 5-7 from WAL). + for i := uint64(0); i < 8; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: mismatch after flush+crash", i) + } + } +} + +// testQACrashWALNearFull: Fill WAL to near-capacity, sync, crash, recover. +// All synced entries must be recoverable even when WAL is almost full. +func testQACrashWALNearFull(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "walfull.blockvol") + walSize := uint64(64 * 1024) // tiny 64KB WAL + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: walSize, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + entrySize := uint64(walEntryHeaderSize + 4096) + maxEntries := int(walSize / entrySize) + + // Write up to capacity. + var written int + for i := 0; i < maxEntries; i++ { + err := v.WriteLBA(uint64(i), makeBlock(byte('A'+i%26))) + if err != nil { + break // ErrWALFull + } + written++ + } + + if written == 0 { + t.Fatal("couldn't write any entries") + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + for i := 0; i < written; i++ { + got, err := v2.ReadLBA(uint64(i), 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('A'+i%26))) { + t.Errorf("block %d: mismatch after near-full WAL recovery", i) + } + } + + t.Logf("near-full WAL: %d/%d entries recovered", written, maxEntries) +} + +// testQACrashConcurrentWriters: Multiple goroutines write, then sync, then crash. +// All synced data must survive recovery. +func testQACrashConcurrentWriters(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "concurrent_crash.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 4 * 1024 * 1024, + BlockSize: 4096, + WALSize: 2 * 1024 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + const goroutines = 8 + const opsPerGoroutine = 20 + + // Each goroutine writes to its own LBA range. + var wg sync.WaitGroup + writtenLBAs := make([][]uint64, goroutines) + + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + baseLBA := uint64(id * opsPerGoroutine) + for i := 0; i < opsPerGoroutine; i++ { + lba := baseLBA + uint64(i) + data := makeBlock(byte('A' + id%26)) + if err := v.WriteLBA(lba, data); err != nil { + if errors.Is(err, ErrWALFull) { + return + } + return + } + writtenLBAs[id] = append(writtenLBAs[id], lba) + } + }(g) + } + wg.Wait() + + // Sync everything. + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Verify all written LBAs. + var totalVerified int + for g := 0; g < goroutines; g++ { + expected := makeBlock(byte('A' + g%26)) + for _, lba := range writtenLBAs[g] { + got, err := v2.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", lba, err) + } + if !bytes.Equal(got, expected) { + t.Errorf("goroutine %d LBA %d: mismatch after concurrent crash", g, lba) + } + totalVerified++ + } + } + t.Logf("concurrent crash: %d blocks verified from %d goroutines", totalVerified, goroutines) +} + +// testQACrashTrimHeavy: Crash loop with heavy trim operations. +// Verifies trim semantics survive crash+recovery correctly. +func testQACrashTrimHeavy(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "trim_crash.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 256 * 1024, // 64 blocks + BlockSize: 4096, + WALSize: 128 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + oracle := make(map[uint64]byte) + const maxLBA = 64 + + for iter := 0; iter < 20; iter++ { + // Write some blocks. + for i := 0; i < 4; i++ { + lba := uint64((iter*3 + i*7) % maxLBA) + fill := byte('A' + (iter+i)%26) + if err := v.WriteLBA(lba, makeBlock(fill)); err != nil { + if errors.Is(err, ErrWALFull) { + continue + } + t.Fatalf("iter %d WriteLBA(%d): %v", iter, lba, err) + } + oracle[lba] = fill + } + + // Trim half of what we wrote. + for i := 0; i < 2; i++ { + lba := uint64((iter*3 + i*7) % maxLBA) + if err := v.Trim(lba, 4096); err != nil { + if errors.Is(err, ErrWALFull) { + continue + } + t.Fatalf("iter %d Trim(%d): %v", iter, lba, err) + } + oracle[lba] = 0 + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("iter %d SyncCache: %v", iter, err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("iter %d OpenBlockVol: %v", iter, err) + } + + // Verify oracle. + for lba, fill := range oracle { + got, err := v.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("iter %d ReadLBA(%d): %v", iter, lba, err) + } + var expected []byte + if fill == 0 { + expected = make([]byte, 4096) + } else { + expected = makeBlock(fill) + } + if !bytes.Equal(got, expected) { + t.Fatalf("iter %d LBA %d: oracle mismatch (got[0]=%d want[0]=%d)", + iter, lba, got[0], expected[0]) + } + } + } + + v.Close() + t.Logf("trim-heavy crash: 20 iterations, all consistent") +} + +// testQACrashMultiBlockStress: Crash loop with multi-block writes (2-4 blocks). +// Exercises recovery of multi-block WAL entries. +func testQACrashMultiBlockStress(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "multiblock_crash.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + oracle := make(map[uint64]byte) + + for iter := 0; iter < 15; iter++ { + // Write 2-4 blocks as a single call. + nBlocks := 2 + (iter % 3) // 2, 3, or 4 + baseLBA := uint64((iter * 5) % 200) + + data := make([]byte, nBlocks*4096) + for b := 0; b < nBlocks; b++ { + fill := byte('A' + (iter+b)%26) + for j := 0; j < 4096; j++ { + data[b*4096+j] = fill + } + oracle[baseLBA+uint64(b)] = fill + } + + if err := v.WriteLBA(baseLBA, data); err != nil { + if errors.Is(err, ErrWALFull) { + // Can't write, skip this iteration. + continue + } + t.Fatalf("iter %d WriteLBA: %v", iter, err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("iter %d SyncCache: %v", iter, err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("iter %d OpenBlockVol: %v", iter, err) + } + + // Verify all oracle entries. + for lba, fill := range oracle { + got, err := v.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("iter %d ReadLBA(%d): %v", iter, lba, err) + } + if !bytes.Equal(got, makeBlock(fill)) { + t.Fatalf("iter %d LBA %d: mismatch", iter, lba) + } + } + } + + v.Close() + t.Logf("multi-block crash: 15 iterations, %d blocks tracked, all consistent", len(oracle)) +} + +// testQACrashOverwriteStorm: Overwrite the same LBA many times across crash +// iterations. Latest synced write must always win. +func testQACrashOverwriteStorm(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "overwrite_crash.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + for iter := 0; iter < 30; iter++ { + fill := byte(iter % 256) + // Overwrite LBA 0 with a new value each iteration. + if err := v.WriteLBA(0, makeBlock(fill)); err != nil { + if errors.Is(err, ErrWALFull) { + // Need to persist what we have and cycle. + v.SyncCache() + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + path = simulateCrash(v) + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("iter %d reopen: %v", iter, err) + } + // Retry write. + if err := v.WriteLBA(0, makeBlock(fill)); err != nil { + t.Fatalf("iter %d retry WriteLBA: %v", iter, err) + } + } else { + t.Fatalf("iter %d WriteLBA: %v", iter, err) + } + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("iter %d SyncCache: %v", iter, err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("iter %d OpenBlockVol: %v", iter, err) + } + + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("iter %d ReadLBA: %v", iter, err) + } + if !bytes.Equal(got, makeBlock(fill)) { + t.Fatalf("iter %d: LBA 0 should be %d, got %d", iter, fill, got[0]) + } + } + + v.Close() + t.Logf("overwrite storm: 30 crash iterations, LBA 0 always correct") +} + +// ============================================================================ +// QA Adversarial Tests — Round 4 (Architect-directed edge cases) +// ============================================================================ + +// --- WAL / Recovery Edge Cases --- + +func TestQARecoveryEdgeCases(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_recover_entrysize_mismatch_at_tail", run: testQARecoverEntrySizeMismatchAtTail}, + {name: "qa_recover_partial_padding", run: testQARecoverPartialPadding}, + {name: "qa_recover_trim_then_write_same_lba", run: testQARecoverTrimThenWriteSameLBA}, + {name: "qa_recover_write_then_trim_same_lba", run: testQARecoverWriteThenTrimSameLBA}, + {name: "qa_recover_barrier_only_full_wal", run: testQARecoverBarrierOnlyFullWAL}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQARecoverEntrySizeMismatchAtTail: Corrupt the EntrySize field of the last +// WAL entry. Recovery should stop cleanly at the corrupt entry (CRC or EntrySize +// validation) without panic or returning an error. +func testQARecoverEntrySizeMismatchAtTail(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "entrysize.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write 3 entries. + for i := uint64(0); i < 3; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + // Corrupt the EntrySize field (last 4 bytes) of the 3rd entry. + entrySize := uint64(walEntryHeaderSize + 4096) + thirdEntryEnd := v.super.WALOffset + entrySize*3 + entrySizeOff := int64(thirdEntryEnd - 4) // EntrySize is last 4 bytes + var badSize [4]byte + binary.LittleEndian.PutUint32(badSize[:], 99999) + v.fd.WriteAt(badSize[:], entrySizeOff) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol should succeed (recovery stops at corrupt entry): %v", err) + } + defer v2.Close() + + // Entries 1 and 2 should be recovered. + for i := uint64(0); i < 2; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('A'+i))) { + t.Errorf("block %d: should be '%c' after recovery", i, byte('A'+i)) + } + } + + // Entry 3 (corrupt) should NOT be recovered — read returns zeros. + got, err := v2.ReadLBA(2, 4096) + if err != nil { + t.Fatalf("ReadLBA(2): %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("block 2 (corrupt EntrySize) should return zeros") + } +} + +// testQARecoverPartialPadding: Write entries until WAL wraps with padding. +// Corrupt the padding entry to simulate truncation at EOF. Recovery should +// stop at the corrupt padding (torn write) and not advance past it. +func testQARecoverPartialPadding(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "partial_pad.blockvol") + + // WAL sized so first entry leaves a gap that needs padding. + entrySize := uint64(walEntryHeaderSize + 4096) // 4134 + // WAL = 2 * entrySize + 50 bytes (50 bytes becomes padding on wrap). + walSize := entrySize*2 + 50 + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: walSize, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write 2 entries (fills 2*4134 = 8268 bytes, leaving 50 bytes). + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA(0): %v", err) + } + if err := v.WriteLBA(1, makeBlock('B')); err != nil { + t.Fatalf("WriteLBA(1): %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Advance tail past entry 1 so we have room to wrap. + v.wal.AdvanceTail(entrySize) + + // Write entry 3 — this should trigger padding (50 bytes) at end and wrap to 0. + if err := v.WriteLBA(2, makeBlock('C')); err != nil { + t.Fatalf("WriteLBA(2) after wrap: %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + // Corrupt the padding region: overwrite the padding with garbage + // to simulate a torn write at the padding boundary. + paddingOff := int64(v.super.WALOffset + entrySize*2) + garbage := bytes.Repeat([]byte{0xDE}, 50) + v.fd.WriteAt(garbage, paddingOff) + v.fd.Sync() + + path = simulateCrash(v) + + // Recovery should handle this — either skip corrupt padding and find + // entry 3, or stop at the corruption. Either way, no panic. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol should not fail: %v", err) + } + defer v2.Close() + + // Entry 2 (LBA 1) should be recovered (it's before the padding). + got, err := v2.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1): %v", err) + } + if !bytes.Equal(got, makeBlock('B')) { + t.Error("block 1 should survive (before corrupt padding)") + } +} + +// testQARecoverTrimThenWriteSameLBA: TRIM LBA X, then WRITE same LBA, crash. +// Recovery should keep the WRITE (latest LSN wins). +func testQARecoverTrimThenWriteSameLBA(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "trim_write.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write initial data. + if err := v.WriteLBA(5, makeBlock('X')); err != nil { + t.Fatalf("WriteLBA initial: %v", err) + } + + // Trim LBA 5. + if err := v.Trim(5, 4096); err != nil { + t.Fatalf("Trim: %v", err) + } + + // Write LBA 5 again with new data. + if err := v.WriteLBA(5, makeBlock('Y')); err != nil { + t.Fatalf("WriteLBA after trim: %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Latest WRITE should win over earlier TRIM. + got, err := v2.ReadLBA(5, 4096) + if err != nil { + t.Fatalf("ReadLBA(5): %v", err) + } + if !bytes.Equal(got, makeBlock('Y')) { + t.Error("LBA 5 should be 'Y' (WRITE after TRIM wins)") + } +} + +// testQARecoverWriteThenTrimSameLBA: WRITE LBA X, then TRIM same LBA, crash. +// Recovery should return zeros (TRIM is latest). +func testQARecoverWriteThenTrimSameLBA(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "write_trim.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write LBA 7. + if err := v.WriteLBA(7, makeBlock('W')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Trim LBA 7. + if err := v.Trim(7, 4096); err != nil { + t.Fatalf("Trim: %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // TRIM is latest — should return zeros. + got, err := v2.ReadLBA(7, 4096) + if err != nil { + t.Fatalf("ReadLBA(7): %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("LBA 7 should be zeros (TRIM after WRITE)") + } +} + +// testQARecoverBarrierOnlyFullWAL: Fill WAL entirely with BARRIER entries. +// Recovery should process them all without error but make no data changes. +func testQARecoverBarrierOnlyFullWAL(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "barrier_full.blockvol") + + // WAL sized for ~5 barrier entries (header-only, 38 bytes each). + walSize := uint64(walEntryHeaderSize * 5) + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: walSize, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Append barrier entries until WAL is full. + var appended int + for i := 0; i < 10; i++ { + lsn := v.nextLSN.Add(1) - 1 + entry := &WALEntry{LSN: lsn, Type: EntryTypeBarrier, LBA: 0} + if _, err := v.wal.Append(entry); err != nil { + break // ErrWALFull + } + appended++ + } + + if appended == 0 { + t.Fatal("couldn't append any barrier entries") + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol with barrier-full WAL: %v", err) + } + defer v2.Close() + + // No data changes from barriers — read should return zeros. + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("barrier-only WAL should leave data as zeros") + } + + t.Logf("barrier-full WAL: %d barriers appended, recovery clean", appended) +} + +// --- Flusher / Dirty Map Edge Cases --- + +func TestQAFlusherEdgeCases(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_flush_interleaved_overwrite", run: testQAFlushInterleavedOverwrite}, + {name: "qa_flush_partial_wal_wrap", run: testQAFlushPartialWALWrap}, + {name: "qa_flush_trim_mixed_write", run: testQAFlushTrimMixedWrite}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQAFlushInterleavedOverwrite: Write LBA 0 three times with increasing LSN. +// Flush after first, overwrite twice more, flush again. Flusher's LSN-check +// should only remove entries matching the snapshot LSN. +func testQAFlushInterleavedOverwrite(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write LBA 0 = 'A' (LSN 1). + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA(A): %v", err) + } + + // Flush — moves 'A' to extent, removes dirty entry for LSN 1. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 1: %v", err) + } + + // Overwrite LBA 0 = 'B' (LSN 2). + if err := v.WriteLBA(0, makeBlock('B')); err != nil { + t.Fatalf("WriteLBA(B): %v", err) + } + + // Overwrite LBA 0 = 'C' (LSN 3). + if err := v.WriteLBA(0, makeBlock('C')); err != nil { + t.Fatalf("WriteLBA(C): %v", err) + } + + // Dirty map should have LBA 0 with LSN 3 (latest overwrite). + _, lsn, _, ok := v.dirtyMap.Get(0) + if !ok { + t.Fatal("LBA 0 should be in dirty map") + } + + // Flush — snapshot captures LSN 3. After flush, extent has 'C'. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 2: %v", err) + } + + // Dirty map should be empty (LSN matched, so flusher removed it). + if v.dirtyMap.Len() != 0 { + t.Errorf("dirty map should be empty after flush, got %d", v.dirtyMap.Len()) + } + + // Read should return 'C' from extent. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, makeBlock('C')) { + t.Error("LBA 0 should be 'C' after interleaved overwrites + flush") + } + + _ = lsn // used for clarity in the test logic +} + +// testQAFlushPartialWALWrap: Write entries until WAL wraps (with tail advance +// in between), then flush. Verify tail advance is correct and no WAL space leaks. +func testQAFlushPartialWALWrap(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "wrap_flush.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 128 * 1024, // small WAL + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + defer v.Close() + + f := NewFlusher(FlusherConfig{ + FD: v.fd, + Super: &v.super, + WAL: v.wal, + DirtyMap: v.dirtyMap, + Interval: 1 * time.Hour, // manual only + }) + + entrySize := uint64(walEntryHeaderSize + 4096) + maxEntries := int(128 * 1024 / entrySize) + + // Write ~60% capacity. + firstBatch := maxEntries * 60 / 100 + for i := 0; i < firstBatch; i++ { + if err := v.WriteLBA(uint64(i), makeBlock(byte('A'+i%26))); err != nil { + t.Fatalf("batch1 WriteLBA(%d): %v", i, err) + } + } + + // Flush — moves all to extent, advances tail. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 1: %v", err) + } + + // Write more — these will wrap around in the WAL. + for i := 0; i < firstBatch; i++ { + lba := uint64(firstBatch + i) + if lba >= 256 { // stay within volume + break + } + if err := v.WriteLBA(lba, makeBlock(byte('a'+i%26))); err != nil { + if errors.Is(err, ErrWALFull) { + break + } + t.Fatalf("batch2 WriteLBA(%d): %v", lba, err) + } + } + + // Flush again — should handle wrapped entries correctly. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 2: %v", err) + } + + // Dirty map should be empty. + if v.dirtyMap.Len() != 0 { + t.Errorf("dirty map should be 0 after double flush, got %d", v.dirtyMap.Len()) + } + + // Write more to verify WAL space was properly reclaimed. + for i := 0; i < 5; i++ { + if err := v.WriteLBA(uint64(i), makeBlock(byte('Z'-i))); err != nil { + t.Fatalf("post-wrap write %d: %v", i, err) + } + } + + // Verify latest writes. + for i := 0; i < 5; i++ { + got, err := v.ReadLBA(uint64(i), 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('Z'-i))) { + t.Errorf("block %d: mismatch after wrap+flush+rewrite", i) + } + } +} + +// testQAFlushTrimMixedWrite: Write some blocks, trim some, write others. +// Flush once. Verify extent has correct data (zeros for trimmed, data for written). +func testQAFlushTrimMixedWrite(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write LBAs 0-4. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Trim LBAs 1 and 3. + if err := v.Trim(1, 4096); err != nil { + t.Fatalf("Trim(1): %v", err) + } + if err := v.Trim(3, 4096); err != nil { + t.Fatalf("Trim(3): %v", err) + } + + // Flush — should write data for 0,2,4 and zeros for 1,3. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // Dirty map should be empty. + if v.dirtyMap.Len() != 0 { + t.Errorf("dirty map should be empty, got %d", v.dirtyMap.Len()) + } + + // Verify from extent. + expected := map[uint64][]byte{ + 0: makeBlock('A'), + 1: make([]byte, 4096), // trimmed + 2: makeBlock('C'), + 3: make([]byte, 4096), // trimmed + 4: makeBlock('E'), + } + for lba, want := range expected { + got, err := v.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", lba, err) + } + if !bytes.Equal(got, want) { + t.Errorf("LBA %d: extent data mismatch after mixed flush", lba) + } + } +} + +// --- Lifecycle + Concurrency Edge Cases --- + +func TestQALifecycleConcurrency(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_concurrent_flush_and_write", run: testQAConcurrentFlushAndWrite}, + {name: "qa_close_while_synccache_waits", run: testQACloseWhileSyncCacheWaits}, + {name: "qa_close_with_pending_dirtymap", run: testQACloseWithPendingDirtyMap}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQAConcurrentFlushAndWrite: Background flusher runs while writes happen. +// Crash after some time. Verify no data loss for synced writes. +func testQAConcurrentFlushAndWrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "conc_flush.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + // v already has background flusher running (100ms interval). + + // Write 50 blocks with SyncCache, while flusher runs in background. + oracle := make(map[uint64]byte) + for i := uint64(0); i < 50; i++ { + fill := byte('A' + i%26) + if err := v.WriteLBA(i, makeBlock(fill)); err != nil { + if errors.Is(err, ErrWALFull) { + // Flusher should free space, but if not fast enough, skip. + time.Sleep(150 * time.Millisecond) // let flusher run + if err := v.WriteLBA(i, makeBlock(fill)); err != nil { + continue // still full, skip + } + } else { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + oracle[i] = fill + + // Sync periodically (every 10 writes). + if i%10 == 9 { + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache at %d: %v", i, err) + } + } + } + + // Final sync. + if err := v.SyncCache(); err != nil { + t.Fatalf("final SyncCache: %v", err) + } + + // Let flusher run one more cycle. + time.Sleep(150 * time.Millisecond) + + // Persist superblock and crash. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Verify all oracle entries (some from extent, some from WAL replay). + for lba, fill := range oracle { + got, err := v2.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", lba, err) + } + if !bytes.Equal(got, makeBlock(fill)) { + t.Errorf("block %d: mismatch after concurrent flush+crash", lba) + } + } + t.Logf("concurrent flush+write: %d blocks verified", len(oracle)) +} + +// testQACloseWhileSyncCacheWaits: Start SyncCache in a goroutine, then Close. +// SyncCache should return ErrGroupCommitShutdown (not deadlock). +func testQACloseWhileSyncCacheWaits(t *testing.T) { + v := createTestVol(t) + + if err := v.WriteLBA(0, makeBlock('X')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Launch SyncCache in background. + syncDone := make(chan error, 1) + go func() { + syncDone <- v.SyncCache() + }() + + // Small delay to let SyncCache enqueue. + time.Sleep(2 * time.Millisecond) + + // Close while SyncCache may be waiting. + closeDone := make(chan error, 1) + go func() { + closeDone <- v.Close() + }() + + // Both should complete without deadlock. + select { + case err := <-syncDone: + // SyncCache either succeeded (fsync happened before close) or got shutdown error. + if err != nil && !errors.Is(err, ErrGroupCommitShutdown) { + t.Errorf("SyncCache: unexpected error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("SyncCache deadlocked during Close") + } + + select { + case err := <-closeDone: + // Close may return nil or an error from final flush — both are OK. + _ = err + case <-time.After(5 * time.Second): + t.Fatal("Close deadlocked") + } +} + +// testQACloseWithPendingDirtyMap: Write blocks without sync, then Close. +// Close should flush dirty map. Reopen should show 0 dirty entries. +func testQACloseWithPendingDirtyMap(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pending_dirty.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write 15 blocks without explicit SyncCache. + for i := uint64(0); i < 15; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i%26))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Close — should stop group committer, stop flusher, do final flush. + if err := v.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Reopen — verify dirty map is empty (all data in extent). + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + if v2.dirtyMap.Len() != 0 { + t.Errorf("dirty map after reopen should be 0, got %d", v2.dirtyMap.Len()) + } + + // Verify all blocks. + for i := uint64(0); i < 15; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i%26)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: mismatch after close-with-pending", i) + } + } +} + +// --- Parameter Extremes --- + +func TestQAParameterExtremes(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "qa_blocksize_512_wal_small", run: testQABlockSize512WALSmall}, + {name: "qa_wal_size_min_header", run: testQAWALSizeMinHeader}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// testQABlockSize512WALSmall: 512-byte blocks with tiny WAL. Write, sync, +// crash, recover. Ensures no panics with non-standard parameters. +func testQABlockSize512WALSmall(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bs512_small.blockvol") + + // 512-byte blocks, 4KB WAL (holds ~7 entries: (38+512)=550 per entry). + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 64 * 1024, // 64KB = 128 blocks of 512 bytes + BlockSize: 512, + WALSize: 4 * 1024, // 4KB WAL + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write a few blocks. + data := make([]byte, 512) + for i := range data { + data[i] = 0xAB + } + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, data); err != nil { + if errors.Is(err, ErrWALFull) { + break + } + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + got, err := v2.ReadLBA(0, 512) + if err != nil { + t.Fatalf("ReadLBA(0): %v", err) + } + if !bytes.Equal(got, data) { + t.Error("512-byte block not recovered correctly") + } +} + +// testQAWALSizeMinHeader: WAL barely larger than one entry header. +// Should return ErrWALFull on first write without panicking. +func testQAWALSizeMinHeader(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tiny_wal.blockvol") + + // WAL = walEntryHeaderSize + 1 byte — can't fit any entry with data. + walSize := uint64(walEntryHeaderSize + 1) + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: walSize, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + defer v.Close() + + // First write should fail with ErrWALFull (entry is 38+4096=4134 > 39 bytes). + err = v.WriteLBA(0, makeBlock('X')) + if err == nil { + t.Fatal("expected ErrWALFull with tiny WAL") + } + if !errors.Is(err, ErrWALFull) { + t.Errorf("expected ErrWALFull, got: %v", err) + } + + // Volume should still be usable (read returns zeros, no panic). + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA should work even with full WAL: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("unwritten block should be zeros") + } +} diff --git a/weed/storage/blockvol/blockvol_test.go b/weed/storage/blockvol/blockvol_test.go new file mode 100644 index 000000000..0fa274756 --- /dev/null +++ b/weed/storage/blockvol/blockvol_test.go @@ -0,0 +1,685 @@ +package blockvol + +import ( + "bytes" + "encoding/binary" + "errors" + "os" + "path/filepath" + "testing" +) + +func TestBlockVol(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "write_then_read", run: testWriteThenRead}, + {name: "overwrite_read_latest", run: testOverwriteReadLatest}, + {name: "write_no_sync_not_durable", run: testWriteNoSyncNotDurable}, + {name: "write_sync_durable", run: testWriteSyncDurable}, + {name: "write_multiple_sync", run: testWriteMultipleSync}, + {name: "read_unflushed", run: testReadUnflushed}, + {name: "read_flushed", run: testReadFlushed}, + {name: "read_mixed_dirty_clean", run: testReadMixedDirtyClean}, + {name: "wal_read_corrupt_length", run: testWALReadCorruptLength}, + {name: "open_invalid_superblock", run: testOpenInvalidSuperblock}, + {name: "trim_large_length_read_returns_zero", run: testTrimLargeLengthReadReturnsZero}, + // Task 1.10: Lifecycle tests. + {name: "lifecycle_create_close_reopen", run: testLifecycleCreateCloseReopen}, + {name: "lifecycle_close_flushes_dirty", run: testLifecycleCloseFlushes}, + {name: "lifecycle_double_close", run: testLifecycleDoubleClose}, + {name: "lifecycle_info", run: testLifecycleInfo}, + {name: "lifecycle_write_sync_close_reopen", run: testLifecycleWriteSyncCloseReopen}, + // Task 1.11: Crash stress test. + {name: "crash_stress_100", run: testCrashStress100}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func createTestVol(t *testing.T) *BlockVol { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "test.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, // 1MB + BlockSize: 4096, + WALSize: 256 * 1024, // 256KB WAL + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + return v +} + +func makeBlock(fill byte) []byte { + b := make([]byte, 4096) + for i := range b { + b[i] = fill + } + return b +} + +func testWriteThenRead(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + data := makeBlock('A') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("read data does not match written data") + } +} + +func testOverwriteReadLatest(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA(A): %v", err) + } + if err := v.WriteLBA(0, makeBlock('B')); err != nil { + t.Fatalf("WriteLBA(B): %v", err) + } + + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, makeBlock('B')) { + t.Error("read should return latest write ('B'), not 'A'") + } +} + +func testWriteNoSyncNotDurable(t *testing.T) { + v := createTestVol(t) + path := v.Path() + + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Simulate crash: close fd without sync. + v.fd.Close() + + // Reopen -- without recovery (Phase 1.9), data MAY be lost. + // This test just verifies we can reopen without error. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol after crash: %v", err) + } + defer v2.Close() + // Data may or may not be present -- both are correct without SyncCache. +} + +func testWriteSyncDurable(t *testing.T) { + v := createTestVol(t) + path := v.Path() + + data := makeBlock('A') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Sync WAL (manual fsync for now, group commit in Task 1.7). + if err := v.wal.Sync(); err != nil { + t.Fatalf("Sync: %v", err) + } + + // Update superblock WALHead so reopen knows where entries are. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALCheckpointLSN = 0 + if _, err := v.fd.Seek(0, 0); err != nil { + t.Fatalf("Seek: %v", err) + } + if _, err := v.super.WriteTo(v.fd); err != nil { + t.Fatalf("WriteTo: %v", err) + } + v.fd.Sync() + v.fd.Close() + + // Reopen and manually replay WAL to verify data is durable. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Manually replay: read WAL entry and populate dirty map. + replayBuf := make([]byte, v2.super.WALHead) + if _, err := v2.fd.ReadAt(replayBuf, int64(v2.super.WALOffset)); err != nil { + t.Fatalf("read WAL for replay: %v", err) + } + entry, err := DecodeWALEntry(replayBuf) + if err != nil { + t.Fatalf("decode WAL entry: %v", err) + } + blocks := entry.Length / v2.super.BlockSize + for i := uint32(0); i < blocks; i++ { + v2.dirtyMap.Put(entry.LBA+uint64(i), 0, entry.LSN, v2.super.BlockSize) + } + + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after recovery: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("data not durable after sync + reopen") + } +} + +func testWriteMultipleSync(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Write 10 blocks with different data. + for i := uint64(0); i < 10; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Read all back. + for i := uint64(0); i < 10; i++ { + got, err := v.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: data mismatch", i) + } + } +} + +func testReadUnflushed(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + data := makeBlock('X') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Read before flusher runs -- should come from dirty map / WAL. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("unflushed read: data mismatch") + } +} + +func testReadFlushed(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + data := makeBlock('F') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Manually flush: copy WAL data to extent region, clear dirty map. + extentStart := v.super.WALOffset + v.super.WALSize + if _, err := v.fd.WriteAt(data, int64(extentStart)); err != nil { + t.Fatalf("manual flush write: %v", err) + } + v.dirtyMap.Delete(0) + + // Read should now come from extent region. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after flush: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("flushed read: data mismatch") + } +} + +func testReadMixedDirtyClean(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Write blocks 0, 2, 4 (dirty). + for _, lba := range []uint64{0, 2, 4} { + if err := v.WriteLBA(lba, makeBlock(byte('A'+lba))); err != nil { + t.Fatalf("WriteLBA(%d): %v", lba, err) + } + } + + // Manually flush block 0 to extent, remove from dirty map. + extentStart := v.super.WALOffset + v.super.WALSize + if _, err := v.fd.WriteAt(makeBlock('A'), int64(extentStart)); err != nil { + t.Fatalf("manual flush: %v", err) + } + v.dirtyMap.Delete(0) + + // Block 0: from extent (flushed) + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0): %v", err) + } + if !bytes.Equal(got, makeBlock('A')) { + t.Error("block 0 (flushed) mismatch") + } + + // Block 2: from dirty map (WAL) + got, err = v.ReadLBA(2, 4096) + if err != nil { + t.Fatalf("ReadLBA(2): %v", err) + } + if !bytes.Equal(got, makeBlock('C')) { // 'A'+2 = 'C' + t.Error("block 2 (dirty) mismatch") + } + + // Block 4: from dirty map (WAL) + got, err = v.ReadLBA(4, 4096) + if err != nil { + t.Fatalf("ReadLBA(4): %v", err) + } + if !bytes.Equal(got, makeBlock('E')) { // 'A'+4 = 'E' + t.Error("block 4 (dirty) mismatch") + } + + // Blocks 1, 3, 5: never written, should be zeros (from extent region). + for _, lba := range []uint64{1, 3, 5} { + got, err = v.ReadLBA(lba, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", lba, err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Errorf("block %d (unwritten) should be zeros", lba) + } + } +} + +func testWALReadCorruptLength(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Write a valid block. + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Get the WAL offset from dirty map. + walOff, _, _, ok := v.dirtyMap.Get(0) + if !ok { + t.Fatal("block 0 not in dirty map") + } + + // Corrupt the Length field in the WAL entry on disk. + // Length is at header offset 26 (LSN=8 + Epoch=8 + Type=1 + Flags=1 + LBA=8). + absOff := int64(v.super.WALOffset + walOff) + lengthOff := absOff + 26 + var hugeLenBuf [4]byte + binary.LittleEndian.PutUint32(hugeLenBuf[:], 999999999) // ~1GB + if _, err := v.fd.WriteAt(hugeLenBuf[:], lengthOff); err != nil { + t.Fatalf("corrupt length: %v", err) + } + + // ReadLBA should detect the corrupt length and error (not panic/OOM). + _, err := v.ReadLBA(0, 4096) + if err == nil { + t.Error("expected error reading corrupt WAL entry, got nil") + } +} + +func testOpenInvalidSuperblock(t *testing.T) { + dir := t.TempDir() + + // Create a valid volume, then corrupt BlockSize to 0. + path := filepath.Join(dir, "corrupt.blockvol") + v, err := CreateBlockVol(path, CreateOptions{VolumeSize: 1024 * 1024}) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + v.Close() + + // Corrupt BlockSize (at superblock offset 36: Magic=4 + Version=2 + Flags=2 + UUID=16 + VolumeSize=8 + ExtentSize=4 = 36). + fd, err := os.OpenFile(path, os.O_RDWR, 0644) + if err != nil { + t.Fatalf("open for corrupt: %v", err) + } + var zeroBuf [4]byte + if _, err := fd.WriteAt(zeroBuf[:], 36); err != nil { + fd.Close() + t.Fatalf("corrupt blocksize: %v", err) + } + fd.Close() + + // OpenBlockVol should reject the corrupt superblock. + _, err = OpenBlockVol(path) + if err == nil { + t.Fatal("expected error opening volume with BlockSize=0") + } + if !errors.Is(err, ErrInvalidSuperblock) { + t.Errorf("expected ErrInvalidSuperblock, got: %v", err) + } +} + +// --- Task 1.10: Lifecycle tests --- + +func testLifecycleCreateCloseReopen(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "lifecycle.blockvol") + + // Create, write, close. + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + data := makeBlock('L') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + if err := v.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Reopen and verify data survived. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("data not durable after create→write→sync→close→reopen") + } +} + +func testLifecycleCloseFlushes(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "flush.blockvol") + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write several blocks. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Close does a final flush — dirty map should be drained. + if err := v.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Reopen and verify. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + for i := uint64(0); i < 5; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: data mismatch after close+reopen", i) + } + } +} + +func testLifecycleDoubleClose(t *testing.T) { + v := createTestVol(t) + if err := v.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + // Second close should not panic (group committer + flusher are idempotent). + // The fd.Close() will return an error but should not panic. + _ = v.Close() +} + +func testLifecycleInfo(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + info := v.Info() + if info.VolumeSize != 1*1024*1024 { + t.Errorf("VolumeSize = %d, want %d", info.VolumeSize, 1*1024*1024) + } + if info.BlockSize != 4096 { + t.Errorf("BlockSize = %d, want 4096", info.BlockSize) + } + if !info.Healthy { + t.Error("Healthy = false, want true") + } +} + +func testLifecycleWriteSyncCloseReopen(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "wsco.blockvol") + + // Cycle: create → write → sync → close → reopen → write → sync → close → verify. + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: 256 * 1024, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + if err := v.WriteLBA(0, makeBlock('X')); err != nil { + t.Fatalf("WriteLBA round 1: %v", err) + } + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache round 1: %v", err) + } + v.Close() + + // Reopen, write more. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol round 2: %v", err) + } + + if err := v2.WriteLBA(1, makeBlock('Y')); err != nil { + t.Fatalf("WriteLBA round 2: %v", err) + } + // Overwrite block 0. + if err := v2.WriteLBA(0, makeBlock('Z')); err != nil { + t.Fatalf("WriteLBA overwrite: %v", err) + } + if err := v2.SyncCache(); err != nil { + t.Fatalf("SyncCache round 2: %v", err) + } + v2.Close() + + // Final reopen — verify. + v3, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol round 3: %v", err) + } + defer v3.Close() + + got0, _ := v3.ReadLBA(0, 4096) + if !bytes.Equal(got0, makeBlock('Z')) { + t.Error("block 0: expected 'Z' (overwritten)") + } + got1, _ := v3.ReadLBA(1, 4096) + if !bytes.Equal(got1, makeBlock('Y')) { + t.Error("block 1: expected 'Y'") + } +} + +// --- Task 1.11: Crash stress test --- + +func testCrashStress100(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "stress.blockvol") + + const ( + volumeSize = 256 * 1024 // 256KB volume (64 blocks of 4KB) + blockSize = 4096 + walSize = 64 * 1024 // 64KB WAL + maxLBA = volumeSize / blockSize + iterations = 100 + ) + + // Oracle: tracks the expected state of each block. + oracle := make(map[uint64]byte) // lba → fill byte (0 = zeros/trimmed) + + // Create initial volume. + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: volumeSize, + BlockSize: blockSize, + WALSize: walSize, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + for iter := 0; iter < iterations; iter++ { + // Deterministic "random" ops using iteration number. + numOps := 3 + (iter % 5) // 3-7 ops per iteration + + for op := 0; op < numOps; op++ { + lba := uint64((iter*7 + op*13) % maxLBA) + action := (iter + op) % 3 // 0=write, 1=overwrite, 2=trim + + switch action { + case 0, 1: // write + fill := byte('A' + (iter+op)%26) + err := v.WriteLBA(lba, makeBlock(fill)) + if err != nil { + if errors.Is(err, ErrWALFull) { + continue // WAL full, skip this op + } + t.Fatalf("iter %d op %d: WriteLBA(%d): %v", iter, op, lba, err) + } + oracle[lba] = fill + case 2: // trim + err := v.Trim(lba, blockSize) + if err != nil { + if errors.Is(err, ErrWALFull) { + continue + } + t.Fatalf("iter %d op %d: Trim(%d): %v", iter, op, lba, err) + } + oracle[lba] = 0 + } + } + + // Sync WAL. + if err := v.SyncCache(); err != nil { + t.Fatalf("iter %d: SyncCache: %v", iter, err) + } + + // Simulate crash: close fd without clean shutdown. + v.fd.Sync() // ensure WAL is on disk + // Write superblock with current WAL positions for recovery. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + if _, seekErr := v.fd.Seek(0, 0); seekErr != nil { + t.Fatalf("iter %d: Seek: %v", iter, seekErr) + } + v.super.WriteTo(v.fd) + v.fd.Sync() + + // Hard crash: close fd, stop goroutines. + v.groupCommit.Stop() + v.flusher.Stop() + v.fd.Close() + + // Reopen with recovery. + v, err = OpenBlockVol(path) + if err != nil { + t.Fatalf("iter %d: OpenBlockVol: %v", iter, err) + } + + // Verify oracle against actual reads. + for lba, fill := range oracle { + got, readErr := v.ReadLBA(lba, blockSize) + if readErr != nil { + t.Fatalf("iter %d: ReadLBA(%d): %v", iter, lba, readErr) + } + var expected []byte + if fill == 0 { + expected = make([]byte, blockSize) + } else { + expected = makeBlock(fill) + } + if !bytes.Equal(got, expected) { + t.Fatalf("iter %d: block %d mismatch: got[0]=%d want[0]=%d", iter, lba, got[0], expected[0]) + } + } + } + + v.Close() +} + +func testTrimLargeLengthReadReturnsZero(t *testing.T) { + v := createTestVol(t) + defer v.Close() + + // Write data first. + data := makeBlock('A') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Verify data is written. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA before trim: %v", err) + } + if !bytes.Equal(got, data) { + t.Fatal("data mismatch before trim") + } + + // Trim with a length larger than WAL size — should still work. + // The trim Length is metadata (trim extent), not a data allocation. + if err := v.Trim(0, 4096); err != nil { + t.Fatalf("Trim: %v", err) + } + + // Read should return zeros (TRIM entry in dirty map). + got, err = v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after trim: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("read after trim should return zeros") + } +} diff --git a/weed/storage/blockvol/dirty_map.go b/weed/storage/blockvol/dirty_map.go new file mode 100644 index 000000000..ec6905ce3 --- /dev/null +++ b/weed/storage/blockvol/dirty_map.go @@ -0,0 +1,110 @@ +package blockvol + +import "sync" + +// dirtyEntry tracks a single LBA's location in the WAL. +type dirtyEntry struct { + walOffset uint64 + lsn uint64 + length uint32 +} + +// DirtyMap is a concurrency-safe map from LBA to WAL offset. +// Phase 1 uses a single shard; Phase 3 upgrades to 256 shards. +type DirtyMap struct { + mu sync.RWMutex + m map[uint64]dirtyEntry +} + +// NewDirtyMap creates an empty dirty map. +func NewDirtyMap() *DirtyMap { + return &DirtyMap{m: make(map[uint64]dirtyEntry)} +} + +// Put records that the given LBA has dirty data at walOffset. +func (d *DirtyMap) Put(lba uint64, walOffset uint64, lsn uint64, length uint32) { + d.mu.Lock() + d.m[lba] = dirtyEntry{walOffset: walOffset, lsn: lsn, length: length} + d.mu.Unlock() +} + +// Get returns the dirty entry for lba. ok is false if lba is not dirty. +func (d *DirtyMap) Get(lba uint64) (walOffset uint64, lsn uint64, length uint32, ok bool) { + d.mu.RLock() + e, found := d.m[lba] + d.mu.RUnlock() + if !found { + return 0, 0, 0, false + } + return e.walOffset, e.lsn, e.length, true +} + +// Delete removes a single LBA from the dirty map. +func (d *DirtyMap) Delete(lba uint64) { + d.mu.Lock() + delete(d.m, lba) + d.mu.Unlock() +} + +// rangeEntry is a snapshot of one dirty entry for lock-free iteration. +type rangeEntry struct { + lba uint64 + walOffset uint64 + lsn uint64 + length uint32 +} + +// Range calls fn for each dirty entry with LBA in [start, start+count). +// Entries are copied under lock, then fn is called without holding the lock, +// so fn may safely call back into DirtyMap (Put, Delete, etc.). +func (d *DirtyMap) Range(start uint64, count uint32, fn func(lba, walOffset, lsn uint64, length uint32)) { + end := start + uint64(count) + + // Snapshot matching entries under lock. + d.mu.RLock() + entries := make([]rangeEntry, 0, len(d.m)) + for lba, e := range d.m { + if lba >= start && lba < end { + entries = append(entries, rangeEntry{lba, e.walOffset, e.lsn, e.length}) + } + } + d.mu.RUnlock() + + // Call fn without lock. + for _, e := range entries { + fn(e.lba, e.walOffset, e.lsn, e.length) + } +} + +// Len returns the number of dirty entries. +func (d *DirtyMap) Len() int { + d.mu.RLock() + n := len(d.m) + d.mu.RUnlock() + return n +} + +// SnapshotEntry is an exported snapshot of one dirty entry. +type SnapshotEntry struct { + Lba uint64 + WalOffset uint64 + Lsn uint64 + Length uint32 +} + +// Snapshot returns a copy of all dirty entries. The snapshot is taken under +// read lock but returned without holding the lock. +func (d *DirtyMap) Snapshot() []SnapshotEntry { + d.mu.RLock() + entries := make([]SnapshotEntry, 0, len(d.m)) + for lba, e := range d.m { + entries = append(entries, SnapshotEntry{ + Lba: lba, + WalOffset: e.walOffset, + Lsn: e.lsn, + Length: e.length, + }) + } + d.mu.RUnlock() + return entries +} diff --git a/weed/storage/blockvol/dirty_map_test.go b/weed/storage/blockvol/dirty_map_test.go new file mode 100644 index 000000000..872412416 --- /dev/null +++ b/weed/storage/blockvol/dirty_map_test.go @@ -0,0 +1,153 @@ +package blockvol + +import ( + "sync" + "testing" +) + +func TestDirtyMap(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "dirty_put_get", run: testDirtyPutGet}, + {name: "dirty_overwrite", run: testDirtyOverwrite}, + {name: "dirty_delete", run: testDirtyDelete}, + {name: "dirty_range_query", run: testDirtyRangeQuery}, + {name: "dirty_empty", run: testDirtyEmpty}, + {name: "dirty_concurrent_rw", run: testDirtyConcurrentRW}, + {name: "dirty_range_modify", run: testDirtyRangeModify}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testDirtyPutGet(t *testing.T) { + dm := NewDirtyMap() + dm.Put(100, 5000, 1, 4096) + + off, lsn, length, ok := dm.Get(100) + if !ok { + t.Fatal("Get(100) returned not-found") + } + if off != 5000 { + t.Errorf("walOffset = %d, want 5000", off) + } + if lsn != 1 { + t.Errorf("lsn = %d, want 1", lsn) + } + if length != 4096 { + t.Errorf("length = %d, want 4096", length) + } +} + +func testDirtyOverwrite(t *testing.T) { + dm := NewDirtyMap() + dm.Put(100, 5000, 1, 4096) + dm.Put(100, 9000, 2, 4096) + + off, lsn, _, ok := dm.Get(100) + if !ok { + t.Fatal("Get(100) returned not-found") + } + if off != 9000 { + t.Errorf("walOffset = %d, want 9000 (latest)", off) + } + if lsn != 2 { + t.Errorf("lsn = %d, want 2 (latest)", lsn) + } +} + +func testDirtyDelete(t *testing.T) { + dm := NewDirtyMap() + dm.Put(100, 5000, 1, 4096) + dm.Delete(100) + + _, _, _, ok := dm.Get(100) + if ok { + t.Error("Get(100) should return not-found after delete") + } +} + +func testDirtyRangeQuery(t *testing.T) { + dm := NewDirtyMap() + for i := uint64(100); i < 110; i++ { + dm.Put(i, i*1000, i, 4096) + } + + found := make(map[uint64]bool) + dm.Range(100, 10, func(lba, walOffset, lsn uint64, length uint32) { + found[lba] = true + }) + + if len(found) != 10 { + t.Errorf("Range returned %d entries, want 10", len(found)) + } + for i := uint64(100); i < 110; i++ { + if !found[i] { + t.Errorf("Range missing LBA %d", i) + } + } +} + +func testDirtyEmpty(t *testing.T) { + dm := NewDirtyMap() + + _, _, _, ok := dm.Get(0) + if ok { + t.Error("empty map: Get(0) should return not-found") + } + _, _, _, ok = dm.Get(999) + if ok { + t.Error("empty map: Get(999) should return not-found") + } + if dm.Len() != 0 { + t.Errorf("empty map: Len() = %d, want 0", dm.Len()) + } +} + +func testDirtyConcurrentRW(t *testing.T) { + dm := NewDirtyMap() + const goroutines = 16 + const opsPerGoroutine = 1000 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(id int) { + defer wg.Done() + base := uint64(id * opsPerGoroutine) + for i := uint64(0); i < opsPerGoroutine; i++ { + lba := base + i + dm.Put(lba, lba*10, lba, 4096) + dm.Get(lba) + if i%3 == 0 { + dm.Delete(lba) + } + } + }(g) + } + wg.Wait() + + // No assertion on final state -- the test passes if no race detector fires. +} + +func testDirtyRangeModify(t *testing.T) { + // Verify that Range callback can safely call Delete without deadlock. + dm := NewDirtyMap() + for i := uint64(0); i < 10; i++ { + dm.Put(i, i*1000, i+1, 4096) + } + + // Delete every entry from within the Range callback. + dm.Range(0, 10, func(lba, walOffset, lsn uint64, length uint32) { + dm.Delete(lba) + }) + + if dm.Len() != 0 { + t.Errorf("after Range+Delete: Len() = %d, want 0", dm.Len()) + } +} diff --git a/weed/storage/blockvol/flusher.go b/weed/storage/blockvol/flusher.go new file mode 100644 index 000000000..d40dd6558 --- /dev/null +++ b/weed/storage/blockvol/flusher.go @@ -0,0 +1,239 @@ +package blockvol + +import ( + "encoding/binary" + "fmt" + "os" + "sync" + "time" +) + +// Flusher copies WAL entries to the extent region and frees WAL space. +// It runs as a background goroutine and can also be triggered manually. +type Flusher struct { + fd *os.File + super *Superblock + wal *WALWriter + dirtyMap *DirtyMap + walOffset uint64 // absolute file offset of WAL region + walSize uint64 + blockSize uint32 + extentStart uint64 // absolute file offset of extent region + + mu sync.Mutex + checkpointLSN uint64 // last flushed LSN + checkpointTail uint64 // WAL physical tail after last flush + + interval time.Duration + notifyCh chan struct{} + stopCh chan struct{} + done chan struct{} + stopOnce sync.Once +} + +// FlusherConfig configures the flusher. +type FlusherConfig struct { + FD *os.File + Super *Superblock + WAL *WALWriter + DirtyMap *DirtyMap + Interval time.Duration // default 100ms +} + +// NewFlusher creates a flusher. Call Run() in a goroutine. +func NewFlusher(cfg FlusherConfig) *Flusher { + if cfg.Interval == 0 { + cfg.Interval = 100 * time.Millisecond + } + return &Flusher{ + fd: cfg.FD, + super: cfg.Super, + wal: cfg.WAL, + dirtyMap: cfg.DirtyMap, + walOffset: cfg.Super.WALOffset, + walSize: cfg.Super.WALSize, + blockSize: cfg.Super.BlockSize, + extentStart: cfg.Super.WALOffset + cfg.Super.WALSize, + checkpointLSN: cfg.Super.WALCheckpointLSN, + checkpointTail: 0, + interval: cfg.Interval, + notifyCh: make(chan struct{}, 1), + stopCh: make(chan struct{}), + done: make(chan struct{}), + } +} + +// Run is the flusher main loop. Call in a goroutine. +func (f *Flusher) Run() { + defer close(f.done) + ticker := time.NewTicker(f.interval) + defer ticker.Stop() + + for { + select { + case <-f.stopCh: + return + case <-ticker.C: + f.FlushOnce() + case <-f.notifyCh: + f.FlushOnce() + } + } +} + +// Notify wakes up the flusher for an immediate flush cycle. +func (f *Flusher) Notify() { + select { + case f.notifyCh <- struct{}{}: + default: + } +} + +// Stop shuts down the flusher. Safe to call multiple times. +func (f *Flusher) Stop() { + f.stopOnce.Do(func() { + close(f.stopCh) + }) + <-f.done +} + +// FlushOnce performs a single flush cycle: scan dirty map, copy data to +// extent region, fsync, update checkpoint, advance WAL tail. +func (f *Flusher) FlushOnce() error { + // Snapshot dirty entries. We use a full scan (Range over all possible LBAs + // is impractical), so we collect from the dirty map directly. + type flushEntry struct { + lba uint64 + walOff uint64 + lsn uint64 + length uint32 + } + + entries := f.dirtyMap.Snapshot() + if len(entries) == 0 { + return nil + } + + // Find the max LSN and max WAL offset to know where to advance tail. + var maxLSN uint64 + var maxWALEnd uint64 + + for _, e := range entries { + // Read the WAL entry and copy data to extent region. + headerBuf := make([]byte, walEntryHeaderSize) + absWALOff := int64(f.walOffset + e.WalOffset) + if _, err := f.fd.ReadAt(headerBuf, absWALOff); err != nil { + return fmt.Errorf("flusher: read WAL header at %d: %w", absWALOff, err) + } + + // Parse entry type and length. + entryType := headerBuf[16] // Type at LSN(8)+Epoch(8)=16 + dataLen := parseLength(headerBuf) + + if entryType == EntryTypeWrite && dataLen > 0 { + // Read full entry. + entryLen := walEntryHeaderSize + int(dataLen) + fullBuf := make([]byte, entryLen) + if _, err := f.fd.ReadAt(fullBuf, absWALOff); err != nil { + return fmt.Errorf("flusher: read WAL entry at %d: %w", absWALOff, err) + } + + entry, err := DecodeWALEntry(fullBuf) + if err != nil { + return fmt.Errorf("flusher: decode WAL entry: %w", err) + } + + // Write data to extent region. + blocks := entry.Length / f.blockSize + for i := uint32(0); i < blocks; i++ { + blockLBA := entry.LBA + uint64(i) + extentOff := int64(f.extentStart + blockLBA*uint64(f.blockSize)) + blockData := entry.Data[i*f.blockSize : (i+1)*f.blockSize] + if _, err := f.fd.WriteAt(blockData, extentOff); err != nil { + return fmt.Errorf("flusher: write extent at LBA %d: %w", blockLBA, err) + } + } + + // Track WAL end position for tail advance. + walEnd := e.WalOffset + uint64(entryLen) + if walEnd > maxWALEnd { + maxWALEnd = walEnd + } + } else if entryType == EntryTypeTrim { + // TRIM entries: zero the extent region for this LBA. + // Each dirty map entry represents one trimmed block. + zeroBlock := make([]byte, f.blockSize) + extentOff := int64(f.extentStart + e.Lba*uint64(f.blockSize)) + if _, err := f.fd.WriteAt(zeroBlock, extentOff); err != nil { + return fmt.Errorf("flusher: zero extent at LBA %d: %w", e.Lba, err) + } + + // TRIM entry has no data payload, just a header. + walEnd := e.WalOffset + uint64(walEntryHeaderSize) + if walEnd > maxWALEnd { + maxWALEnd = walEnd + } + } + + if e.Lsn > maxLSN { + maxLSN = e.Lsn + } + } + + // Fsync extent writes. + if err := f.fd.Sync(); err != nil { + return fmt.Errorf("flusher: fsync extent: %w", err) + } + + // Remove flushed entries from dirty map. + f.mu.Lock() + for _, e := range entries { + // Only remove if the dirty map entry still has the same LSN + // (a newer write may have updated it). + _, currentLSN, _, ok := f.dirtyMap.Get(e.Lba) + if ok && currentLSN == e.Lsn { + f.dirtyMap.Delete(e.Lba) + } + } + f.checkpointLSN = maxLSN + f.checkpointTail = maxWALEnd + f.mu.Unlock() + + // Advance WAL tail to free space. + if maxWALEnd > 0 { + f.wal.AdvanceTail(maxWALEnd) + } + + // Update superblock checkpoint. + f.updateSuperblockCheckpoint(maxLSN, f.wal.Tail()) + + return nil +} + +// updateSuperblockCheckpoint writes the updated checkpoint to disk. +func (f *Flusher) updateSuperblockCheckpoint(checkpointLSN uint64, walTail uint64) error { + f.super.WALCheckpointLSN = checkpointLSN + f.super.WALHead = f.wal.LogicalHead() + f.super.WALTail = f.wal.LogicalTail() + + if _, err := f.fd.Seek(0, 0); err != nil { + return fmt.Errorf("flusher: seek to superblock: %w", err) + } + if _, err := f.super.WriteTo(f.fd); err != nil { + return fmt.Errorf("flusher: write superblock: %w", err) + } + return f.fd.Sync() +} + +// CheckpointLSN returns the last flushed LSN. +func (f *Flusher) CheckpointLSN() uint64 { + f.mu.Lock() + defer f.mu.Unlock() + return f.checkpointLSN +} + +// parseLength extracts the Length field from a WAL entry header buffer. +func parseLength(headerBuf []byte) uint32 { + // Length at LSN(8)+Epoch(8)+Type(1)+Flags(1)+LBA(8) = 26 + return binary.LittleEndian.Uint32(headerBuf[26:]) +} diff --git a/weed/storage/blockvol/flusher_test.go b/weed/storage/blockvol/flusher_test.go new file mode 100644 index 000000000..cc6ab2792 --- /dev/null +++ b/weed/storage/blockvol/flusher_test.go @@ -0,0 +1,261 @@ +package blockvol + +import ( + "bytes" + "path/filepath" + "testing" + "time" +) + +func TestFlusher(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "flush_moves_data", run: testFlushMovesData}, + {name: "flush_idempotent", run: testFlushIdempotent}, + {name: "flush_concurrent_writes", run: testFlushConcurrentWrites}, + {name: "flush_frees_wal_space", run: testFlushFreesWALSpace}, + {name: "flush_partial", run: testFlushPartial}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func createTestVolWithFlusher(t *testing.T) (*BlockVol, *Flusher) { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "test.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, // 1MB + BlockSize: 4096, + WALSize: 256 * 1024, // 256KB WAL + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + f := NewFlusher(FlusherConfig{ + FD: v.fd, + Super: &v.super, + WAL: v.wal, + DirtyMap: v.dirtyMap, + Interval: 1 * time.Hour, // don't auto-flush in tests + }) + + return v, f +} + +func testFlushMovesData(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write 10 blocks. + for i := uint64(0); i < 10; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + if v.dirtyMap.Len() != 10 { + t.Fatalf("dirty map len = %d, want 10", v.dirtyMap.Len()) + } + + // Run flusher. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // Dirty map should be empty. + if v.dirtyMap.Len() != 0 { + t.Errorf("after flush: dirty map len = %d, want 0", v.dirtyMap.Len()) + } + + // Checkpoint should have advanced. + if f.CheckpointLSN() == 0 { + t.Error("checkpoint LSN should be > 0 after flush") + } + + // Read from extent (dirty map is empty, so reads go to extent). + for i := uint64(0); i < 10; i++ { + got, err := v.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d) after flush: %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('A'+i))) { + t.Errorf("block %d: data mismatch after flush", i) + } + } +} + +func testFlushIdempotent(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + data := makeBlock('X') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Flush twice. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 1: %v", err) + } + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 2: %v", err) + } + + // Data should still be correct. + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after double flush: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("data mismatch after double flush") + } +} + +func testFlushConcurrentWrites(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write blocks 0-4. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Flush (moves blocks 0-4 to extent). + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // Write blocks 5-9 AFTER flush. + for i := uint64(5); i < 10; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Blocks 0-4 should read from extent, blocks 5-9 from WAL. + for i := uint64(0); i < 10; i++ { + got, err := v.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('A'+i))) { + t.Errorf("block %d: data mismatch", i) + } + } + + // Dirty map should have 5 entries (blocks 5-9). + if v.dirtyMap.Len() != 5 { + t.Errorf("dirty map len = %d, want 5", v.dirtyMap.Len()) + } + + // Also: overwrite block 0 after flush -- new write should go to WAL. + newData := makeBlock('Z') + if err := v.WriteLBA(0, newData); err != nil { + t.Fatalf("WriteLBA(0) overwrite: %v", err) + } + got, err := v.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0) after overwrite: %v", err) + } + if !bytes.Equal(got, newData) { + t.Error("block 0: should return overwritten data 'Z'") + } +} + +func testFlushFreesWALSpace(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write enough blocks to fill a significant portion of WAL. + entrySize := uint64(walEntryHeaderSize + 4096) + walCapacity := v.super.WALSize / entrySize + // Write ~80% of capacity. + writeCount := int(walCapacity * 80 / 100) + + for i := 0; i < writeCount; i++ { + if err := v.WriteLBA(uint64(i), makeBlock(byte(i%26+'A'))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Try to write more -- should eventually fail with WAL full. + var walFullBefore bool + for i := writeCount; i < writeCount+int(walCapacity); i++ { + if err := v.WriteLBA(uint64(i%writeCount), makeBlock('X')); err != nil { + walFullBefore = true + break + } + } + + // Flush to free WAL space. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // WAL tail should have advanced (free space available). + // New writes should succeed. + if err := v.WriteLBA(0, makeBlock('Y')); err != nil { + t.Fatalf("WriteLBA after flush: %v", err) + } + + // Log whether WAL was full before flush. + if walFullBefore { + t.Log("WAL was full before flush, writes succeeded after flush") + } +} + +func testFlushPartial(t *testing.T) { + v, f := createTestVolWithFlusher(t) + defer v.Close() + + // Write blocks 0-4. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Flush once (all 5 blocks). + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + checkpointAfterFirst := f.CheckpointLSN() + + // Write blocks 5-9. + for i := uint64(5); i < 10; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + // Simulate partial flush: flusher runs again, should handle new entries. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce 2: %v", err) + } + + checkpointAfterSecond := f.CheckpointLSN() + if checkpointAfterSecond <= checkpointAfterFirst { + t.Errorf("checkpoint should advance: first=%d, second=%d", checkpointAfterFirst, checkpointAfterSecond) + } + + // All blocks should be readable from extent. + for i := uint64(0); i < 10; i++ { + got, err := v.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d) after two flushes: %v", i, err) + } + if !bytes.Equal(got, makeBlock(byte('A'+i))) { + t.Errorf("block %d: data mismatch after two flushes", i) + } + } +} diff --git a/weed/storage/blockvol/group_commit.go b/weed/storage/blockvol/group_commit.go new file mode 100644 index 000000000..8bee6c6f6 --- /dev/null +++ b/weed/storage/blockvol/group_commit.go @@ -0,0 +1,167 @@ +package blockvol + +import ( + "errors" + "sync" + "sync/atomic" + "time" +) + +var ErrGroupCommitShutdown = errors.New("blockvol: group committer shut down") + +// GroupCommitter batches SyncCache requests and performs a single fsync +// for the entire batch. This amortizes the cost of fsync across many callers. +type GroupCommitter struct { + syncFunc func() error // called to fsync (injectable for testing) + maxDelay time.Duration // max wait before flushing a partial batch + maxBatch int // flush immediately when this many waiters accumulate + onDegraded func() // called when fsync fails + + mu sync.Mutex + pending []chan error + stopped bool // set under mu by Run() before draining + notifyCh chan struct{} + stopCh chan struct{} + done chan struct{} + stopOnce sync.Once + + syncCount atomic.Uint64 // number of fsyncs performed (for testing) +} + +// GroupCommitterConfig configures the group committer. +type GroupCommitterConfig struct { + SyncFunc func() error // required: the fsync function + MaxDelay time.Duration // default 1ms + MaxBatch int // default 64 + OnDegraded func() // optional: called on fsync error +} + +// NewGroupCommitter creates a new group committer. Call Run() to start it. +func NewGroupCommitter(cfg GroupCommitterConfig) *GroupCommitter { + if cfg.MaxDelay == 0 { + cfg.MaxDelay = 1 * time.Millisecond + } + if cfg.MaxBatch == 0 { + cfg.MaxBatch = 64 + } + if cfg.OnDegraded == nil { + cfg.OnDegraded = func() {} + } + return &GroupCommitter{ + syncFunc: cfg.SyncFunc, + maxDelay: cfg.MaxDelay, + maxBatch: cfg.MaxBatch, + onDegraded: cfg.OnDegraded, + notifyCh: make(chan struct{}, 1), + stopCh: make(chan struct{}), + done: make(chan struct{}), + } +} + +// Run is the main loop. Call this in a goroutine. +func (gc *GroupCommitter) Run() { + defer close(gc.done) + for { + // Wait for first waiter or shutdown. + select { + case <-gc.stopCh: + gc.markStoppedAndDrain() + return + case <-gc.notifyCh: + } + + // Collect batch: wait up to maxDelay for more waiters, or until maxBatch reached. + deadline := time.NewTimer(gc.maxDelay) + for { + gc.mu.Lock() + n := len(gc.pending) + gc.mu.Unlock() + if n >= gc.maxBatch { + deadline.Stop() + break + } + select { + case <-gc.stopCh: + deadline.Stop() + gc.markStoppedAndDrain() + return + case <-deadline.C: + goto flush + case <-gc.notifyCh: + continue + } + } + + flush: + // Take all pending waiters. + gc.mu.Lock() + batch := gc.pending + gc.pending = nil + gc.mu.Unlock() + + if len(batch) == 0 { + continue + } + + // Perform fsync. + err := gc.syncFunc() + gc.syncCount.Add(1) + if err != nil { + gc.onDegraded() + } + + // Wake all waiters. + for _, ch := range batch { + ch <- err + } + } +} + +// Submit submits a sync request and blocks until the batch fsync completes. +// Returns nil on success, the fsync error, or ErrGroupCommitShutdown if stopped. +func (gc *GroupCommitter) Submit() error { + ch := make(chan error, 1) + gc.mu.Lock() + if gc.stopped { + gc.mu.Unlock() + return ErrGroupCommitShutdown + } + gc.pending = append(gc.pending, ch) + gc.mu.Unlock() + + // Non-blocking notify to wake Run(). + select { + case gc.notifyCh <- struct{}{}: + default: + } + + return <-ch +} + +// Stop shuts down the group committer. Pending waiters receive ErrGroupCommitShutdown. +// Safe to call multiple times. +func (gc *GroupCommitter) Stop() { + gc.stopOnce.Do(func() { + close(gc.stopCh) + }) + <-gc.done +} + +// SyncCount returns the number of fsyncs performed (for testing). +func (gc *GroupCommitter) SyncCount() uint64 { + return gc.syncCount.Load() +} + +// markStoppedAndDrain sets the stopped flag under mu and drains all pending +// waiters with ErrGroupCommitShutdown. This ensures no new Submit() can +// enqueue after we drain, closing the race window from QA-002. +func (gc *GroupCommitter) markStoppedAndDrain() { + gc.mu.Lock() + gc.stopped = true + batch := gc.pending + gc.pending = nil + gc.mu.Unlock() + for _, ch := range batch { + ch <- ErrGroupCommitShutdown + } +} diff --git a/weed/storage/blockvol/group_commit_test.go b/weed/storage/blockvol/group_commit_test.go new file mode 100644 index 000000000..2e2958044 --- /dev/null +++ b/weed/storage/blockvol/group_commit_test.go @@ -0,0 +1,287 @@ +package blockvol + +import ( + "errors" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestGroupCommitter(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "group_single_barrier", run: testGroupSingleBarrier}, + {name: "group_batch_10", run: testGroupBatch10}, + {name: "group_max_delay", run: testGroupMaxDelay}, + {name: "group_max_batch", run: testGroupMaxBatch}, + {name: "group_fsync_error", run: testGroupFsyncError}, + {name: "group_sequential", run: testGroupSequential}, + {name: "group_shutdown", run: testGroupShutdown}, + {name: "group_submit_after_stop", run: testGroupSubmitAfterStop}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testGroupSingleBarrier(t *testing.T) { + var syncCalls atomic.Uint64 + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + syncCalls.Add(1) + return nil + }, + MaxDelay: 10 * time.Millisecond, + }) + go gc.Run() + defer gc.Stop() + + if err := gc.Submit(); err != nil { + t.Fatalf("Submit: %v", err) + } + + if c := syncCalls.Load(); c != 1 { + t.Errorf("syncCalls = %d, want 1", c) + } +} + +func testGroupBatch10(t *testing.T) { + var syncCalls atomic.Uint64 + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + syncCalls.Add(1) + return nil + }, + MaxDelay: 50 * time.Millisecond, + MaxBatch: 64, + }) + go gc.Run() + defer gc.Stop() + + const n = 10 + var wg sync.WaitGroup + errs := make([]error, n) + + // Launch all 10 concurrently. + wg.Add(n) + for i := 0; i < n; i++ { + go func(idx int) { + defer wg.Done() + errs[idx] = gc.Submit() + }(i) + } + wg.Wait() + + for i, err := range errs { + if err != nil { + t.Errorf("Submit[%d]: %v", i, err) + } + } + + // All 10 should be batched into 1 fsync (maybe 2 if timing is unlucky). + if c := syncCalls.Load(); c > 2 { + t.Errorf("syncCalls = %d, want 1-2 (batched)", c) + } +} + +func testGroupMaxDelay(t *testing.T) { + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { return nil }, + MaxDelay: 5 * time.Millisecond, + }) + go gc.Run() + defer gc.Stop() + + start := time.Now() + if err := gc.Submit(); err != nil { + t.Fatalf("Submit: %v", err) + } + elapsed := time.Since(start) + + // Should complete within ~maxDelay + some margin, not take much longer. + if elapsed > 50*time.Millisecond { + t.Errorf("Submit took %v, expected ~5ms", elapsed) + } +} + +func testGroupMaxBatch(t *testing.T) { + var syncCalls atomic.Uint64 + const batch = 64 + + // Use a slow sync to ensure we can detect immediate trigger. + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + syncCalls.Add(1) + return nil + }, + MaxDelay: 5 * time.Second, // very long delay -- should NOT wait this long + MaxBatch: batch, + }) + go gc.Run() + defer gc.Stop() + + var wg sync.WaitGroup + wg.Add(batch) + for i := 0; i < batch; i++ { + go func() { + defer wg.Done() + gc.Submit() + }() + } + + // Wait with timeout -- if maxBatch triggers, should complete fast. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Good -- completed without waiting 5 seconds. + case <-time.After(2 * time.Second): + t.Fatal("maxBatch=64 did not trigger immediate flush") + } +} + +func testGroupFsyncError(t *testing.T) { + errIO := errors.New("simulated EIO") + var degraded atomic.Bool + + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + return errIO + }, + MaxDelay: 10 * time.Millisecond, + OnDegraded: func() { degraded.Store(true) }, + }) + go gc.Run() + defer gc.Stop() + + const n = 5 + var wg sync.WaitGroup + errs := make([]error, n) + + wg.Add(n) + for i := 0; i < n; i++ { + go func(idx int) { + defer wg.Done() + errs[idx] = gc.Submit() + }(i) + } + wg.Wait() + + for i, err := range errs { + if !errors.Is(err, errIO) { + t.Errorf("Submit[%d]: got %v, want %v", i, err, errIO) + } + } + + if !degraded.Load() { + t.Error("onDegraded was not called") + } +} + +func testGroupSequential(t *testing.T) { + var syncCalls atomic.Uint64 + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + syncCalls.Add(1) + return nil + }, + MaxDelay: 10 * time.Millisecond, + }) + go gc.Run() + defer gc.Stop() + + // First sync. + if err := gc.Submit(); err != nil { + t.Fatalf("Submit 1: %v", err) + } + + // Second sync (after first completes). + if err := gc.Submit(); err != nil { + t.Fatalf("Submit 2: %v", err) + } + + if c := syncCalls.Load(); c != 2 { + t.Errorf("syncCalls = %d, want 2 (two separate fsyncs)", c) + } +} + +func testGroupShutdown(t *testing.T) { + // Block fsync so we can shut down while waiters are pending. + syncStarted := make(chan struct{}) + syncBlock := make(chan struct{}) + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { + close(syncStarted) + <-syncBlock // block until test releases + return nil + }, + MaxDelay: 1 * time.Millisecond, + }) + go gc.Run() + + // Submit one request that will block on fsync. + errCh1 := make(chan error, 1) + go func() { + errCh1 <- gc.Submit() + }() + + // Wait for fsync to start. + <-syncStarted + + // Submit another request while fsync is in progress. + errCh2 := make(chan error, 1) + go func() { + errCh2 <- gc.Submit() + }() + time.Sleep(5 * time.Millisecond) // let it enqueue + + // Stop the group committer (will drain pending with ErrGroupCommitShutdown). + go func() { + time.Sleep(10 * time.Millisecond) + close(syncBlock) // unblock the fsync first + }() + gc.Stop() + + // First waiter should get nil (fsync succeeded). + if err := <-errCh1; err != nil { + t.Errorf("first waiter: %v, want nil", err) + } + + // Second waiter should get shutdown error. + if err := <-errCh2; !errors.Is(err, ErrGroupCommitShutdown) { + t.Errorf("second waiter: %v, want ErrGroupCommitShutdown", err) + } +} + +func testGroupSubmitAfterStop(t *testing.T) { + gc := NewGroupCommitter(GroupCommitterConfig{ + SyncFunc: func() error { return nil }, + MaxDelay: 10 * time.Millisecond, + }) + go gc.Run() + gc.Stop() + + // Submit after Stop must return ErrGroupCommitShutdown, not deadlock. + done := make(chan error, 1) + go func() { + done <- gc.Submit() + }() + + select { + case err := <-done: + if !errors.Is(err, ErrGroupCommitShutdown) { + t.Errorf("Submit after Stop: %v, want ErrGroupCommitShutdown", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Submit after Stop deadlocked") + } +} diff --git a/weed/storage/blockvol/iscsi/cmd/iscsi-target/main.go b/weed/storage/blockvol/iscsi/cmd/iscsi-target/main.go new file mode 100644 index 000000000..6d023ec06 --- /dev/null +++ b/weed/storage/blockvol/iscsi/cmd/iscsi-target/main.go @@ -0,0 +1,141 @@ +// iscsi-target is a standalone iSCSI target backed by a BlockVol file. +// Usage: +// +// iscsi-target -vol /path/to/volume.blk -addr :3260 -iqn iqn.2024.com.seaweedfs:vol1 +// iscsi-target -create -size 1G -vol /path/to/volume.blk -addr :3260 +package main + +import ( + "flag" + "fmt" + "log" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + + "github.com/seaweedfs/seaweedfs/weed/storage/blockvol" + "github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi" +) + +func main() { + volPath := flag.String("vol", "", "path to BlockVol file") + addr := flag.String("addr", ":3260", "listen address") + iqn := flag.String("iqn", "iqn.2024.com.seaweedfs:vol1", "target IQN") + create := flag.Bool("create", false, "create a new volume file") + size := flag.String("size", "1G", "volume size (e.g., 1G, 100M) — used with -create") + flag.Parse() + + if *volPath == "" { + fmt.Fprintln(os.Stderr, "error: -vol is required") + flag.Usage() + os.Exit(1) + } + + logger := log.New(os.Stdout, "[iscsi] ", log.LstdFlags) + + var vol *blockvol.BlockVol + var err error + + if *create { + volSize, parseErr := parseSize(*size) + if parseErr != nil { + log.Fatalf("invalid size %q: %v", *size, parseErr) + } + vol, err = blockvol.CreateBlockVol(*volPath, blockvol.CreateOptions{ + VolumeSize: volSize, + BlockSize: 4096, + WALSize: 64 * 1024 * 1024, + }) + if err != nil { + log.Fatalf("create volume: %v", err) + } + logger.Printf("created volume: %s (%s)", *volPath, *size) + } else { + vol, err = blockvol.OpenBlockVol(*volPath) + if err != nil { + log.Fatalf("open volume: %v", err) + } + logger.Printf("opened volume: %s", *volPath) + } + defer vol.Close() + + info := vol.Info() + logger.Printf("volume: %d bytes, block=%d, healthy=%v", + info.VolumeSize, info.BlockSize, info.Healthy) + + // Create adapter + adapter := &blockVolAdapter{vol: vol} + + // Create target server + config := iscsi.DefaultTargetConfig() + config.TargetName = *iqn + config.TargetAlias = "SeaweedFS BlockVol" + ts := iscsi.NewTargetServer(*addr, config, logger) + ts.AddVolume(*iqn, adapter) + + // Graceful shutdown on signal + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigCh + logger.Printf("received %v, shutting down...", sig) + ts.Close() + }() + + logger.Printf("starting iSCSI target: %s on %s", *iqn, *addr) + if err := ts.ListenAndServe(); err != nil { + log.Fatalf("target server: %v", err) + } + logger.Println("target stopped") +} + +// blockVolAdapter wraps BlockVol to implement iscsi.BlockDevice. +type blockVolAdapter struct { + vol *blockvol.BlockVol +} + +func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) { + return a.vol.ReadLBA(lba, length) +} +func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error { + return a.vol.WriteLBA(lba, data) +} +func (a *blockVolAdapter) Trim(lba uint64, length uint32) error { + return a.vol.Trim(lba, length) +} +func (a *blockVolAdapter) SyncCache() error { return a.vol.SyncCache() } +func (a *blockVolAdapter) BlockSize() uint32 { return a.vol.Info().BlockSize } +func (a *blockVolAdapter) VolumeSize() uint64 { return a.vol.Info().VolumeSize } +func (a *blockVolAdapter) IsHealthy() bool { return a.vol.Info().Healthy } + +func parseSize(s string) (uint64, error) { + s = strings.TrimSpace(s) + if len(s) == 0 { + return 0, fmt.Errorf("empty size") + } + + multiplier := uint64(1) + suffix := s[len(s)-1] + switch suffix { + case 'K', 'k': + multiplier = 1024 + s = s[:len(s)-1] + case 'M', 'm': + multiplier = 1024 * 1024 + s = s[:len(s)-1] + case 'G', 'g': + multiplier = 1024 * 1024 * 1024 + s = s[:len(s)-1] + case 'T', 't': + multiplier = 1024 * 1024 * 1024 * 1024 + s = s[:len(s)-1] + } + + n, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, err + } + return n * multiplier, nil +} diff --git a/weed/storage/blockvol/iscsi/cmd/iscsi-target/smoke-test.sh b/weed/storage/blockvol/iscsi/cmd/iscsi-target/smoke-test.sh new file mode 100644 index 000000000..54e91246c --- /dev/null +++ b/weed/storage/blockvol/iscsi/cmd/iscsi-target/smoke-test.sh @@ -0,0 +1,216 @@ +#!/usr/bin/env bash +# smoke-test.sh — iscsiadm smoke test for SeaweedFS iSCSI target +# +# Prerequisites: +# - Linux host with iscsiadm (open-iscsi) installed +# - Root access (for iscsiadm, mkfs, mount) +# - iscsi-target binary built: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/ +# +# Usage: +# sudo ./smoke-test.sh [target-addr:port] +# +# Default: starts a local target on :3260, runs discovery + login + I/O, cleans up. + +set -euo pipefail + +TARGET_ADDR="${1:-127.0.0.1}" +TARGET_PORT="${2:-3260}" +TARGET_IQN="iqn.2024.com.seaweedfs:smoke" +VOL_FILE="/tmp/iscsi-smoke-test.blk" +MOUNT_POINT="/tmp/iscsi-smoke-mnt" +TARGET_PID="" + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log() { echo -e "${GREEN}[SMOKE]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +fail() { echo -e "${RED}[FAIL]${NC} $*"; exit 1; } + +cleanup() { + log "Cleaning up..." + + # Unmount if mounted + if mountpoint -q "$MOUNT_POINT" 2>/dev/null; then + umount "$MOUNT_POINT" || true + fi + rmdir "$MOUNT_POINT" 2>/dev/null || true + + # Logout from iSCSI + iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout 2>/dev/null || true + iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" -o delete 2>/dev/null || true + + # Stop target + if [[ -n "$TARGET_PID" ]] && kill -0 "$TARGET_PID" 2>/dev/null; then + kill "$TARGET_PID" + wait "$TARGET_PID" 2>/dev/null || true + fi + + # Remove volume file + rm -f "$VOL_FILE" + + log "Cleanup complete" +} + +trap cleanup EXIT + +# ------------------------------------------------------- +# Preflight +# ------------------------------------------------------- +if [[ $EUID -ne 0 ]]; then + fail "This script must be run as root (for iscsiadm/mount)" +fi + +if ! command -v iscsiadm &>/dev/null; then + fail "iscsiadm not found. Install open-iscsi: apt install open-iscsi" +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TARGET_BIN="${SCRIPT_DIR}/iscsi-target" +if [[ ! -x "$TARGET_BIN" ]]; then + # Try the repo root + TARGET_BIN="$(cd "$SCRIPT_DIR/../../../../../.." && pwd)/iscsi-target" +fi +if [[ ! -x "$TARGET_BIN" ]]; then + fail "iscsi-target binary not found. Build with: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/" +fi + +# ------------------------------------------------------- +# Step 1: Start iSCSI target +# ------------------------------------------------------- +log "Starting iSCSI target..." +"$TARGET_BIN" \ + -create \ + -vol "$VOL_FILE" \ + -size 100M \ + -addr "${TARGET_ADDR}:${TARGET_PORT}" \ + -iqn "$TARGET_IQN" & +TARGET_PID=$! +sleep 1 + +if ! kill -0 "$TARGET_PID" 2>/dev/null; then + fail "Target failed to start" +fi +log "Target started (PID $TARGET_PID)" + +# ------------------------------------------------------- +# Step 2: Discovery +# ------------------------------------------------------- +log "Running iscsiadm discovery..." +DISCOVERY=$(iscsiadm -m discovery -t sendtargets -p "${TARGET_ADDR}:${TARGET_PORT}" 2>&1) || { + fail "Discovery failed: $DISCOVERY" +} +echo "$DISCOVERY" + +if ! echo "$DISCOVERY" | grep -q "$TARGET_IQN"; then + fail "Target IQN not found in discovery output" +fi +log "Discovery OK" + +# ------------------------------------------------------- +# Step 3: Login +# ------------------------------------------------------- +log "Logging in to target..." +iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --login || { + fail "Login failed" +} +log "Login OK" + +# Wait for device to appear +sleep 2 +ISCSI_DEV=$(iscsiadm -m session -P 3 2>/dev/null | grep "Attached scsi disk" | awk '{print $4}' | head -1) +if [[ -z "$ISCSI_DEV" ]]; then + warn "Could not determine attached device — trying /dev/sdb" + ISCSI_DEV="sdb" +fi +DEV_PATH="/dev/$ISCSI_DEV" +log "Attached device: $DEV_PATH" + +if [[ ! -b "$DEV_PATH" ]]; then + fail "Block device $DEV_PATH not found" +fi + +# ------------------------------------------------------- +# Step 4: dd write/read test +# ------------------------------------------------------- +log "Writing test pattern with dd..." +dd if=/dev/urandom of=/tmp/iscsi-smoke-pattern bs=1M count=1 2>/dev/null +dd if=/tmp/iscsi-smoke-pattern of="$DEV_PATH" bs=1M count=1 oflag=direct 2>/dev/null || { + fail "dd write failed" +} + +log "Reading back and verifying..." +dd if="$DEV_PATH" of=/tmp/iscsi-smoke-readback bs=1M count=1 iflag=direct 2>/dev/null || { + fail "dd read failed" +} + +if cmp -s /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback; then + log "Data integrity verified (dd write/read OK)" +else + fail "Data integrity check failed!" +fi +rm -f /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback + +# ------------------------------------------------------- +# Step 5: mkfs + mount + file I/O +# ------------------------------------------------------- +log "Creating ext4 filesystem..." +mkfs.ext4 -F -q "$DEV_PATH" || { + warn "mkfs.ext4 failed — skipping filesystem tests" + # Still consider the test a pass if dd worked + log "SMOKE TEST PASSED (dd only, mkfs skipped)" + exit 0 +} + +mkdir -p "$MOUNT_POINT" +mount "$DEV_PATH" "$MOUNT_POINT" || { + fail "mount failed" +} +log "Mounted at $MOUNT_POINT" + +# Write a file +echo "SeaweedFS iSCSI smoke test" > "$MOUNT_POINT/test.txt" +sync + +# Read it back +CONTENT=$(cat "$MOUNT_POINT/test.txt") +if [[ "$CONTENT" != "SeaweedFS iSCSI smoke test" ]]; then + fail "File content mismatch" +fi +log "File I/O verified" + +# ------------------------------------------------------- +# Step 6: Logout +# ------------------------------------------------------- +log "Unmounting and logging out..." +umount "$MOUNT_POINT" +iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout || { + fail "Logout failed" +} +log "Logout OK" + +# Verify no stale sessions +SESSION_COUNT=$(iscsiadm -m session 2>/dev/null | grep -c "$TARGET_IQN" || true) +if [[ "$SESSION_COUNT" -gt 0 ]]; then + fail "Stale session detected after logout" +fi +log "No stale sessions" + +# ------------------------------------------------------- +# Result +# ------------------------------------------------------- +echo "" +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN} SMOKE TEST PASSED${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo " Discovery: OK" +echo " Login: OK" +echo " dd I/O: OK" +echo " mkfs+mount: OK" +echo " File I/O: OK" +echo " Logout: OK" +echo " Cleanup: OK" +echo "" diff --git a/weed/storage/blockvol/iscsi/dataio.go b/weed/storage/blockvol/iscsi/dataio.go new file mode 100644 index 000000000..223a97085 --- /dev/null +++ b/weed/storage/blockvol/iscsi/dataio.go @@ -0,0 +1,207 @@ +package iscsi + +import ( + "errors" + "io" +) + +var ( + ErrDataSNOrder = errors.New("iscsi: Data-Out DataSN out of order") + ErrDataOverflow = errors.New("iscsi: data exceeds expected transfer length") + ErrDataIncomplete = errors.New("iscsi: data transfer incomplete") +) + +// DataInWriter splits a read response into multiple Data-In PDUs, respecting +// MaxRecvDataSegmentLength. The final PDU carries the S-bit (status) and F-bit. +type DataInWriter struct { + maxSegLen uint32 // negotiated MaxRecvDataSegmentLength +} + +// NewDataInWriter creates a writer with the given max segment length. +func NewDataInWriter(maxSegLen uint32) *DataInWriter { + if maxSegLen == 0 { + maxSegLen = 8192 // sensible default + } + return &DataInWriter{maxSegLen: maxSegLen} +} + +// WriteDataIn splits data into Data-In PDUs and writes them to w. +// itt is the initiator task tag, ttt is the target transfer tag. +// statSN is the current StatSN (incremented when S-bit is set). +// Returns the number of PDUs written. +func (d *DataInWriter) WriteDataIn(w io.Writer, data []byte, itt uint32, expCmdSN, maxCmdSN uint32, statSN *uint32) (int, error) { + totalLen := uint32(len(data)) + if totalLen == 0 { + // Zero-length read — send single Data-In with S-bit, no data + pdu := &PDU{} + pdu.SetOpcode(OpSCSIDataIn) + pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status + pdu.SetInitiatorTaskTag(itt) + pdu.SetTargetTransferTag(0xFFFFFFFF) + pdu.SetStatSN(*statSN) + *statSN++ + pdu.SetExpCmdSN(expCmdSN) + pdu.SetMaxCmdSN(maxCmdSN) + pdu.SetDataSN(0) + pdu.SetSCSIStatus(SCSIStatusGood) + if err := WritePDU(w, pdu); err != nil { + return 0, err + } + return 1, nil + } + + var offset uint32 + var dataSN uint32 + count := 0 + + for offset < totalLen { + segLen := d.maxSegLen + if offset+segLen > totalLen { + segLen = totalLen - offset + } + isFinal := (offset + segLen) >= totalLen + + pdu := &PDU{} + pdu.SetOpcode(OpSCSIDataIn) + pdu.SetInitiatorTaskTag(itt) + pdu.SetTargetTransferTag(0xFFFFFFFF) + pdu.SetExpCmdSN(expCmdSN) + pdu.SetMaxCmdSN(maxCmdSN) + pdu.SetDataSN(dataSN) + pdu.SetBufferOffset(offset) + + pdu.DataSegment = data[offset : offset+segLen] + + if isFinal { + pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status + pdu.SetStatSN(*statSN) + *statSN++ + pdu.SetSCSIStatus(SCSIStatusGood) + } else { + pdu.SetOpSpecific1(0) // no flags + } + + if err := WritePDU(w, pdu); err != nil { + return count, err + } + count++ + dataSN++ + offset += segLen + } + + return count, nil +} + +// DataOutCollector collects Data-Out PDUs for a single write command, +// assembling the full data buffer from potentially multiple PDUs. +// It handles both immediate data and R2T-solicited data. +type DataOutCollector struct { + expectedLen uint32 + buf []byte + received uint32 + nextDataSN uint32 + done bool +} + +// NewDataOutCollector creates a collector expecting the given total transfer length. +func NewDataOutCollector(expectedLen uint32) *DataOutCollector { + return &DataOutCollector{ + expectedLen: expectedLen, + buf: make([]byte, expectedLen), + } +} + +// AddImmediateData adds the data from the SCSI Command PDU (immediate data). +func (c *DataOutCollector) AddImmediateData(data []byte) error { + if uint32(len(data)) > c.expectedLen { + return ErrDataOverflow + } + copy(c.buf, data) + c.received += uint32(len(data)) + if c.received >= c.expectedLen { + c.done = true + } + return nil +} + +// AddDataOut processes a Data-Out PDU and adds its data to the buffer. +func (c *DataOutCollector) AddDataOut(pdu *PDU) error { + dataSN := pdu.DataSN() + if dataSN != c.nextDataSN { + return ErrDataSNOrder + } + c.nextDataSN++ + + offset := pdu.BufferOffset() + data := pdu.DataSegment + end := offset + uint32(len(data)) + + if end > c.expectedLen { + return ErrDataOverflow + } + + copy(c.buf[offset:], data) + c.received += uint32(len(data)) + + // Check F-bit + if pdu.OpSpecific1()&FlagF != 0 { + c.done = true + } + + return nil +} + +// Done returns true if all expected data has been received. +func (c *DataOutCollector) Done() bool { return c.done } + +// Data returns the assembled data buffer. +func (c *DataOutCollector) Data() []byte { return c.buf } + +// Remaining returns how many bytes are still needed. +func (c *DataOutCollector) Remaining() uint32 { + if c.received >= c.expectedLen { + return 0 + } + return c.expectedLen - c.received +} + +// BuildR2T creates an R2T PDU requesting more data from the initiator. +func BuildR2T(itt, ttt uint32, r2tSN uint32, bufferOffset, desiredLen uint32, statSN, expCmdSN, maxCmdSN uint32) *PDU { + pdu := &PDU{} + pdu.SetOpcode(OpR2T) + pdu.SetOpSpecific1(FlagF) // always Final for R2T + pdu.SetInitiatorTaskTag(itt) + pdu.SetTargetTransferTag(ttt) + pdu.SetStatSN(statSN) + pdu.SetExpCmdSN(expCmdSN) + pdu.SetMaxCmdSN(maxCmdSN) + pdu.SetR2TSN(r2tSN) + pdu.SetBufferOffset(bufferOffset) + pdu.SetDesiredDataLength(desiredLen) + return pdu +} + +// SendSCSIResponse sends a SCSI Response PDU with optional sense data. +func SendSCSIResponse(w io.Writer, result SCSIResult, itt uint32, statSN *uint32, expCmdSN, maxCmdSN uint32) error { + pdu := &PDU{} + pdu.SetOpcode(OpSCSIResp) + pdu.SetOpSpecific1(FlagF) // Final + pdu.SetSCSIResponse(ISCSIRespCompleted) + pdu.SetSCSIStatus(result.Status) + pdu.SetInitiatorTaskTag(itt) + pdu.SetStatSN(*statSN) + *statSN++ + pdu.SetExpCmdSN(expCmdSN) + pdu.SetMaxCmdSN(maxCmdSN) + + if result.Status == SCSIStatusCheckCond { + senseData := BuildSenseData(result.SenseKey, result.SenseASC, result.SenseASCQ) + // Sense data is wrapped in a 2-byte length prefix + pdu.DataSegment = make([]byte, 2+len(senseData)) + pdu.DataSegment[0] = byte(len(senseData) >> 8) + pdu.DataSegment[1] = byte(len(senseData)) + copy(pdu.DataSegment[2:], senseData) + } + + return WritePDU(w, pdu) +} diff --git a/weed/storage/blockvol/iscsi/dataio_test.go b/weed/storage/blockvol/iscsi/dataio_test.go new file mode 100644 index 000000000..26b34d4d0 --- /dev/null +++ b/weed/storage/blockvol/iscsi/dataio_test.go @@ -0,0 +1,410 @@ +package iscsi + +import ( + "bytes" + "testing" +) + +func TestDataIO(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"datain_single_pdu", testDataInSinglePDU}, + {"datain_multi_pdu", testDataInMultiPDU}, + {"datain_exact_boundary", testDataInExactBoundary}, + {"datain_zero_length", testDataInZeroLength}, + {"datain_datasn_ordering", testDataInDataSNOrdering}, + {"datain_fbit_sbit", testDataInFbitSbit}, + {"dataout_single_pdu", testDataOutSinglePDU}, + {"dataout_multi_pdu", testDataOutMultiPDU}, + {"dataout_immediate_data", testDataOutImmediateData}, + {"dataout_immediate_plus_r2t", testDataOutImmediatePlusR2T}, + {"dataout_wrong_datasn", testDataOutWrongDataSN}, + {"dataout_overflow", testDataOutOverflow}, + {"r2t_build", testR2TBuild}, + {"scsi_response_good", testSCSIResponseGood}, + {"scsi_response_check_condition", testSCSIResponseCheckCondition}, + {"datain_statsn_increment", testDataInStatSNIncrement}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testDataInSinglePDU(t *testing.T) { + w := &bytes.Buffer{} + dw := NewDataInWriter(8192) + data := bytes.Repeat([]byte{0xAA}, 4096) + statSN := uint32(1) + + n, err := dw.WriteDataIn(w, data, 0x100, 1, 10, &statSN) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1 PDU, got %d", n) + } + if statSN != 2 { + t.Fatalf("StatSN should be 2, got %d", statSN) + } + + pdu, err := ReadPDU(w) + if err != nil { + t.Fatal(err) + } + if pdu.Opcode() != OpSCSIDataIn { + t.Fatal("wrong opcode") + } + if pdu.OpSpecific1()&FlagF == 0 { + t.Fatal("F-bit not set") + } + if pdu.OpSpecific1()&FlagS == 0 { + t.Fatal("S-bit not set on final PDU") + } + if len(pdu.DataSegment) != 4096 { + t.Fatalf("data length: %d", len(pdu.DataSegment)) + } +} + +func testDataInMultiPDU(t *testing.T) { + w := &bytes.Buffer{} + dw := NewDataInWriter(1024) // small segment + data := bytes.Repeat([]byte{0xBB}, 3000) + statSN := uint32(1) + + n, err := dw.WriteDataIn(w, data, 0x200, 1, 10, &statSN) + if err != nil { + t.Fatal(err) + } + // 3000 / 1024 = 3 PDUs (1024 + 1024 + 952) + if n != 3 { + t.Fatalf("expected 3 PDUs, got %d", n) + } + + // Read them back + var reassembled []byte + for i := 0; i < 3; i++ { + pdu, err := ReadPDU(w) + if err != nil { + t.Fatalf("PDU %d: %v", i, err) + } + if pdu.DataSN() != uint32(i) { + t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN()) + } + reassembled = append(reassembled, pdu.DataSegment...) + + if i < 2 { + if pdu.OpSpecific1()&FlagF != 0 { + t.Fatalf("PDU %d should not have F-bit", i) + } + } else { + if pdu.OpSpecific1()&FlagF == 0 { + t.Fatal("last PDU should have F-bit") + } + if pdu.OpSpecific1()&FlagS == 0 { + t.Fatal("last PDU should have S-bit") + } + } + } + + if !bytes.Equal(reassembled, data) { + t.Fatal("reassembled data mismatch") + } +} + +func testDataInExactBoundary(t *testing.T) { + w := &bytes.Buffer{} + dw := NewDataInWriter(1024) + data := bytes.Repeat([]byte{0xCC}, 2048) // exact 2 PDUs + statSN := uint32(1) + + n, err := dw.WriteDataIn(w, data, 0x300, 1, 10, &statSN) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 PDUs, got %d", n) + } +} + +func testDataInZeroLength(t *testing.T) { + w := &bytes.Buffer{} + dw := NewDataInWriter(8192) + statSN := uint32(5) + + n, err := dw.WriteDataIn(w, nil, 0x400, 1, 10, &statSN) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1 PDU for zero-length, got %d", n) + } + if statSN != 6 { + t.Fatal("StatSN should still increment") + } +} + +func testDataInDataSNOrdering(t *testing.T) { + w := &bytes.Buffer{} + dw := NewDataInWriter(512) + data := bytes.Repeat([]byte{0xDD}, 2048) // 4 PDUs + statSN := uint32(1) + + dw.WriteDataIn(w, data, 0x500, 1, 10, &statSN) + + for i := 0; i < 4; i++ { + pdu, _ := ReadPDU(w) + if pdu.DataSN() != uint32(i) { + t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN()) + } + expectedOffset := uint32(i) * 512 + if pdu.BufferOffset() != expectedOffset { + t.Fatalf("PDU %d: offset=%d, expected %d", i, pdu.BufferOffset(), expectedOffset) + } + } +} + +func testDataInFbitSbit(t *testing.T) { + w := &bytes.Buffer{} + dw := NewDataInWriter(1000) + data := bytes.Repeat([]byte{0xEE}, 2500) + statSN := uint32(1) + + dw.WriteDataIn(w, data, 0x600, 1, 10, &statSN) + + for i := 0; i < 3; i++ { + pdu, _ := ReadPDU(w) + flags := pdu.OpSpecific1() + if i < 2 { + if flags&FlagF != 0 || flags&FlagS != 0 { + t.Fatalf("PDU %d: should have no F/S bits", i) + } + } else { + if flags&FlagF == 0 || flags&FlagS == 0 { + t.Fatal("last PDU must have F+S bits") + } + } + } +} + +func testDataOutSinglePDU(t *testing.T) { + c := NewDataOutCollector(4096) + + pdu := &PDU{} + pdu.SetOpcode(OpSCSIDataOut) + pdu.SetOpSpecific1(FlagF) + pdu.SetDataSN(0) + pdu.SetBufferOffset(0) + pdu.DataSegment = bytes.Repeat([]byte{0x11}, 4096) + + if err := c.AddDataOut(pdu); err != nil { + t.Fatal(err) + } + if !c.Done() { + t.Fatal("should be done") + } + if c.Remaining() != 0 { + t.Fatal("remaining should be 0") + } +} + +func testDataOutMultiPDU(t *testing.T) { + c := NewDataOutCollector(8192) + + for i := 0; i < 2; i++ { + pdu := &PDU{} + pdu.SetOpcode(OpSCSIDataOut) + pdu.SetDataSN(uint32(i)) + pdu.SetBufferOffset(uint32(i) * 4096) + pdu.DataSegment = bytes.Repeat([]byte{byte(i + 1)}, 4096) + if i == 1 { + pdu.SetOpSpecific1(FlagF) + } + if err := c.AddDataOut(pdu); err != nil { + t.Fatalf("PDU %d: %v", i, err) + } + } + + if !c.Done() { + t.Fatal("should be done") + } + data := c.Data() + if data[0] != 0x01 || data[4096] != 0x02 { + t.Fatal("data assembly wrong") + } +} + +func testDataOutImmediateData(t *testing.T) { + c := NewDataOutCollector(4096) + err := c.AddImmediateData(bytes.Repeat([]byte{0xFF}, 4096)) + if err != nil { + t.Fatal(err) + } + if !c.Done() { + t.Fatal("should be done with immediate data") + } +} + +func testDataOutImmediatePlusR2T(t *testing.T) { + c := NewDataOutCollector(8192) + + // Immediate: first 4096 + err := c.AddImmediateData(bytes.Repeat([]byte{0xAA}, 4096)) + if err != nil { + t.Fatal(err) + } + if c.Done() { + t.Fatal("should not be done yet") + } + if c.Remaining() != 4096 { + t.Fatalf("remaining: %d", c.Remaining()) + } + + // R2T-solicited Data-Out: next 4096 + pdu := &PDU{} + pdu.SetOpcode(OpSCSIDataOut) + pdu.SetOpSpecific1(FlagF) + pdu.SetDataSN(0) + pdu.SetBufferOffset(4096) + pdu.DataSegment = bytes.Repeat([]byte{0xBB}, 4096) + if err := c.AddDataOut(pdu); err != nil { + t.Fatal(err) + } + if !c.Done() { + t.Fatal("should be done") + } + + data := c.Data() + if data[0] != 0xAA || data[4096] != 0xBB { + t.Fatal("assembly wrong") + } +} + +func testDataOutWrongDataSN(t *testing.T) { + c := NewDataOutCollector(8192) + + pdu := &PDU{} + pdu.SetOpcode(OpSCSIDataOut) + pdu.SetDataSN(1) // should be 0 + pdu.SetBufferOffset(0) + pdu.DataSegment = make([]byte, 4096) + + err := c.AddDataOut(pdu) + if err != ErrDataSNOrder { + t.Fatalf("expected ErrDataSNOrder, got %v", err) + } +} + +func testDataOutOverflow(t *testing.T) { + c := NewDataOutCollector(4096) + + pdu := &PDU{} + pdu.SetOpcode(OpSCSIDataOut) + pdu.SetDataSN(0) + pdu.SetBufferOffset(0) + pdu.DataSegment = make([]byte, 8192) // more than expected + + err := c.AddDataOut(pdu) + if err != ErrDataOverflow { + t.Fatalf("expected ErrDataOverflow, got %v", err) + } +} + +func testR2TBuild(t *testing.T) { + pdu := BuildR2T(0x100, 0x200, 0, 4096, 4096, 5, 10, 20) + if pdu.Opcode() != OpR2T { + t.Fatal("wrong opcode") + } + if pdu.InitiatorTaskTag() != 0x100 { + t.Fatal("ITT wrong") + } + if pdu.TargetTransferTag() != 0x200 { + t.Fatal("TTT wrong") + } + if pdu.R2TSN() != 0 { + t.Fatal("R2TSN wrong") + } + if pdu.BufferOffset() != 4096 { + t.Fatal("offset wrong") + } + if pdu.DesiredDataLength() != 4096 { + t.Fatal("desired length wrong") + } + if pdu.StatSN() != 5 { + t.Fatal("StatSN wrong") + } +} + +func testSCSIResponseGood(t *testing.T) { + w := &bytes.Buffer{} + statSN := uint32(10) + result := SCSIResult{Status: SCSIStatusGood} + err := SendSCSIResponse(w, result, 0x300, &statSN, 5, 15) + if err != nil { + t.Fatal(err) + } + if statSN != 11 { + t.Fatal("StatSN not incremented") + } + + pdu, err := ReadPDU(w) + if err != nil { + t.Fatal(err) + } + if pdu.Opcode() != OpSCSIResp { + t.Fatal("wrong opcode") + } + if pdu.SCSIStatus() != SCSIStatusGood { + t.Fatal("status wrong") + } + if len(pdu.DataSegment) != 0 { + t.Fatal("no data expected for good status") + } +} + +func testSCSIResponseCheckCondition(t *testing.T) { + w := &bytes.Buffer{} + statSN := uint32(20) + result := SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: SenseIllegalRequest, + SenseASC: ASCInvalidOpcode, + SenseASCQ: ASCQLuk, + } + err := SendSCSIResponse(w, result, 0x400, &statSN, 5, 15) + if err != nil { + t.Fatal(err) + } + + pdu, err := ReadPDU(w) + if err != nil { + t.Fatal(err) + } + if pdu.SCSIStatus() != SCSIStatusCheckCond { + t.Fatal("status wrong") + } + // Data segment should contain sense data with 2-byte length prefix + if len(pdu.DataSegment) < 20 { // 2 + 18 + t.Fatalf("sense data too short: %d", len(pdu.DataSegment)) + } + senseLen := int(pdu.DataSegment[0])<<8 | int(pdu.DataSegment[1]) + if senseLen != 18 { + t.Fatalf("sense length: %d", senseLen) + } +} + +func testDataInStatSNIncrement(t *testing.T) { + w := &bytes.Buffer{} + dw := NewDataInWriter(1024) + data := bytes.Repeat([]byte{0x00}, 3072) // 3 PDUs + statSN := uint32(100) + + dw.WriteDataIn(w, data, 0x700, 1, 10, &statSN) + // Only the final PDU has S-bit, so StatSN increments once + if statSN != 101 { + t.Fatalf("StatSN should be 101, got %d", statSN) + } +} diff --git a/weed/storage/blockvol/iscsi/discovery.go b/weed/storage/blockvol/iscsi/discovery.go new file mode 100644 index 000000000..e497e2c54 --- /dev/null +++ b/weed/storage/blockvol/iscsi/discovery.go @@ -0,0 +1,68 @@ +package iscsi + +// HandleTextRequest processes a Text Request PDU. +// Currently supports SendTargets discovery (RFC 7143, Section 12.3). +func HandleTextRequest(req *PDU, targets []DiscoveryTarget) *PDU { + resp := &PDU{} + resp.SetOpcode(OpTextResp) + resp.SetOpSpecific1(FlagF) // Final + resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) + resp.SetTargetTransferTag(0xFFFFFFFF) // no continuation + + params, err := ParseParams(req.DataSegment) + if err != nil { + // Malformed — return empty response + return resp + } + + val, ok := params.Get("SendTargets") + if !ok { + return resp + } + + // Discovery responses can have duplicate keys (TargetName appears + // for each target), so we use EncodeDiscoveryTargets directly. + var matched []DiscoveryTarget + + switch val { + case "All": + matched = targets + default: + for _, tgt := range targets { + if tgt.Name == val { + matched = append(matched, tgt) + break + } + } + } + + if data := EncodeDiscoveryTargets(matched); len(data) > 0 { + resp.DataSegment = data + } + return resp +} + +// DiscoveryTarget represents a target available for discovery. +type DiscoveryTarget struct { + Name string // IQN, e.g., "iqn.2024.com.seaweedfs:vol1" + Address string // IP:port,portal-group, e.g., "10.0.0.1:3260,1" +} + +// EncodeDiscoveryTargets encodes multiple targets into text parameter format. +// Each target produces TargetName=\0TargetAddress=\0 +// This handles the special multi-value encoding where the same key can appear +// multiple times (unlike normal negotiation params). +func EncodeDiscoveryTargets(targets []DiscoveryTarget) []byte { + if len(targets) == 0 { + return nil + } + + var buf []byte + for _, tgt := range targets { + buf = append(buf, "TargetName="+tgt.Name+"\x00"...) + if tgt.Address != "" { + buf = append(buf, "TargetAddress="+tgt.Address+"\x00"...) + } + } + return buf +} diff --git a/weed/storage/blockvol/iscsi/discovery_test.go b/weed/storage/blockvol/iscsi/discovery_test.go new file mode 100644 index 000000000..1aa1b2701 --- /dev/null +++ b/weed/storage/blockvol/iscsi/discovery_test.go @@ -0,0 +1,213 @@ +package iscsi + +import ( + "strings" + "testing" +) + +func TestDiscovery(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"send_targets_all", testSendTargetsAll}, + {"send_targets_specific", testSendTargetsSpecific}, + {"send_targets_not_found", testSendTargetsNotFound}, + {"send_targets_empty_list", testSendTargetsEmptyList}, + {"no_send_targets_key", testNoSendTargetsKey}, + {"malformed_text_request", testMalformedTextRequest}, + {"special_chars_in_iqn", testSpecialCharsInIQN}, + {"multiple_targets", testMultipleTargets}, + {"encode_discovery_targets", testEncodeDiscoveryTargets}, + {"encode_empty_targets", testEncodeEmptyTargets}, + {"target_without_address", testTargetWithoutAddress}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +var testTargets = []DiscoveryTarget{ + {Name: "iqn.2024.com.seaweedfs:vol1", Address: "10.0.0.1:3260,1"}, + {Name: "iqn.2024.com.seaweedfs:vol2", Address: "10.0.0.2:3260,1"}, +} + +func makeTextReq(params *Params) *PDU { + p := &PDU{} + p.SetOpcode(OpTextReq) + p.SetOpSpecific1(FlagF) + p.SetInitiatorTaskTag(0x1234) + if params != nil { + p.DataSegment = params.Encode() + } + return p +} + +func testSendTargetsAll(t *testing.T) { + params := NewParams() + params.Set("SendTargets", "All") + req := makeTextReq(params) + + resp := HandleTextRequest(req, testTargets) + + if resp.Opcode() != OpTextResp { + t.Fatalf("opcode: 0x%02x", resp.Opcode()) + } + if resp.InitiatorTaskTag() != 0x1234 { + t.Fatal("ITT mismatch") + } + + body := string(resp.DataSegment) + if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol1") { + t.Fatal("missing vol1") + } + if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol2") { + t.Fatal("missing vol2") + } + if !strings.Contains(body, "10.0.0.1:3260,1") { + t.Fatal("missing vol1 address") + } +} + +func testSendTargetsSpecific(t *testing.T) { + params := NewParams() + params.Set("SendTargets", "iqn.2024.com.seaweedfs:vol2") + req := makeTextReq(params) + + resp := HandleTextRequest(req, testTargets) + body := string(resp.DataSegment) + if !strings.Contains(body, "vol2") { + t.Fatal("missing vol2") + } + // Should not contain vol1 + if strings.Contains(body, "vol1") { + t.Fatal("should not contain vol1") + } +} + +func testSendTargetsNotFound(t *testing.T) { + params := NewParams() + params.Set("SendTargets", "iqn.2024.com.seaweedfs:nonexistent") + req := makeTextReq(params) + + resp := HandleTextRequest(req, testTargets) + if len(resp.DataSegment) != 0 { + t.Fatalf("expected empty response, got %q", resp.DataSegment) + } +} + +func testSendTargetsEmptyList(t *testing.T) { + params := NewParams() + params.Set("SendTargets", "All") + req := makeTextReq(params) + + resp := HandleTextRequest(req, nil) + if len(resp.DataSegment) != 0 { + t.Fatalf("expected empty response, got %q", resp.DataSegment) + } +} + +func testNoSendTargetsKey(t *testing.T) { + params := NewParams() + params.Set("SomethingElse", "value") + req := makeTextReq(params) + + resp := HandleTextRequest(req, testTargets) + if len(resp.DataSegment) != 0 { + t.Fatal("expected empty response for non-SendTargets request") + } +} + +func testMalformedTextRequest(t *testing.T) { + req := &PDU{} + req.SetOpcode(OpTextReq) + req.SetInitiatorTaskTag(0x5678) + req.DataSegment = []byte("not a valid param format") // no '=' + + resp := HandleTextRequest(req, testTargets) + // Should return empty response without error + if resp.Opcode() != OpTextResp { + t.Fatalf("opcode: 0x%02x", resp.Opcode()) + } + if resp.InitiatorTaskTag() != 0x5678 { + t.Fatal("ITT mismatch") + } +} + +func testSpecialCharsInIQN(t *testing.T) { + targets := []DiscoveryTarget{ + {Name: "iqn.2024-01.com.example:storage.tape1.sys1.xyz", Address: "192.168.1.100:3260,1"}, + } + params := NewParams() + params.Set("SendTargets", "All") + req := makeTextReq(params) + + resp := HandleTextRequest(req, targets) + body := string(resp.DataSegment) + if !strings.Contains(body, "iqn.2024-01.com.example:storage.tape1.sys1.xyz") { + t.Fatal("IQN with special chars not found") + } +} + +func testMultipleTargets(t *testing.T) { + targets := make([]DiscoveryTarget, 10) + for i := range targets { + targets[i] = DiscoveryTarget{ + Name: "iqn.2024.com.test:vol" + string(rune('0'+i)), + Address: "10.0.0.1:3260,1", + } + } + + params := NewParams() + params.Set("SendTargets", "All") + req := makeTextReq(params) + + resp := HandleTextRequest(req, targets) + body := string(resp.DataSegment) + // The response uses EncodeDiscoveryTargets internally via the params, + // but since Params doesn't allow duplicate keys, the last one wins. + // This is a known limitation — for multi-target discovery, we use + // EncodeDiscoveryTargets directly. Let's verify at least the last target. + if !strings.Contains(body, "TargetName=") { + t.Fatal("no TargetName in response") + } +} + +func testEncodeDiscoveryTargets(t *testing.T) { + encoded := EncodeDiscoveryTargets(testTargets) + body := string(encoded) + + // Should contain both targets with proper format + if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") { + t.Fatal("missing vol1") + } + if !strings.Contains(body, "TargetAddress=10.0.0.1:3260,1\x00") { + t.Fatal("missing vol1 address") + } + if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol2\x00") { + t.Fatal("missing vol2") + } +} + +func testEncodeEmptyTargets(t *testing.T) { + encoded := EncodeDiscoveryTargets(nil) + if encoded != nil { + t.Fatalf("expected nil, got %q", encoded) + } +} + +func testTargetWithoutAddress(t *testing.T) { + targets := []DiscoveryTarget{ + {Name: "iqn.2024.com.seaweedfs:vol1"}, // no address + } + encoded := EncodeDiscoveryTargets(targets) + body := string(encoded) + if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") { + t.Fatal("missing target name") + } + if strings.Contains(body, "TargetAddress") { + t.Fatal("should not have TargetAddress") + } +} diff --git a/weed/storage/blockvol/iscsi/integration_test.go b/weed/storage/blockvol/iscsi/integration_test.go new file mode 100644 index 000000000..004af73ba --- /dev/null +++ b/weed/storage/blockvol/iscsi/integration_test.go @@ -0,0 +1,412 @@ +package iscsi_test + +import ( + "bytes" + "encoding/binary" + "io" + "log" + "net" + "path/filepath" + "testing" + "time" + + "github.com/seaweedfs/seaweedfs/weed/storage/blockvol" + "github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi" +) + +// blockVolAdapter wraps a BlockVol to implement iscsi.BlockDevice. +type blockVolAdapter struct { + vol *blockvol.BlockVol +} + +func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) { + return a.vol.ReadLBA(lba, length) +} +func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error { + return a.vol.WriteLBA(lba, data) +} +func (a *blockVolAdapter) Trim(lba uint64, length uint32) error { + return a.vol.Trim(lba, length) +} +func (a *blockVolAdapter) SyncCache() error { + return a.vol.SyncCache() +} +func (a *blockVolAdapter) BlockSize() uint32 { + return a.vol.Info().BlockSize +} +func (a *blockVolAdapter) VolumeSize() uint64 { + return a.vol.Info().VolumeSize +} +func (a *blockVolAdapter) IsHealthy() bool { + return a.vol.Info().Healthy +} + +const ( + intTargetName = "iqn.2024.com.seaweedfs:integration" + intInitiatorName = "iqn.2024.com.test:client" +) + +func createTestVol(t *testing.T) *blockvol.BlockVol { + t.Helper() + path := filepath.Join(t.TempDir(), "test.blk") + vol, err := blockvol.CreateBlockVol(path, blockvol.CreateOptions{ + VolumeSize: 1024 * 4096, // 1024 blocks = 4MB + BlockSize: 4096, + WALSize: 1024 * 1024, // 1MB WAL + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { vol.Close() }) + return vol +} + +func setupIntegrationTarget(t *testing.T) (net.Conn, *iscsi.TargetServer) { + t.Helper() + vol := createTestVol(t) + adapter := &blockVolAdapter{vol: vol} + + config := iscsi.DefaultTargetConfig() + config.TargetName = intTargetName + logger := log.New(io.Discard, "", 0) + ts := iscsi.NewTargetServer("127.0.0.1:0", config, logger) + ts.AddVolume(intTargetName, adapter) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + go ts.Serve(ln) + t.Cleanup(func() { ts.Close() }) + + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // Login + params := iscsi.NewParams() + params.Set("InitiatorName", intInitiatorName) + params.Set("TargetName", intTargetName) + params.Set("SessionType", "Normal") + + loginReq := &iscsi.PDU{} + loginReq.SetOpcode(iscsi.OpLoginReq) + loginReq.SetLoginStages(iscsi.StageSecurityNeg, iscsi.StageFullFeature) + loginReq.SetLoginTransit(true) + loginReq.SetISID([6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01}) + loginReq.SetCmdSN(1) + loginReq.DataSegment = params.Encode() + + if err := iscsi.WritePDU(conn, loginReq); err != nil { + t.Fatal(err) + } + resp, err := iscsi.ReadPDU(conn) + if err != nil { + t.Fatal(err) + } + if resp.LoginStatusClass() != iscsi.LoginStatusSuccess { + t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) + } + + return conn, ts +} + +func TestIntegration(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"write_read_verify", testWriteReadVerify}, + {"multi_block_write_read", testMultiBlockWriteRead}, + {"inquiry_capacity", testInquiryCapacity}, + {"sync_cache", testIntSyncCache}, + {"test_unit_ready", testIntTestUnitReady}, + {"concurrent_readers_writers", testConcurrentReadersWriters}, + {"write_at_boundary", testWriteAtBoundary}, + {"unmap_integration", testUnmapIntegration}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func sendSCSICmd(t *testing.T, conn net.Conn, cdb [16]byte, cmdSN uint32, read bool, write bool, dataOut []byte, expLen uint32) *iscsi.PDU { + t.Helper() + cmd := &iscsi.PDU{} + cmd.SetOpcode(iscsi.OpSCSICmd) + flags := uint8(iscsi.FlagF) + if read { + flags |= iscsi.FlagR + } + if write { + flags |= iscsi.FlagW + } + cmd.SetOpSpecific1(flags) + cmd.SetInitiatorTaskTag(cmdSN) + cmd.SetExpectedDataTransferLength(expLen) + cmd.SetCmdSN(cmdSN) + cmd.SetCDB(cdb) + if dataOut != nil { + cmd.DataSegment = dataOut + } + + if err := iscsi.WritePDU(conn, cmd); err != nil { + t.Fatal(err) + } + + resp, err := iscsi.ReadPDU(conn) + if err != nil { + t.Fatal(err) + } + return resp +} + +func testWriteReadVerify(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + // Write pattern to LBA 0 + writeData := make([]byte, 4096) + for i := range writeData { + writeData[i] = byte(i % 251) // prime modulus for pattern + } + + var wCDB [16]byte + wCDB[0] = iscsi.ScsiWrite10 + binary.BigEndian.PutUint32(wCDB[2:6], 0) // LBA 0 + binary.BigEndian.PutUint16(wCDB[7:9], 1) // 1 block + + resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatalf("write failed: status %d", resp.SCSIStatus()) + } + + // Read it back + var rCDB [16]byte + rCDB[0] = iscsi.ScsiRead10 + binary.BigEndian.PutUint32(rCDB[2:6], 0) + binary.BigEndian.PutUint16(rCDB[7:9], 1) + + resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, 4096) + if resp2.Opcode() != iscsi.OpSCSIDataIn { + t.Fatalf("expected Data-In, got %s", iscsi.OpcodeName(resp2.Opcode())) + } + + if !bytes.Equal(resp2.DataSegment, writeData) { + t.Fatal("data integrity verification failed") + } +} + +func testMultiBlockWriteRead(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + // Write 4 blocks at LBA 10 + blockCount := uint16(4) + dataLen := int(blockCount) * 4096 + writeData := make([]byte, dataLen) + for i := range writeData { + writeData[i] = byte((i / 4096) + 1) // each block has different fill + } + + var wCDB [16]byte + wCDB[0] = iscsi.ScsiWrite10 + binary.BigEndian.PutUint32(wCDB[2:6], 10) + binary.BigEndian.PutUint16(wCDB[7:9], blockCount) + + resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, uint32(dataLen)) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatalf("write failed: status %d", resp.SCSIStatus()) + } + + // Read back all 4 blocks + var rCDB [16]byte + rCDB[0] = iscsi.ScsiRead10 + binary.BigEndian.PutUint32(rCDB[2:6], 10) + binary.BigEndian.PutUint16(rCDB[7:9], blockCount) + + resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, uint32(dataLen)) + + // May be split across multiple Data-In PDUs — reassemble + var readData []byte + readData = append(readData, resp2.DataSegment...) + + // If first PDU doesn't have S-bit, read more + for resp2.OpSpecific1()&iscsi.FlagS == 0 { + var err error + resp2, err = iscsi.ReadPDU(conn) + if err != nil { + t.Fatal(err) + } + readData = append(readData, resp2.DataSegment...) + } + + if !bytes.Equal(readData, writeData) { + t.Fatal("multi-block data integrity failed") + } +} + +func testInquiryCapacity(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + // TEST UNIT READY + var tuCDB [16]byte + tuCDB[0] = iscsi.ScsiTestUnitReady + resp := sendSCSICmd(t, conn, tuCDB, 2, false, false, nil, 0) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatal("test unit ready failed") + } + + // INQUIRY + var iqCDB [16]byte + iqCDB[0] = iscsi.ScsiInquiry + binary.BigEndian.PutUint16(iqCDB[3:5], 96) + resp = sendSCSICmd(t, conn, iqCDB, 3, true, false, nil, 96) + if resp.Opcode() != iscsi.OpSCSIDataIn { + t.Fatal("expected Data-In for inquiry") + } + if resp.DataSegment[0] != 0x00 { // SBC device + t.Fatal("not SBC device type") + } + + // READ CAPACITY 10 + var rcCDB [16]byte + rcCDB[0] = iscsi.ScsiReadCapacity10 + resp = sendSCSICmd(t, conn, rcCDB, 4, true, false, nil, 8) + if resp.Opcode() != iscsi.OpSCSIDataIn { + t.Fatal("expected Data-In for read capacity") + } + lastLBA := binary.BigEndian.Uint32(resp.DataSegment[0:4]) + blockSize := binary.BigEndian.Uint32(resp.DataSegment[4:8]) + if lastLBA != 1023 { // 1024 blocks, last = 1023 + t.Fatalf("last LBA: %d, expected 1023", lastLBA) + } + if blockSize != 4096 { + t.Fatalf("block size: %d", blockSize) + } +} + +func testIntSyncCache(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + var cdb [16]byte + cdb[0] = iscsi.ScsiSyncCache10 + resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatalf("sync cache failed: %d", resp.SCSIStatus()) + } +} + +func testIntTestUnitReady(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + var cdb [16]byte + cdb[0] = iscsi.ScsiTestUnitReady + resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatal("TUR failed") + } +} + +func testConcurrentReadersWriters(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + // Sequential writes to different LBAs, then sequential reads to verify + cmdSN := uint32(2) + for lba := uint32(0); lba < 10; lba++ { + data := bytes.Repeat([]byte{byte(lba)}, 4096) + var wCDB [16]byte + wCDB[0] = iscsi.ScsiWrite10 + binary.BigEndian.PutUint32(wCDB[2:6], lba) + binary.BigEndian.PutUint16(wCDB[7:9], 1) + resp := sendSCSICmd(t, conn, wCDB, cmdSN, false, true, data, 4096) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatalf("write LBA %d failed", lba) + } + cmdSN++ + } + + // Read back and verify + for lba := uint32(0); lba < 10; lba++ { + var rCDB [16]byte + rCDB[0] = iscsi.ScsiRead10 + binary.BigEndian.PutUint32(rCDB[2:6], lba) + binary.BigEndian.PutUint16(rCDB[7:9], 1) + resp := sendSCSICmd(t, conn, rCDB, cmdSN, true, false, nil, 4096) + if resp.Opcode() != iscsi.OpSCSIDataIn { + t.Fatalf("LBA %d: expected Data-In", lba) + } + expected := bytes.Repeat([]byte{byte(lba)}, 4096) + if !bytes.Equal(resp.DataSegment, expected) { + t.Fatalf("LBA %d: data mismatch", lba) + } + cmdSN++ + } +} + +func testWriteAtBoundary(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + // Write at last valid LBA (1023) + data := bytes.Repeat([]byte{0xFF}, 4096) + var wCDB [16]byte + wCDB[0] = iscsi.ScsiWrite10 + binary.BigEndian.PutUint32(wCDB[2:6], 1023) + binary.BigEndian.PutUint16(wCDB[7:9], 1) + resp := sendSCSICmd(t, conn, wCDB, 2, false, true, data, 4096) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatal("boundary write failed") + } + + // Write past end — should fail + var wCDB2 [16]byte + wCDB2[0] = iscsi.ScsiWrite10 + binary.BigEndian.PutUint32(wCDB2[2:6], 1024) // out of bounds + binary.BigEndian.PutUint16(wCDB2[7:9], 1) + resp2 := sendSCSICmd(t, conn, wCDB2, 3, false, true, data, 4096) + if resp2.SCSIStatus() == iscsi.SCSIStatusGood { + t.Fatal("OOB write should fail") + } +} + +func testUnmapIntegration(t *testing.T) { + conn, _ := setupIntegrationTarget(t) + + // Write data at LBA 5 + writeData := bytes.Repeat([]byte{0xCC}, 4096) + var wCDB [16]byte + wCDB[0] = iscsi.ScsiWrite10 + binary.BigEndian.PutUint32(wCDB[2:6], 5) + binary.BigEndian.PutUint16(wCDB[7:9], 1) + sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096) + + // UNMAP LBA 5 + unmapData := make([]byte, 24) + binary.BigEndian.PutUint16(unmapData[0:2], 22) + binary.BigEndian.PutUint16(unmapData[2:4], 16) + binary.BigEndian.PutUint64(unmapData[8:16], 5) + binary.BigEndian.PutUint32(unmapData[16:20], 1) + + var uCDB [16]byte + uCDB[0] = iscsi.ScsiUnmap + resp := sendSCSICmd(t, conn, uCDB, 3, false, true, unmapData, uint32(len(unmapData))) + if resp.SCSIStatus() != iscsi.SCSIStatusGood { + t.Fatalf("unmap failed: %d", resp.SCSIStatus()) + } + + // Read back — should be zeros + var rCDB [16]byte + rCDB[0] = iscsi.ScsiRead10 + binary.BigEndian.PutUint32(rCDB[2:6], 5) + binary.BigEndian.PutUint16(rCDB[7:9], 1) + resp2 := sendSCSICmd(t, conn, rCDB, 4, true, false, nil, 4096) + if resp2.Opcode() != iscsi.OpSCSIDataIn { + t.Fatal("expected Data-In") + } + zeros := make([]byte, 4096) + if !bytes.Equal(resp2.DataSegment, zeros) { + t.Fatal("unmapped block should return zeros") + } +} diff --git a/weed/storage/blockvol/iscsi/login.go b/weed/storage/blockvol/iscsi/login.go new file mode 100644 index 000000000..4d42d20ab --- /dev/null +++ b/weed/storage/blockvol/iscsi/login.go @@ -0,0 +1,396 @@ +package iscsi + +import ( + "errors" + "strconv" +) + +// Login status classes (RFC 7143, Section 11.13.5) +const ( + LoginStatusSuccess uint8 = 0x00 + LoginStatusRedirect uint8 = 0x01 + LoginStatusInitiatorErr uint8 = 0x02 + LoginStatusTargetErr uint8 = 0x03 +) + +// Login status details +const ( + LoginDetailSuccess uint8 = 0x00 + LoginDetailTargetMoved uint8 = 0x01 // redirect: permanently moved + LoginDetailTargetMovedTemp uint8 = 0x02 // redirect: temporarily moved + LoginDetailInitiatorError uint8 = 0x00 // initiator error (miscellaneous) + LoginDetailAuthFailure uint8 = 0x01 // authentication failure + LoginDetailAuthorizationFail uint8 = 0x02 // authorization failure + LoginDetailNotFound uint8 = 0x03 // target not found + LoginDetailTargetRemoved uint8 = 0x04 // target removed + LoginDetailUnsupported uint8 = 0x05 // unsupported version + LoginDetailTooManyConns uint8 = 0x06 // too many connections + LoginDetailMissingParam uint8 = 0x07 // missing parameter + LoginDetailNoSessionSlot uint8 = 0x08 // no session slot + LoginDetailNoTCPConn uint8 = 0x09 // no TCP connection available + LoginDetailNoSession uint8 = 0x0a // no existing session + LoginDetailTargetError uint8 = 0x00 // target error (miscellaneous) + LoginDetailServiceUnavail uint8 = 0x01 // service unavailable + LoginDetailOutOfResources uint8 = 0x02 // out of resources +) + +var ( + ErrLoginInvalidStage = errors.New("iscsi: invalid login stage transition") + ErrLoginMissingParam = errors.New("iscsi: missing required login parameter") + ErrLoginInvalidISID = errors.New("iscsi: invalid ISID") + ErrLoginTargetNotFound = errors.New("iscsi: target not found") + ErrLoginSessionExists = errors.New("iscsi: session already exists for this ISID") + ErrLoginInvalidRequest = errors.New("iscsi: invalid login request") +) + +// LoginPhase tracks the current phase of login negotiation. +type LoginPhase int + +const ( + LoginPhaseStart LoginPhase = iota // before first PDU + LoginPhaseSecurity // CSG=SecurityNeg + LoginPhaseOperational // CSG=LoginOp + LoginPhaseDone // transition to FFP complete +) + +// TargetConfig holds the target-side negotiation defaults. +type TargetConfig struct { + TargetName string + TargetAlias string + MaxRecvDataSegmentLength int + MaxBurstLength int + FirstBurstLength int + MaxConnections int + MaxOutstandingR2T int + DefaultTime2Wait int + DefaultTime2Retain int + DataPDUInOrder bool + DataSequenceInOrder bool + InitialR2T bool + ImmediateData bool + ErrorRecoveryLevel int +} + +// DefaultTargetConfig returns sensible defaults for a target. +func DefaultTargetConfig() TargetConfig { + return TargetConfig{ + MaxRecvDataSegmentLength: 262144, // 256KB + MaxBurstLength: 262144, + FirstBurstLength: 65536, + MaxConnections: 1, + MaxOutstandingR2T: 1, + DefaultTime2Wait: 2, + DefaultTime2Retain: 0, + DataPDUInOrder: true, + DataSequenceInOrder: true, + InitialR2T: true, + ImmediateData: true, + ErrorRecoveryLevel: 0, + } +} + +// LoginNegotiator handles the target side of login negotiation. +type LoginNegotiator struct { + config TargetConfig + phase LoginPhase + isid [6]byte + tsih uint16 + targetOK bool // target name validated + + // Negotiated values (updated during negotiation) + NegMaxRecvDataSegLen int + NegMaxBurstLength int + NegFirstBurstLength int + NegInitialR2T bool + NegImmediateData bool + + // Initiator/target info captured during login + InitiatorName string + TargetName string + SessionType string // "Normal" or "Discovery" +} + +// NewLoginNegotiator creates a negotiator for a new login sequence. +func NewLoginNegotiator(config TargetConfig) *LoginNegotiator { + return &LoginNegotiator{ + config: config, + phase: LoginPhaseStart, + NegMaxRecvDataSegLen: config.MaxRecvDataSegmentLength, + NegMaxBurstLength: config.MaxBurstLength, + NegFirstBurstLength: config.FirstBurstLength, + NegInitialR2T: config.InitialR2T, + NegImmediateData: config.ImmediateData, + } +} + +// HandleLoginPDU processes one login request PDU and returns the response PDU. +// It manages stage transitions and parameter negotiation. +func (ln *LoginNegotiator) HandleLoginPDU(req *PDU, resolver TargetResolver) *PDU { + resp := &PDU{} + resp.SetOpcode(OpLoginResp) + resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) + resp.SetISID(req.ISID()) + resp.SetTSIH(req.TSIH()) + + // Validate opcode + if req.Opcode() != OpLoginReq { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) + return resp + } + + csg := req.LoginCSG() + nsg := req.LoginNSG() + transit := req.LoginTransit() + + // Parse text parameters from data segment + params, err := ParseParams(req.DataSegment) + if err != nil { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) + return resp + } + + // Process based on current stage + respParams := NewParams() + + switch csg { + case StageSecurityNeg: + if ln.phase != LoginPhaseStart && ln.phase != LoginPhaseSecurity { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) + return resp + } + ln.phase = LoginPhaseSecurity + + // Capture initiator name + if name, ok := params.Get("InitiatorName"); ok { + ln.InitiatorName = name + } else if ln.InitiatorName == "" { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) + return resp + } + + // Session type + if st, ok := params.Get("SessionType"); ok { + ln.SessionType = st + } + if ln.SessionType == "" { + ln.SessionType = "Normal" + } + + // Target name (required for Normal sessions) + if tn, ok := params.Get("TargetName"); ok { + if ln.SessionType == "Normal" { + if resolver == nil || !resolver.HasTarget(tn) { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailNotFound) + return resp + } + } + ln.TargetName = tn + ln.targetOK = true + } else if ln.SessionType == "Normal" && !ln.targetOK { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) + return resp + } + + // ISID + ln.isid = req.ISID() + + // We don't implement CHAP — declare AuthMethod=None + respParams.Set("AuthMethod", "None") + + if transit { + if nsg == StageLoginOp { + ln.phase = LoginPhaseOperational + } else if nsg == StageFullFeature { + ln.phase = LoginPhaseDone + } + } + + case StageLoginOp: + if ln.phase != LoginPhaseOperational && ln.phase != LoginPhaseSecurity { + // Allow direct jump to LoginOp if security was skipped + if ln.phase == LoginPhaseStart { + // Need InitiatorName at minimum + if name, ok := params.Get("InitiatorName"); ok { + ln.InitiatorName = name + } else if ln.InitiatorName == "" { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) + return resp + } + } else { + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) + return resp + } + } + ln.phase = LoginPhaseOperational + + // Negotiate operational parameters + ln.negotiateParams(params, respParams) + + if transit && nsg == StageFullFeature { + ln.phase = LoginPhaseDone + } + + default: + setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) + return resp + } + + // Build response + resp.SetLoginStages(csg, nsg) + if transit { + resp.SetLoginTransit(true) + } + resp.SetLoginStatus(LoginStatusSuccess, LoginDetailSuccess) + + // Assign TSIH on first successful login + if ln.tsih == 0 { + ln.tsih = 1 // simplified: single session + } + resp.SetTSIH(ln.tsih) + + // Encode response params + if respParams.Len() > 0 { + resp.DataSegment = respParams.Encode() + } + + return resp +} + +// Done returns true if login negotiation is complete. +func (ln *LoginNegotiator) Done() bool { + return ln.phase == LoginPhaseDone +} + +// Phase returns the current login phase. +func (ln *LoginNegotiator) Phase() LoginPhase { + return ln.phase +} + +// negotiateParams processes operational parameter negotiation. +func (ln *LoginNegotiator) negotiateParams(req *Params, resp *Params) { + req.Each(func(key, value string) { + switch key { + case "MaxRecvDataSegmentLength": + if v, err := NegotiateNumber(value, ln.config.MaxRecvDataSegmentLength, 512, 16777215); err == nil { + ln.NegMaxRecvDataSegLen = v + resp.Set(key, strconv.Itoa(ln.config.MaxRecvDataSegmentLength)) + } + case "MaxBurstLength": + if v, err := NegotiateNumber(value, ln.config.MaxBurstLength, 512, 16777215); err == nil { + ln.NegMaxBurstLength = v + resp.Set(key, strconv.Itoa(v)) + } + case "FirstBurstLength": + if v, err := NegotiateNumber(value, ln.config.FirstBurstLength, 512, 16777215); err == nil { + ln.NegFirstBurstLength = v + resp.Set(key, strconv.Itoa(v)) + } + case "InitialR2T": + if v, err := NegotiateBool(value, ln.config.InitialR2T); err == nil { + // InitialR2T uses OR semantics: result is Yes if either side says Yes + if value == "Yes" || ln.config.InitialR2T { + ln.NegInitialR2T = true + } else { + ln.NegInitialR2T = v + } + resp.Set(key, BoolStr(ln.NegInitialR2T)) + } + case "ImmediateData": + if v, err := NegotiateBool(value, ln.config.ImmediateData); err == nil { + ln.NegImmediateData = v + resp.Set(key, BoolStr(v)) + } + case "MaxConnections": + resp.Set(key, strconv.Itoa(ln.config.MaxConnections)) + case "DataPDUInOrder": + resp.Set(key, BoolStr(ln.config.DataPDUInOrder)) + case "DataSequenceInOrder": + resp.Set(key, BoolStr(ln.config.DataSequenceInOrder)) + case "DefaultTime2Wait": + resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Wait)) + case "DefaultTime2Retain": + resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Retain)) + case "MaxOutstandingR2T": + resp.Set(key, strconv.Itoa(ln.config.MaxOutstandingR2T)) + case "ErrorRecoveryLevel": + resp.Set(key, strconv.Itoa(ln.config.ErrorRecoveryLevel)) + case "HeaderDigest": + resp.Set(key, "None") + case "DataDigest": + resp.Set(key, "None") + case "TargetName", "InitiatorName", "SessionType", "AuthMethod": + // Already handled or declarative — skip + case "TargetAlias": + // Informational from initiator — skip + default: + // Unknown keys: respond with NotUnderstood + resp.Set(key, "NotUnderstood") + } + }) + + // Always declare our TargetAlias if configured + if ln.config.TargetAlias != "" { + resp.Set("TargetAlias", ln.config.TargetAlias) + } +} + +// TargetResolver allows the login state machine to check if a target exists. +type TargetResolver interface { + HasTarget(name string) bool +} + +func setLoginReject(resp *PDU, class, detail uint8) { + resp.SetLoginStatus(class, detail) + resp.SetLoginTransit(false) +} + +// LoginResult contains the outcome of a completed login negotiation. +type LoginResult struct { + InitiatorName string + TargetName string + SessionType string + ISID [6]byte + TSIH uint16 + MaxRecvDataSegLen int + MaxBurstLength int + FirstBurstLength int + InitialR2T bool + ImmediateData bool +} + +// Result returns the negotiation outcome. Only valid after Done() returns true. +func (ln *LoginNegotiator) Result() LoginResult { + return LoginResult{ + InitiatorName: ln.InitiatorName, + TargetName: ln.TargetName, + SessionType: ln.SessionType, + ISID: ln.isid, + TSIH: ln.tsih, + MaxRecvDataSegLen: ln.NegMaxRecvDataSegLen, + MaxBurstLength: ln.NegMaxBurstLength, + FirstBurstLength: ln.NegFirstBurstLength, + InitialR2T: ln.NegInitialR2T, + ImmediateData: ln.NegImmediateData, + } +} + +// BuildRedirectResponse creates a login response that redirects the initiator +// to a different target address. +func BuildRedirectResponse(req *PDU, addr string, permanent bool) *PDU { + resp := &PDU{} + resp.SetOpcode(OpLoginResp) + resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) + resp.SetISID(req.ISID()) + + detail := LoginDetailTargetMovedTemp + if permanent { + detail = LoginDetailTargetMoved + } + resp.SetLoginStatus(LoginStatusRedirect, detail) + + params := NewParams() + params.Set("TargetAddress", addr) + resp.DataSegment = params.Encode() + + return resp +} diff --git a/weed/storage/blockvol/iscsi/login_test.go b/weed/storage/blockvol/iscsi/login_test.go new file mode 100644 index 000000000..de7464a14 --- /dev/null +++ b/weed/storage/blockvol/iscsi/login_test.go @@ -0,0 +1,444 @@ +package iscsi + +import ( + "testing" +) + +// mockResolver implements TargetResolver for tests. +type mockResolver struct { + targets map[string]bool +} + +func (m *mockResolver) HasTarget(name string) bool { + return m.targets[name] +} + +func newResolver(names ...string) *mockResolver { + r := &mockResolver{targets: make(map[string]bool)} + for _, n := range names { + r.targets[n] = true + } + return r +} + +const testTargetName = "iqn.2024.com.seaweedfs:vol1" +const testInitiatorName = "iqn.2024.com.test:client1" + +func TestLogin(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"single_pdu_security_to_ffp", testSinglePDUSecurityToFFP}, + {"two_phase_login", testTwoPhaseLogin}, + {"security_to_loginop_to_ffp", testSecurityToLoginOpToFFP}, + {"missing_initiator_name", testMissingInitiatorName}, + {"target_not_found", testTargetNotFound}, + {"discovery_session_no_target", testDiscoverySessionNoTarget}, + {"wrong_opcode", testWrongOpcode}, + {"operational_negotiation", testOperationalNegotiation}, + {"redirect_permanent", testRedirectPermanent}, + {"redirect_temporary", testRedirectTemporary}, + {"login_result", testLoginResult}, + {"unknown_key_not_understood", testUnknownKeyNotUnderstood}, + {"header_data_digest_none", testHeaderDataDigestNone}, + {"duplicate_login_pdu", testDuplicateLoginPDU}, + {"invalid_csg", testInvalidCSG}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func makeLoginReq(csg, nsg uint8, transit bool, params *Params) *PDU { + p := &PDU{} + p.SetOpcode(OpLoginReq) + p.SetLoginStages(csg, nsg) + if transit { + p.SetLoginTransit(true) + } + isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} + p.SetISID(isid) + if params != nil { + p.DataSegment = params.Encode() + } + return p +} + +func testSinglePDUSecurityToFFP(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + params.Set("SessionType", "Normal") + + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + resp := ln.HandleLoginPDU(req, resolver) + + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("status class: %d", resp.LoginStatusClass()) + } + if !resp.LoginTransit() { + t.Fatal("transit should be set") + } + if !ln.Done() { + t.Fatal("should be done") + } + if resp.TSIH() == 0 { + t.Fatal("TSIH should be assigned") + } + + // Verify AuthMethod=None in response + rp, err := ParseParams(resp.DataSegment) + if err != nil { + t.Fatal(err) + } + if v, ok := rp.Get("AuthMethod"); !ok || v != "None" { + t.Fatalf("AuthMethod: %q, %v", v, ok) + } +} + +func testTwoPhaseLogin(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + // Phase 1: Security -> LoginOp + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + params.Set("SessionType", "Normal") + + req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) + resp := ln.HandleLoginPDU(req, resolver) + + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("phase 1 status: %d", resp.LoginStatusClass()) + } + if ln.Done() { + t.Fatal("should not be done yet") + } + + // Phase 2: LoginOp -> FFP + params2 := NewParams() + params2.Set("MaxRecvDataSegmentLength", "65536") + + req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) + resp2 := ln.HandleLoginPDU(req2, resolver) + + if resp2.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("phase 2 status: %d", resp2.LoginStatusClass()) + } + if !ln.Done() { + t.Fatal("should be done") + } +} + +func testSecurityToLoginOpToFFP(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + // Security (no transit) + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + + req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params) + resp := ln.HandleLoginPDU(req, resolver) + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("status: %d", resp.LoginStatusClass()) + } + + // Security -> LoginOp (transit) + req2 := makeLoginReq(StageSecurityNeg, StageLoginOp, true, NewParams()) + resp2 := ln.HandleLoginPDU(req2, resolver) + if resp2.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("status: %d", resp2.LoginStatusClass()) + } + + // LoginOp -> FFP (transit) + req3 := makeLoginReq(StageLoginOp, StageFullFeature, true, NewParams()) + resp3 := ln.HandleLoginPDU(req3, resolver) + if resp3.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("status: %d", resp3.LoginStatusClass()) + } + if !ln.Done() { + t.Fatal("should be done") + } +} + +func testMissingInitiatorName(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + // No InitiatorName + params := NewParams() + params.Set("TargetName", testTargetName) + + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + resp := ln.HandleLoginPDU(req, resolver) + + if resp.LoginStatusClass() != LoginStatusInitiatorErr { + t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass()) + } + if resp.LoginStatusDetail() != LoginDetailMissingParam { + t.Fatalf("expected missing param, got %d", resp.LoginStatusDetail()) + } +} + +func testTargetNotFound(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver() // empty + + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", "iqn.2024.com.seaweedfs:nonexistent") + + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + resp := ln.HandleLoginPDU(req, resolver) + + if resp.LoginStatusClass() != LoginStatusInitiatorErr { + t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass()) + } + if resp.LoginStatusDetail() != LoginDetailNotFound { + t.Fatalf("expected not found, got %d", resp.LoginStatusDetail()) + } +} + +func testDiscoverySessionNoTarget(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("SessionType", "Discovery") + + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + resp := ln.HandleLoginPDU(req, nil) // no resolver needed + + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("status: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) + } + if !ln.Done() { + t.Fatal("should be done") + } + if ln.SessionType != "Discovery" { + t.Fatalf("session type: %q", ln.SessionType) + } +} + +func testWrongOpcode(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + + req := &PDU{} + req.SetOpcode(OpSCSICmd) // wrong + resp := ln.HandleLoginPDU(req, nil) + + if resp.LoginStatusClass() != LoginStatusInitiatorErr { + t.Fatalf("expected error, got %d", resp.LoginStatusClass()) + } +} + +func testOperationalNegotiation(t *testing.T) { + config := DefaultTargetConfig() + config.TargetAlias = "test-target" + ln := NewLoginNegotiator(config) + resolver := newResolver(testTargetName) + + // First: security phase + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) + ln.HandleLoginPDU(req, resolver) + + // Second: operational negotiation + params2 := NewParams() + params2.Set("MaxRecvDataSegmentLength", "65536") + params2.Set("MaxBurstLength", "131072") + params2.Set("FirstBurstLength", "32768") + params2.Set("InitialR2T", "Yes") + params2.Set("ImmediateData", "No") + params2.Set("HeaderDigest", "CRC32C,None") + params2.Set("DataDigest", "CRC32C,None") + + req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) + resp := ln.HandleLoginPDU(req2, resolver) + + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("status: %d", resp.LoginStatusClass()) + } + + rp, err := ParseParams(resp.DataSegment) + if err != nil { + t.Fatal(err) + } + + // Verify negotiated values + if v, ok := rp.Get("HeaderDigest"); !ok || v != "None" { + t.Fatalf("HeaderDigest: %q", v) + } + if v, ok := rp.Get("DataDigest"); !ok || v != "None" { + t.Fatalf("DataDigest: %q", v) + } + if v, ok := rp.Get("InitialR2T"); !ok || v != "Yes" { + t.Fatalf("InitialR2T: %q", v) + } + if v, ok := rp.Get("ImmediateData"); !ok || v != "No" { + t.Fatalf("ImmediateData: %q", v) + } + if v, ok := rp.Get("TargetAlias"); !ok || v != "test-target" { + t.Fatalf("TargetAlias: %q", v) + } + + // Check negotiated state + if ln.NegInitialR2T != true { + t.Fatal("NegInitialR2T should be true") + } + if ln.NegImmediateData != false { + t.Fatal("NegImmediateData should be false") + } +} + +func testRedirectPermanent(t *testing.T) { + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams()) + resp := BuildRedirectResponse(req, "10.0.0.1:3260,1", true) + + if resp.LoginStatusClass() != LoginStatusRedirect { + t.Fatalf("class: %d", resp.LoginStatusClass()) + } + if resp.LoginStatusDetail() != LoginDetailTargetMoved { + t.Fatalf("detail: %d", resp.LoginStatusDetail()) + } + + rp, err := ParseParams(resp.DataSegment) + if err != nil { + t.Fatal(err) + } + if v, _ := rp.Get("TargetAddress"); v != "10.0.0.1:3260,1" { + t.Fatalf("TargetAddress: %q", v) + } +} + +func testRedirectTemporary(t *testing.T) { + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams()) + resp := BuildRedirectResponse(req, "192.168.1.100:3260,2", false) + + if resp.LoginStatusDetail() != LoginDetailTargetMovedTemp { + t.Fatalf("detail: %d", resp.LoginStatusDetail()) + } +} + +func testLoginResult(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + ln.HandleLoginPDU(req, resolver) + + result := ln.Result() + if result.InitiatorName != testInitiatorName { + t.Fatalf("initiator: %q", result.InitiatorName) + } + if result.SessionType != "Normal" { + t.Fatalf("session type: %q", result.SessionType) + } + if result.TSIH == 0 { + t.Fatal("TSIH should be assigned") + } +} + +func testUnknownKeyNotUnderstood(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + // Security phase first + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) + ln.HandleLoginPDU(req, resolver) + + // Op phase with unknown key + params2 := NewParams() + params2.Set("X-CustomKey", "whatever") + req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) + resp := ln.HandleLoginPDU(req2, resolver) + + rp, err := ParseParams(resp.DataSegment) + if err != nil { + t.Fatal(err) + } + if v, ok := rp.Get("X-CustomKey"); !ok || v != "NotUnderstood" { + t.Fatalf("expected NotUnderstood, got %q, %v", v, ok) + } +} + +func testHeaderDataDigestNone(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) + ln.HandleLoginPDU(req, resolver) + + params2 := NewParams() + params2.Set("HeaderDigest", "CRC32C") + params2.Set("DataDigest", "CRC32C") + req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) + resp := ln.HandleLoginPDU(req2, resolver) + + rp, _ := ParseParams(resp.DataSegment) + if v, _ := rp.Get("HeaderDigest"); v != "None" { + t.Fatalf("HeaderDigest: %q (we always negotiate None)", v) + } + if v, _ := rp.Get("DataDigest"); v != "None" { + t.Fatalf("DataDigest: %q", v) + } +} + +func testDuplicateLoginPDU(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + resolver := newResolver(testTargetName) + + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + + // Send same security PDU twice (no transit) + req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params) + resp1 := ln.HandleLoginPDU(req, resolver) + if resp1.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("first: %d", resp1.LoginStatusClass()) + } + + // Same stage again should still work + resp2 := ln.HandleLoginPDU(req, resolver) + if resp2.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("second: %d", resp2.LoginStatusClass()) + } +} + +func testInvalidCSG(t *testing.T) { + ln := NewLoginNegotiator(DefaultTargetConfig()) + + // CSG=FullFeature (3) is invalid as a current stage + req := &PDU{} + req.SetOpcode(OpLoginReq) + req.SetLoginStages(StageFullFeature, StageFullFeature) + req.SetLoginTransit(true) + isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} + req.SetISID(isid) + + resp := ln.HandleLoginPDU(req, nil) + if resp.LoginStatusClass() != LoginStatusInitiatorErr { + t.Fatalf("expected error for invalid CSG, got %d", resp.LoginStatusClass()) + } +} diff --git a/weed/storage/blockvol/iscsi/params.go b/weed/storage/blockvol/iscsi/params.go new file mode 100644 index 000000000..603afacc4 --- /dev/null +++ b/weed/storage/blockvol/iscsi/params.go @@ -0,0 +1,183 @@ +package iscsi + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +// iSCSI key-value text parameters (RFC 7143, Section 6). +// Parameters are encoded as "Key=Value\0" in the data segment. + +var ( + ErrMalformedParam = errors.New("iscsi: malformed parameter (missing '=')") + ErrEmptyKey = errors.New("iscsi: empty parameter key") + ErrDuplicateKey = errors.New("iscsi: duplicate parameter key") +) + +// Params is an ordered list of iSCSI key-value parameters. +// Order matters for negotiation, so we use a slice rather than a map. +type Params struct { + items []paramItem +} + +type paramItem struct { + key string + value string +} + +// ParseParams decodes "Key=Value\0Key=Value\0..." from raw bytes. +// Empty trailing segments (from a trailing \0) are ignored. +// Returns ErrDuplicateKey if the same key appears more than once. +func ParseParams(data []byte) (*Params, error) { + p := &Params{} + if len(data) == 0 { + return p, nil + } + + seen := make(map[string]bool) + s := string(data) + + // Split by null separator + parts := strings.Split(s, "\x00") + for _, part := range parts { + if part == "" { + continue // trailing null or empty + } + idx := strings.IndexByte(part, '=') + if idx < 0 { + return nil, fmt.Errorf("%w: %q", ErrMalformedParam, part) + } + key := part[:idx] + if key == "" { + return nil, ErrEmptyKey + } + value := part[idx+1:] + + if seen[key] { + return nil, fmt.Errorf("%w: %q", ErrDuplicateKey, key) + } + seen[key] = true + p.items = append(p.items, paramItem{key: key, value: value}) + } + + return p, nil +} + +// Encode serializes parameters to "Key=Value\0" format. +func (p *Params) Encode() []byte { + if len(p.items) == 0 { + return nil + } + var b strings.Builder + for _, item := range p.items { + b.WriteString(item.key) + b.WriteByte('=') + b.WriteString(item.value) + b.WriteByte(0) + } + return []byte(b.String()) +} + +// Get returns the value for a key, or ("", false) if not present. +func (p *Params) Get(key string) (string, bool) { + for _, item := range p.items { + if item.key == key { + return item.value, true + } + } + return "", false +} + +// Set adds or replaces a parameter. If the key already exists, its value +// is updated in place. Otherwise, the key is appended. +func (p *Params) Set(key, value string) { + for i, item := range p.items { + if item.key == key { + p.items[i].value = value + return + } + } + p.items = append(p.items, paramItem{key: key, value: value}) +} + +// Del removes a key if present. +func (p *Params) Del(key string) { + for i, item := range p.items { + if item.key == key { + p.items = append(p.items[:i], p.items[i+1:]...) + return + } + } +} + +// Keys returns all keys in order. +func (p *Params) Keys() []string { + keys := make([]string, len(p.items)) + for i, item := range p.items { + keys[i] = item.key + } + return keys +} + +// Len returns the number of parameters. +func (p *Params) Len() int { return len(p.items) } + +// Each iterates over all key-value pairs in order. +func (p *Params) Each(fn func(key, value string)) { + for _, item := range p.items { + fn(item.key, item.value) + } +} + +// --- Negotiation helpers --- + +// NegotiateNumber applies the iSCSI numeric negotiation rule: +// for "min" semantics (e.g., MaxRecvDataSegmentLength), return min(offer, ours). +// For "max" semantics (e.g., MaxBurstLength), return min(offer, ours). +// Both directions clamp to the smaller value, so min() is the general rule. +func NegotiateNumber(offered string, ours int, min, max int) (int, error) { + v, err := strconv.Atoi(offered) + if err != nil { + return 0, fmt.Errorf("iscsi: invalid numeric value %q: %w", offered, err) + } + // Clamp to valid range + if v < min { + v = min + } + if v > max { + v = max + } + // Standard negotiation: result is min(offer, ours) + if ours < v { + return ours, nil + } + return v, nil +} + +// NegotiateBool applies boolean negotiation (AND semantics per RFC 7143). +// Result is true only if both sides agree to true. +func NegotiateBool(offered string, ours bool) (bool, error) { + switch offered { + case "Yes": + return ours, nil + case "No": + return false, nil + default: + return false, fmt.Errorf("iscsi: invalid boolean value %q", offered) + } +} + +// BoolStr returns "Yes" or "No" for a boolean value. +func BoolStr(v bool) string { + if v { + return "Yes" + } + return "No" +} + +// NewParams creates a new empty Params. +func NewParams() *Params { + return &Params{} +} diff --git a/weed/storage/blockvol/iscsi/params_test.go b/weed/storage/blockvol/iscsi/params_test.go new file mode 100644 index 000000000..69a6081e2 --- /dev/null +++ b/weed/storage/blockvol/iscsi/params_test.go @@ -0,0 +1,357 @@ +package iscsi + +import ( + "bytes" + "strings" + "testing" +) + +func TestParams(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"parse_single", testParseSingle}, + {"parse_multiple", testParseMultiple}, + {"parse_empty", testParseEmpty}, + {"parse_trailing_null", testParseTrailingNull}, + {"parse_malformed_no_equals", testParseMalformedNoEquals}, + {"parse_empty_key", testParseEmptyKey}, + {"parse_empty_value", testParseEmptyValue}, + {"parse_duplicate_key", testParseDuplicateKey}, + {"roundtrip", testParamsRoundtrip}, + {"encode_empty", testEncodeEmpty}, + {"set_new_key", testSetNewKey}, + {"set_existing_key", testSetExistingKey}, + {"del_key", testDelKey}, + {"del_nonexistent", testDelNonexistent}, + {"keys_order", testKeysOrder}, + {"each", testEach}, + {"negotiate_number_min", testNegotiateNumberMin}, + {"negotiate_number_clamp", testNegotiateNumberClamp}, + {"negotiate_number_invalid", testNegotiateNumberInvalid}, + {"negotiate_bool_and", testNegotiateBoolAnd}, + {"negotiate_bool_invalid", testNegotiateBoolInvalid}, + {"bool_str", testBoolStr}, + {"value_with_equals", testValueWithEquals}, + {"parse_binary_data_value", testParseBinaryDataValue}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testParseSingle(t *testing.T) { + data := []byte("TargetName=iqn.2024.com.seaweedfs:vol1\x00") + p, err := ParseParams(data) + if err != nil { + t.Fatal(err) + } + if p.Len() != 1 { + t.Fatalf("expected 1 param, got %d", p.Len()) + } + v, ok := p.Get("TargetName") + if !ok || v != "iqn.2024.com.seaweedfs:vol1" { + t.Fatalf("got %q, %v", v, ok) + } +} + +func testParseMultiple(t *testing.T) { + data := []byte("InitiatorName=iqn.2024.com.test\x00TargetName=iqn.2024.com.seaweedfs:vol1\x00SessionType=Normal\x00") + p, err := ParseParams(data) + if err != nil { + t.Fatal(err) + } + if p.Len() != 3 { + t.Fatalf("expected 3 params, got %d", p.Len()) + } + keys := p.Keys() + if keys[0] != "InitiatorName" || keys[1] != "TargetName" || keys[2] != "SessionType" { + t.Fatalf("wrong order: %v", keys) + } +} + +func testParseEmpty(t *testing.T) { + p, err := ParseParams(nil) + if err != nil { + t.Fatal(err) + } + if p.Len() != 0 { + t.Fatalf("expected 0, got %d", p.Len()) + } + + p2, err := ParseParams([]byte{}) + if err != nil { + t.Fatal(err) + } + if p2.Len() != 0 { + t.Fatalf("expected 0, got %d", p2.Len()) + } +} + +func testParseTrailingNull(t *testing.T) { + // Multiple trailing nulls should not cause errors + data := []byte("Key=Value\x00\x00\x00") + p, err := ParseParams(data) + if err != nil { + t.Fatal(err) + } + if p.Len() != 1 { + t.Fatalf("expected 1, got %d", p.Len()) + } +} + +func testParseMalformedNoEquals(t *testing.T) { + data := []byte("KeyWithoutValue\x00") + _, err := ParseParams(data) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "malformed") { + t.Fatalf("unexpected error: %v", err) + } +} + +func testParseEmptyKey(t *testing.T) { + data := []byte("=Value\x00") + _, err := ParseParams(data) + if err != ErrEmptyKey { + t.Fatalf("expected ErrEmptyKey, got %v", err) + } +} + +func testParseEmptyValue(t *testing.T) { + // Empty value is valid in iSCSI (e.g., reject with empty value) + data := []byte("Key=\x00") + p, err := ParseParams(data) + if err != nil { + t.Fatal(err) + } + v, ok := p.Get("Key") + if !ok || v != "" { + t.Fatalf("got %q, %v", v, ok) + } +} + +func testParseDuplicateKey(t *testing.T) { + data := []byte("Key=Value1\x00Key=Value2\x00") + _, err := ParseParams(data) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "duplicate") { + t.Fatalf("unexpected error: %v", err) + } +} + +func testParamsRoundtrip(t *testing.T) { + original := []byte("InitiatorName=iqn.2024.com.test\x00MaxRecvDataSegmentLength=65536\x00") + p, err := ParseParams(original) + if err != nil { + t.Fatal(err) + } + + encoded := p.Encode() + if !bytes.Equal(encoded, original) { + t.Fatalf("roundtrip mismatch:\n got: %q\n want: %q", encoded, original) + } +} + +func testEncodeEmpty(t *testing.T) { + p := NewParams() + if encoded := p.Encode(); encoded != nil { + t.Fatalf("expected nil, got %q", encoded) + } +} + +func testSetNewKey(t *testing.T) { + p := NewParams() + p.Set("Key1", "Val1") + p.Set("Key2", "Val2") + if p.Len() != 2 { + t.Fatalf("expected 2, got %d", p.Len()) + } + v, _ := p.Get("Key2") + if v != "Val2" { + t.Fatalf("got %q", v) + } +} + +func testSetExistingKey(t *testing.T) { + p := NewParams() + p.Set("Key", "Old") + p.Set("Key", "New") + if p.Len() != 1 { + t.Fatalf("expected 1, got %d", p.Len()) + } + v, _ := p.Get("Key") + if v != "New" { + t.Fatalf("got %q", v) + } +} + +func testDelKey(t *testing.T) { + p := NewParams() + p.Set("A", "1") + p.Set("B", "2") + p.Set("C", "3") + p.Del("B") + if p.Len() != 2 { + t.Fatalf("expected 2, got %d", p.Len()) + } + if _, ok := p.Get("B"); ok { + t.Fatal("B should be deleted") + } + keys := p.Keys() + if keys[0] != "A" || keys[1] != "C" { + t.Fatalf("wrong keys: %v", keys) + } +} + +func testDelNonexistent(t *testing.T) { + p := NewParams() + p.Set("A", "1") + p.Del("Z") // should not panic + if p.Len() != 1 { + t.Fatalf("expected 1, got %d", p.Len()) + } +} + +func testKeysOrder(t *testing.T) { + data := []byte("Z=1\x00A=2\x00M=3\x00") + p, err := ParseParams(data) + if err != nil { + t.Fatal(err) + } + keys := p.Keys() + if keys[0] != "Z" || keys[1] != "A" || keys[2] != "M" { + t.Fatalf("order not preserved: %v", keys) + } +} + +func testEach(t *testing.T) { + p := NewParams() + p.Set("A", "1") + p.Set("B", "2") + var collected []string + p.Each(func(k, v string) { + collected = append(collected, k+"="+v) + }) + if len(collected) != 2 || collected[0] != "A=1" || collected[1] != "B=2" { + t.Fatalf("Each: %v", collected) + } +} + +func testNegotiateNumberMin(t *testing.T) { + // Both sides offer a value, result is min + result, err := NegotiateNumber("65536", 262144, 512, 16777215) + if err != nil { + t.Fatal(err) + } + if result != 65536 { + t.Fatalf("expected 65536, got %d", result) + } + + // Our value is smaller + result, err = NegotiateNumber("262144", 65536, 512, 16777215) + if err != nil { + t.Fatal(err) + } + if result != 65536 { + t.Fatalf("expected 65536, got %d", result) + } +} + +func testNegotiateNumberClamp(t *testing.T) { + // Offered value below minimum + result, err := NegotiateNumber("100", 65536, 512, 16777215) + if err != nil { + t.Fatal(err) + } + if result != 512 { + t.Fatalf("expected 512, got %d", result) + } + + // Offered value above maximum + result, err = NegotiateNumber("99999999", 65536, 512, 16777215) + if err != nil { + t.Fatal(err) + } + if result != 65536 { + t.Fatalf("expected 65536, got %d", result) + } +} + +func testNegotiateNumberInvalid(t *testing.T) { + _, err := NegotiateNumber("notanumber", 65536, 512, 16777215) + if err == nil { + t.Fatal("expected error") + } +} + +func testNegotiateBoolAnd(t *testing.T) { + // Both yes + r, err := NegotiateBool("Yes", true) + if err != nil || !r { + t.Fatalf("Yes+true = %v, %v", r, err) + } + // Initiator yes, target no + r, err = NegotiateBool("Yes", false) + if err != nil || r { + t.Fatalf("Yes+false = %v, %v", r, err) + } + // Initiator no, target yes + r, err = NegotiateBool("No", true) + if err != nil || r { + t.Fatalf("No+true = %v, %v", r, err) + } + // Both no + r, err = NegotiateBool("No", false) + if err != nil || r { + t.Fatalf("No+false = %v, %v", r, err) + } +} + +func testNegotiateBoolInvalid(t *testing.T) { + _, err := NegotiateBool("Maybe", true) + if err == nil { + t.Fatal("expected error") + } +} + +func testBoolStr(t *testing.T) { + if BoolStr(true) != "Yes" { + t.Fatal("true should be Yes") + } + if BoolStr(false) != "No" { + t.Fatal("false should be No") + } +} + +func testValueWithEquals(t *testing.T) { + // Value containing '=' is valid (only first '=' splits key from value) + data := []byte("TargetAddress=10.0.0.1:3260,1\x00Key=a=b=c\x00") + p, err := ParseParams(data) + if err != nil { + t.Fatal(err) + } + v, ok := p.Get("Key") + if !ok || v != "a=b=c" { + t.Fatalf("got %q, %v", v, ok) + } +} + +func testParseBinaryDataValue(t *testing.T) { + // Values can contain arbitrary bytes except null + data := []byte("Key=\x01\x02\x03\x00") + p, err := ParseParams(data) + if err != nil { + t.Fatal(err) + } + v, ok := p.Get("Key") + if !ok || v != "\x01\x02\x03" { + t.Fatalf("got %q, %v", v, ok) + } +} diff --git a/weed/storage/blockvol/iscsi/pdu.go b/weed/storage/blockvol/iscsi/pdu.go new file mode 100644 index 000000000..f6037c1cb --- /dev/null +++ b/weed/storage/blockvol/iscsi/pdu.go @@ -0,0 +1,447 @@ +package iscsi + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// BHS (Basic Header Segment) is 48 bytes, big-endian. +// RFC 7143, Section 12.1 +const BHSLength = 48 + +// Opcode constants — initiator opcodes (RFC 7143, Section 12.1.1) +const ( + OpNOPOut uint8 = 0x00 + OpSCSICmd uint8 = 0x01 + OpSCSITaskMgmt uint8 = 0x02 + OpLoginReq uint8 = 0x03 + OpTextReq uint8 = 0x04 + OpSCSIDataOut uint8 = 0x05 + OpLogoutReq uint8 = 0x06 + OpSNACKReq uint8 = 0x0c +) + +// Target opcodes (RFC 7143, Section 12.1.1) +const ( + OpNOPIn uint8 = 0x20 + OpSCSIResp uint8 = 0x21 + OpSCSITaskResp uint8 = 0x22 + OpLoginResp uint8 = 0x23 + OpTextResp uint8 = 0x24 + OpSCSIDataIn uint8 = 0x25 + OpLogoutResp uint8 = 0x26 + OpR2T uint8 = 0x31 + OpAsyncMsg uint8 = 0x32 + OpReject uint8 = 0x3f +) + +// BHS flag masks +const ( + opcMask = 0x3f // lower 6 bits of byte 0 + FlagI = 0x40 // Immediate delivery bit (byte 0) + FlagF = 0x80 // Final bit (byte 1) + FlagR = 0x40 // Read bit (byte 1, SCSI command) + FlagW = 0x20 // Write bit (byte 1, SCSI command) + FlagC = 0x40 // Continue bit (byte 1, login) + FlagT = 0x80 // Transit bit (byte 1, login) + FlagS = 0x01 // Status bit (byte 1, Data-In) + FlagU = 0x02 // Underflow bit (byte 1, SCSI Response) + FlagO = 0x04 // Overflow bit (byte 1, SCSI Response) + FlagBiU = 0x08 // Bidi underflow (byte 1, SCSI Response) + FlagBiO = 0x10 // Bidi overflow (byte 1, SCSI Response) + FlagA = 0x40 // Acknowledge bit (byte 1, Data-Out) +) + +// Login stage constants (CSG/NSG, 2-bit fields in byte 1 of login PDU) +const ( + StageSecurityNeg uint8 = 0 // Security Negotiation + StageLoginOp uint8 = 1 // Login Operational Negotiation + StageFullFeature uint8 = 3 // Full Feature Phase +) + +// SCSI status codes +const ( + SCSIStatusGood uint8 = 0x00 + SCSIStatusCheckCond uint8 = 0x02 + SCSIStatusBusy uint8 = 0x08 + SCSIStatusResvConflict uint8 = 0x18 +) + +// iSCSI response codes +const ( + ISCSIRespCompleted uint8 = 0x00 +) + +// MaxDataSegmentLength limits the maximum data segment we'll accept. +// The iSCSI data segment length is a 3-byte field (max 16MB-1). +// We cap at 8MB which is well above typical MaxRecvDataSegmentLength values. +const MaxDataSegmentLength = 8 * 1024 * 1024 // 8 MB + +var ( + ErrPDUTruncated = errors.New("iscsi: PDU truncated") + ErrPDUTooLarge = errors.New("iscsi: data segment exceeds maximum length") + ErrInvalidAHSLength = errors.New("iscsi: invalid AHS length (not multiple of 4)") + ErrUnknownOpcode = errors.New("iscsi: unknown opcode") +) + +// PDU represents a full iSCSI Protocol Data Unit. +type PDU struct { + BHS [BHSLength]byte + AHS []byte // Additional Header Segment (multiple of 4 bytes) + DataSegment []byte // Data segment (padded to 4-byte boundary on wire) +} + +// --- BHS field accessors --- + +// Opcode returns the opcode (lower 6 bits of byte 0). +func (p *PDU) Opcode() uint8 { return p.BHS[0] & opcMask } + +// SetOpcode sets the opcode (lower 6 bits of byte 0). +func (p *PDU) SetOpcode(op uint8) { + p.BHS[0] = (p.BHS[0] & ^uint8(opcMask)) | (op & opcMask) +} + +// Immediate returns true if the immediate delivery bit is set. +func (p *PDU) Immediate() bool { return p.BHS[0]&FlagI != 0 } + +// SetImmediate sets the immediate delivery bit. +func (p *PDU) SetImmediate(v bool) { + if v { + p.BHS[0] |= FlagI + } else { + p.BHS[0] &^= FlagI + } +} + +// OpSpecific1 returns byte 1 (opcode-specific flags). +func (p *PDU) OpSpecific1() uint8 { return p.BHS[1] } + +// SetOpSpecific1 sets byte 1. +func (p *PDU) SetOpSpecific1(v uint8) { p.BHS[1] = v } + +// TotalAHSLength returns the total AHS length in 4-byte words (byte 4). +func (p *PDU) TotalAHSLength() uint8 { return p.BHS[4] } + +// DataSegmentLength returns the 3-byte data segment length (bytes 5-7). +func (p *PDU) DataSegmentLength() uint32 { + return uint32(p.BHS[5])<<16 | uint32(p.BHS[6])<<8 | uint32(p.BHS[7]) +} + +// SetDataSegmentLength sets the 3-byte data segment length field. +func (p *PDU) SetDataSegmentLength(n uint32) { + p.BHS[5] = byte(n >> 16) + p.BHS[6] = byte(n >> 8) + p.BHS[7] = byte(n) +} + +// LUN returns the 8-byte LUN field (bytes 8-15). +func (p *PDU) LUN() uint64 { return binary.BigEndian.Uint64(p.BHS[8:16]) } + +// SetLUN sets the LUN field. +func (p *PDU) SetLUN(lun uint64) { binary.BigEndian.PutUint64(p.BHS[8:16], lun) } + +// InitiatorTaskTag returns bytes 16-19. +func (p *PDU) InitiatorTaskTag() uint32 { return binary.BigEndian.Uint32(p.BHS[16:20]) } + +// SetInitiatorTaskTag sets bytes 16-19. +func (p *PDU) SetInitiatorTaskTag(tag uint32) { binary.BigEndian.PutUint32(p.BHS[16:20], tag) } + +// Field32 reads a generic 4-byte field at the given BHS offset. +func (p *PDU) Field32(offset int) uint32 { return binary.BigEndian.Uint32(p.BHS[offset : offset+4]) } + +// SetField32 writes a generic 4-byte field at the given BHS offset. +func (p *PDU) SetField32(offset int, v uint32) { + binary.BigEndian.PutUint32(p.BHS[offset:offset+4], v) +} + +// --- Common BHS offsets for named fields --- + +// TSIH returns the Target Session Identifying Handle (bytes 14-15, login PDU). +func (p *PDU) TSIH() uint16 { return binary.BigEndian.Uint16(p.BHS[14:16]) } + +// SetTSIH sets the TSIH field. +func (p *PDU) SetTSIH(v uint16) { binary.BigEndian.PutUint16(p.BHS[14:16], v) } + +// CmdSN returns bytes 24-27. +func (p *PDU) CmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) } + +// SetCmdSN sets bytes 24-27. +func (p *PDU) SetCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) } + +// ExpStatSN returns bytes 28-31. +func (p *PDU) ExpStatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) } + +// SetExpStatSN sets bytes 28-31. +func (p *PDU) SetExpStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) } + +// StatSN returns bytes 24-27 (target response PDUs). +func (p *PDU) StatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) } + +// SetStatSN sets bytes 24-27. +func (p *PDU) SetStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) } + +// ExpCmdSN returns bytes 28-31 (target response PDUs). +func (p *PDU) ExpCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) } + +// SetExpCmdSN sets bytes 28-31. +func (p *PDU) SetExpCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) } + +// MaxCmdSN returns bytes 32-35 (target response PDUs). +func (p *PDU) MaxCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[32:36]) } + +// SetMaxCmdSN sets bytes 32-35. +func (p *PDU) SetMaxCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[32:36], v) } + +// DataSN returns bytes 36-39 (Data-In / Data-Out PDUs). +func (p *PDU) DataSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) } + +// SetDataSN sets bytes 36-39. +func (p *PDU) SetDataSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) } + +// BufferOffset returns bytes 40-43 (Data-In / Data-Out PDUs). +func (p *PDU) BufferOffset() uint32 { return binary.BigEndian.Uint32(p.BHS[40:44]) } + +// SetBufferOffset sets bytes 40-43. +func (p *PDU) SetBufferOffset(v uint32) { binary.BigEndian.PutUint32(p.BHS[40:44], v) } + +// R2TSN returns bytes 36-39 (R2T PDUs). +func (p *PDU) R2TSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) } + +// SetR2TSN sets bytes 36-39. +func (p *PDU) SetR2TSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) } + +// DesiredDataLength returns bytes 44-47 (R2T PDUs). +func (p *PDU) DesiredDataLength() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) } + +// SetDesiredDataLength sets bytes 44-47. +func (p *PDU) SetDesiredDataLength(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) } + +// ExpectedDataTransferLength returns bytes 20-23 (SCSI Command PDUs). +func (p *PDU) ExpectedDataTransferLength() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) } + +// SetExpectedDataTransferLength sets bytes 20-23. +func (p *PDU) SetExpectedDataTransferLength(v uint32) { + binary.BigEndian.PutUint32(p.BHS[20:24], v) +} + +// TargetTransferTag returns bytes 20-23 (target PDUs: R2T, Data-In, etc.). +func (p *PDU) TargetTransferTag() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) } + +// SetTargetTransferTag sets bytes 20-23. +func (p *PDU) SetTargetTransferTag(v uint32) { binary.BigEndian.PutUint32(p.BHS[20:24], v) } + +// ISID returns the 6-byte Initiator Session ID (bytes 8-13, login PDU). +func (p *PDU) ISID() [6]byte { + var id [6]byte + copy(id[:], p.BHS[8:14]) + return id +} + +// SetISID sets the 6-byte ISID field. +func (p *PDU) SetISID(id [6]byte) { copy(p.BHS[8:14], id[:]) } + +// CDB returns the 16-byte SCSI CDB from the BHS (bytes 32-47, SCSI Command PDU). +func (p *PDU) CDB() [16]byte { + var cdb [16]byte + copy(cdb[:], p.BHS[32:48]) + return cdb +} + +// SetCDB sets the 16-byte SCSI CDB in the BHS. +func (p *PDU) SetCDB(cdb [16]byte) { copy(p.BHS[32:48], cdb[:]) } + +// ResidualCount returns bytes 44-47 (SCSI Response PDUs). +func (p *PDU) ResidualCount() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) } + +// SetResidualCount sets bytes 44-47. +func (p *PDU) SetResidualCount(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) } + +// --- Wire I/O --- + +// pad4 rounds n up to the next multiple of 4. +func pad4(n uint32) uint32 { return (n + 3) &^ 3 } + +// ReadPDU reads a complete PDU from r. +func ReadPDU(r io.Reader) (*PDU, error) { + p := &PDU{} + + // Read BHS (48 bytes) + if _, err := io.ReadFull(r, p.BHS[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, ErrPDUTruncated + } + return nil, err + } + + // AHS + ahsLen := uint32(p.TotalAHSLength()) * 4 + if ahsLen > 0 { + p.AHS = make([]byte, ahsLen) + if _, err := io.ReadFull(r, p.AHS); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, ErrPDUTruncated + } + return nil, err + } + } + + // Data segment + dsLen := p.DataSegmentLength() + if dsLen > MaxDataSegmentLength { + return nil, fmt.Errorf("%w: %d bytes", ErrPDUTooLarge, dsLen) + } + + if dsLen > 0 { + paddedLen := pad4(dsLen) + buf := make([]byte, paddedLen) + if _, err := io.ReadFull(r, buf); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, ErrPDUTruncated + } + return nil, err + } + p.DataSegment = buf[:dsLen] // strip padding + } + + return p, nil +} + +// WritePDU writes a complete PDU to w, with proper padding. +func WritePDU(w io.Writer, p *PDU) error { + // Update header lengths from actual AHS/DataSegment + if len(p.AHS) > 0 { + if len(p.AHS)%4 != 0 { + return ErrInvalidAHSLength + } + p.BHS[4] = uint8(len(p.AHS) / 4) + } else { + p.BHS[4] = 0 + } + p.SetDataSegmentLength(uint32(len(p.DataSegment))) + + // Write BHS + if _, err := w.Write(p.BHS[:]); err != nil { + return err + } + + // Write AHS + if len(p.AHS) > 0 { + if _, err := w.Write(p.AHS); err != nil { + return err + } + } + + // Write data segment with padding + if len(p.DataSegment) > 0 { + if _, err := w.Write(p.DataSegment); err != nil { + return err + } + padLen := pad4(uint32(len(p.DataSegment))) - uint32(len(p.DataSegment)) + if padLen > 0 { + var pad [3]byte + if _, err := w.Write(pad[:padLen]); err != nil { + return err + } + } + } + + return nil +} + +// OpcodeName returns a human-readable name for the given opcode. +func OpcodeName(op uint8) string { + switch op { + case OpNOPOut: + return "NOP-Out" + case OpSCSICmd: + return "SCSI-Command" + case OpSCSITaskMgmt: + return "SCSI-Task-Mgmt" + case OpLoginReq: + return "Login-Request" + case OpTextReq: + return "Text-Request" + case OpSCSIDataOut: + return "SCSI-Data-Out" + case OpLogoutReq: + return "Logout-Request" + case OpSNACKReq: + return "SNACK-Request" + case OpNOPIn: + return "NOP-In" + case OpSCSIResp: + return "SCSI-Response" + case OpSCSITaskResp: + return "SCSI-Task-Mgmt-Response" + case OpLoginResp: + return "Login-Response" + case OpTextResp: + return "Text-Response" + case OpSCSIDataIn: + return "SCSI-Data-In" + case OpLogoutResp: + return "Logout-Response" + case OpR2T: + return "R2T" + case OpAsyncMsg: + return "Async-Message" + case OpReject: + return "Reject" + default: + return fmt.Sprintf("Unknown(0x%02x)", op) + } +} + +// Login PDU helpers + +// LoginCSG returns the Current Stage from byte 1 of a login PDU (bits 3-2). +func (p *PDU) LoginCSG() uint8 { return (p.BHS[1] >> 2) & 0x03 } + +// LoginNSG returns the Next Stage from byte 1 of a login PDU (bits 1-0). +func (p *PDU) LoginNSG() uint8 { return p.BHS[1] & 0x03 } + +// SetLoginStages sets the CSG and NSG fields in byte 1, preserving T and C flags. +func (p *PDU) SetLoginStages(csg, nsg uint8) { + p.BHS[1] = (p.BHS[1] & 0xF0) | ((csg & 0x03) << 2) | (nsg & 0x03) +} + +// LoginTransit returns true if the Transit bit is set (byte 1, bit 7). +func (p *PDU) LoginTransit() bool { return p.BHS[1]&FlagT != 0 } + +// SetLoginTransit sets or clears the Transit bit. +func (p *PDU) SetLoginTransit(v bool) { + if v { + p.BHS[1] |= FlagT + } else { + p.BHS[1] &^= FlagT + } +} + +// LoginContinue returns true if the Continue bit is set. +func (p *PDU) LoginContinue() bool { return p.BHS[1]&FlagC != 0 } + +// LoginStatusClass returns the status class (byte 36) of a login response. +func (p *PDU) LoginStatusClass() uint8 { return p.BHS[36] } + +// LoginStatusDetail returns the status detail (byte 37) of a login response. +func (p *PDU) LoginStatusDetail() uint8 { return p.BHS[37] } + +// SetLoginStatus sets the status class and detail in a login response. +func (p *PDU) SetLoginStatus(class, detail uint8) { + p.BHS[36] = class + p.BHS[37] = detail +} + +// SCSIResponse returns the iSCSI response byte (byte 2) of a SCSI Response PDU. +func (p *PDU) SCSIResponse() uint8 { return p.BHS[2] } + +// SetSCSIResponse sets byte 2. +func (p *PDU) SetSCSIResponse(v uint8) { p.BHS[2] = v } + +// SCSIStatus returns the SCSI status byte (byte 3) of a SCSI Response PDU. +func (p *PDU) SCSIStatus() uint8 { return p.BHS[3] } + +// SetSCSIStatus sets byte 3. +func (p *PDU) SetSCSIStatus(v uint8) { p.BHS[3] = v } diff --git a/weed/storage/blockvol/iscsi/pdu_test.go b/weed/storage/blockvol/iscsi/pdu_test.go new file mode 100644 index 000000000..735f14339 --- /dev/null +++ b/weed/storage/blockvol/iscsi/pdu_test.go @@ -0,0 +1,559 @@ +package iscsi + +import ( + "bytes" + "encoding/binary" + "io" + "strings" + "testing" +) + +func TestPDU(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"roundtrip_bhs_only", testRoundtripBHSOnly}, + {"roundtrip_with_data", testRoundtripWithData}, + {"roundtrip_with_ahs_and_data", testRoundtripWithAHSAndData}, + {"data_padding", testDataPadding}, + {"opcode_accessors", testOpcodeAccessors}, + {"immediate_flag", testImmediateFlag}, + {"field32_accessors", testField32Accessors}, + {"lun_accessor", testLUNAccessor}, + {"isid_accessor", testISIDAccessor}, + {"cdb_accessor", testCDBAccessor}, + {"login_stage_accessors", testLoginStageAccessors}, + {"login_transit_continue", testLoginTransitContinue}, + {"login_status", testLoginStatus}, + {"scsi_response_status", testSCSIResponseStatus}, + {"data_segment_length_3byte", testDataSegmentLength3Byte}, + {"tsih_accessor", testTSIHAccessor}, + {"cmdsn_expstatsn", testCmdSNExpStatSN}, + {"statsn_expcmdsn_maxcmdsn", testStatSNExpCmdSNMaxCmdSN}, + {"datasn_bufferoffset", testDataSNBufferOffset}, + {"r2t_fields", testR2TFields}, + {"residual_count", testResidualCount}, + {"expected_data_transfer_length", testExpectedDataTransferLength}, + {"target_transfer_tag", testTargetTransferTag}, + {"opcode_name", testOpcodeName}, + {"read_truncated_bhs", testReadTruncatedBHS}, + {"read_truncated_data", testReadTruncatedData}, + {"read_truncated_ahs", testReadTruncatedAHS}, + {"read_oversized_data", testReadOversizedData}, + {"read_eof", testReadEOF}, + {"write_invalid_ahs_length", testWriteInvalidAHSLength}, + {"roundtrip_all_opcodes", testRoundtripAllOpcodes}, + {"data_segment_exact_4byte_boundary", testDataSegmentExact4ByteBoundary}, + {"zero_length_data_segment", testZeroLengthDataSegment}, + {"max_3byte_data_length", testMax3ByteDataLength}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testRoundtripBHSOnly(t *testing.T) { + p := &PDU{} + p.SetOpcode(OpSCSICmd) + p.SetInitiatorTaskTag(0xDEADBEEF) + p.SetCmdSN(42) + + var buf bytes.Buffer + if err := WritePDU(&buf, p); err != nil { + t.Fatal(err) + } + if buf.Len() != BHSLength { + t.Fatalf("expected %d bytes, got %d", BHSLength, buf.Len()) + } + + p2, err := ReadPDU(&buf) + if err != nil { + t.Fatal(err) + } + if p2.Opcode() != OpSCSICmd { + t.Fatalf("opcode mismatch: got 0x%02x", p2.Opcode()) + } + if p2.InitiatorTaskTag() != 0xDEADBEEF { + t.Fatalf("ITT mismatch: got 0x%x", p2.InitiatorTaskTag()) + } + if p2.CmdSN() != 42 { + t.Fatalf("CmdSN mismatch: got %d", p2.CmdSN()) + } +} + +func testRoundtripWithData(t *testing.T) { + p := &PDU{} + p.SetOpcode(OpSCSIDataIn) + p.DataSegment = []byte("hello, iSCSI world!") + + var buf bytes.Buffer + if err := WritePDU(&buf, p); err != nil { + t.Fatal(err) + } + + // BHS(48) + data(19) + padding(1) = 68 + expectedLen := 48 + pad4(uint32(len(p.DataSegment))) + if uint32(buf.Len()) != expectedLen { + t.Fatalf("wire length: expected %d, got %d", expectedLen, buf.Len()) + } + + p2, err := ReadPDU(&buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(p2.DataSegment, []byte("hello, iSCSI world!")) { + t.Fatalf("data mismatch: %q", p2.DataSegment) + } +} + +func testRoundtripWithAHSAndData(t *testing.T) { + p := &PDU{} + p.SetOpcode(OpSCSICmd) + p.AHS = make([]byte, 8) // 2 words + p.AHS[0] = 0xAA + p.AHS[7] = 0xBB + p.DataSegment = []byte{1, 2, 3, 4, 5} + + var buf bytes.Buffer + if err := WritePDU(&buf, p); err != nil { + t.Fatal(err) + } + + p2, err := ReadPDU(&buf) + if err != nil { + t.Fatal(err) + } + if len(p2.AHS) != 8 { + t.Fatalf("AHS length: expected 8, got %d", len(p2.AHS)) + } + if p2.AHS[0] != 0xAA || p2.AHS[7] != 0xBB { + t.Fatal("AHS content mismatch") + } + if !bytes.Equal(p2.DataSegment, []byte{1, 2, 3, 4, 5}) { + t.Fatal("data segment mismatch") + } +} + +func testDataPadding(t *testing.T) { + for _, dataLen := range []int{1, 2, 3, 4, 5, 7, 8, 100, 1023} { + p := &PDU{} + p.SetOpcode(OpSCSIDataIn) + p.DataSegment = bytes.Repeat([]byte{0xFF}, dataLen) + + var buf bytes.Buffer + if err := WritePDU(&buf, p); err != nil { + t.Fatalf("dataLen=%d: write: %v", dataLen, err) + } + + expectedWire := BHSLength + int(pad4(uint32(dataLen))) + if buf.Len() != expectedWire { + t.Fatalf("dataLen=%d: wire=%d, expected=%d", dataLen, buf.Len(), expectedWire) + } + + p2, err := ReadPDU(&buf) + if err != nil { + t.Fatalf("dataLen=%d: read: %v", dataLen, err) + } + if len(p2.DataSegment) != dataLen { + t.Fatalf("dataLen=%d: got %d", dataLen, len(p2.DataSegment)) + } + } +} + +func testOpcodeAccessors(t *testing.T) { + p := &PDU{} + // Set immediate first, then opcode — verify no interference + p.SetImmediate(true) + p.SetOpcode(OpLoginReq) + if p.Opcode() != OpLoginReq { + t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpLoginReq) + } + if !p.Immediate() { + t.Fatal("immediate flag lost") + } + + // Change opcode, verify immediate preserved + p.SetOpcode(OpTextReq) + if p.Opcode() != OpTextReq { + t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpTextReq) + } + if !p.Immediate() { + t.Fatal("immediate flag lost after opcode change") + } +} + +func testImmediateFlag(t *testing.T) { + p := &PDU{} + if p.Immediate() { + t.Fatal("should start false") + } + p.SetImmediate(true) + if !p.Immediate() { + t.Fatal("should be true") + } + p.SetImmediate(false) + if p.Immediate() { + t.Fatal("should be false again") + } +} + +func testField32Accessors(t *testing.T) { + p := &PDU{} + p.SetField32(20, 0x12345678) + if got := p.Field32(20); got != 0x12345678 { + t.Fatalf("got 0x%08x", got) + } +} + +func testLUNAccessor(t *testing.T) { + p := &PDU{} + p.SetLUN(0x0001000000000000) // LUN 1 in SAM encoding + if p.LUN() != 0x0001000000000000 { + t.Fatalf("LUN mismatch: 0x%016x", p.LUN()) + } +} + +func testISIDAccessor(t *testing.T) { + p := &PDU{} + isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} + p.SetISID(isid) + got := p.ISID() + if got != isid { + t.Fatalf("ISID mismatch: %v", got) + } +} + +func testCDBAccessor(t *testing.T) { + p := &PDU{} + var cdb [16]byte + cdb[0] = 0x28 // READ_10 + cdb[1] = 0x00 + binary.BigEndian.PutUint32(cdb[2:6], 100) // LBA + binary.BigEndian.PutUint16(cdb[7:9], 8) // transfer length + p.SetCDB(cdb) + got := p.CDB() + if got != cdb { + t.Fatal("CDB mismatch") + } +} + +func testLoginStageAccessors(t *testing.T) { + p := &PDU{} + p.SetLoginStages(StageSecurityNeg, StageLoginOp) + if p.LoginCSG() != StageSecurityNeg { + t.Fatalf("CSG: got %d", p.LoginCSG()) + } + if p.LoginNSG() != StageLoginOp { + t.Fatalf("NSG: got %d", p.LoginNSG()) + } + + // Change to LoginOp -> FullFeature + p.SetLoginStages(StageLoginOp, StageFullFeature) + if p.LoginCSG() != StageLoginOp { + t.Fatalf("CSG: got %d", p.LoginCSG()) + } + if p.LoginNSG() != StageFullFeature { + t.Fatalf("NSG: got %d", p.LoginNSG()) + } +} + +func testLoginTransitContinue(t *testing.T) { + p := &PDU{} + if p.LoginTransit() { + t.Fatal("Transit should start false") + } + if p.LoginContinue() { + t.Fatal("Continue should start false") + } + + p.SetLoginTransit(true) + p.SetLoginStages(StageLoginOp, StageFullFeature) + if !p.LoginTransit() { + t.Fatal("Transit should be true") + } + // Verify stages preserved + if p.LoginCSG() != StageLoginOp { + t.Fatal("CSG lost after setting transit") + } +} + +func testLoginStatus(t *testing.T) { + p := &PDU{} + p.SetLoginStatus(0x02, 0x01) // Initiator error, authentication failure + if p.LoginStatusClass() != 0x02 { + t.Fatalf("class: got %d", p.LoginStatusClass()) + } + if p.LoginStatusDetail() != 0x01 { + t.Fatalf("detail: got %d", p.LoginStatusDetail()) + } +} + +func testSCSIResponseStatus(t *testing.T) { + p := &PDU{} + p.SetSCSIResponse(ISCSIRespCompleted) + p.SetSCSIStatus(SCSIStatusGood) + if p.SCSIResponse() != ISCSIRespCompleted { + t.Fatal("response mismatch") + } + if p.SCSIStatus() != SCSIStatusGood { + t.Fatal("status mismatch") + } +} + +func testDataSegmentLength3Byte(t *testing.T) { + p := &PDU{} + // Test various sizes including >64KB (needs all 3 bytes) + for _, size := range []uint32{0, 1, 255, 256, 65535, 65536, 1<<24 - 1} { + p.SetDataSegmentLength(size) + if got := p.DataSegmentLength(); got != size { + t.Fatalf("size %d: got %d", size, got) + } + } +} + +func testTSIHAccessor(t *testing.T) { + p := &PDU{} + p.SetTSIH(0x1234) + if p.TSIH() != 0x1234 { + t.Fatalf("TSIH: got 0x%04x", p.TSIH()) + } +} + +func testCmdSNExpStatSN(t *testing.T) { + p := &PDU{} + p.SetCmdSN(100) + p.SetExpStatSN(200) + if p.CmdSN() != 100 { + t.Fatal("CmdSN mismatch") + } + if p.ExpStatSN() != 200 { + t.Fatal("ExpStatSN mismatch") + } +} + +func testStatSNExpCmdSNMaxCmdSN(t *testing.T) { + p := &PDU{} + p.SetStatSN(10) + p.SetExpCmdSN(20) + p.SetMaxCmdSN(30) + if p.StatSN() != 10 { + t.Fatal("StatSN mismatch") + } + if p.ExpCmdSN() != 20 { + t.Fatal("ExpCmdSN mismatch") + } + if p.MaxCmdSN() != 30 { + t.Fatal("MaxCmdSN mismatch") + } +} + +func testDataSNBufferOffset(t *testing.T) { + p := &PDU{} + p.SetDataSN(5) + p.SetBufferOffset(8192) + if p.DataSN() != 5 { + t.Fatal("DataSN mismatch") + } + if p.BufferOffset() != 8192 { + t.Fatal("BufferOffset mismatch") + } +} + +func testR2TFields(t *testing.T) { + p := &PDU{} + p.SetR2TSN(3) + p.SetDesiredDataLength(65536) + if p.R2TSN() != 3 { + t.Fatal("R2TSN mismatch") + } + if p.DesiredDataLength() != 65536 { + t.Fatal("DesiredDataLength mismatch") + } +} + +func testResidualCount(t *testing.T) { + p := &PDU{} + p.SetResidualCount(512) + if p.ResidualCount() != 512 { + t.Fatal("ResidualCount mismatch") + } +} + +func testExpectedDataTransferLength(t *testing.T) { + p := &PDU{} + p.SetExpectedDataTransferLength(1048576) + if p.ExpectedDataTransferLength() != 1048576 { + t.Fatal("mismatch") + } +} + +func testTargetTransferTag(t *testing.T) { + p := &PDU{} + p.SetTargetTransferTag(0xFFFFFFFF) + if p.TargetTransferTag() != 0xFFFFFFFF { + t.Fatal("mismatch") + } +} + +func testOpcodeName(t *testing.T) { + if OpcodeName(OpSCSICmd) != "SCSI-Command" { + t.Fatal("wrong name for SCSI-Command") + } + if OpcodeName(OpLoginReq) != "Login-Request" { + t.Fatal("wrong name for Login-Request") + } + if !strings.HasPrefix(OpcodeName(0xFF), "Unknown") { + t.Fatal("unknown opcode should have Unknown prefix") + } +} + +func testReadTruncatedBHS(t *testing.T) { + // Only 20 bytes — not enough for BHS + buf := bytes.NewReader(make([]byte, 20)) + _, err := ReadPDU(buf) + if err == nil { + t.Fatal("expected error") + } +} + +func testReadTruncatedData(t *testing.T) { + // Valid BHS claiming 100 bytes of data, but only 10 bytes follow + var bhs [BHSLength]byte + bhs[5] = 0 + bhs[6] = 0 + bhs[7] = 100 // DataSegmentLength = 100 + buf := make([]byte, BHSLength+10) + copy(buf, bhs[:]) + + _, err := ReadPDU(bytes.NewReader(buf)) + if err == nil { + t.Fatal("expected truncation error") + } +} + +func testReadTruncatedAHS(t *testing.T) { + // BHS claiming 2 words of AHS but only 4 bytes follow + var bhs [BHSLength]byte + bhs[4] = 2 // TotalAHSLength = 2 words = 8 bytes + buf := make([]byte, BHSLength+4) + copy(buf, bhs[:]) + + _, err := ReadPDU(bytes.NewReader(buf)) + if err == nil { + t.Fatal("expected truncation error") + } +} + +func testReadOversizedData(t *testing.T) { + // BHS claiming MaxDataSegmentLength+1 bytes + var bhs [BHSLength]byte + oversized := uint32(MaxDataSegmentLength + 1) + bhs[5] = byte(oversized >> 16) + bhs[6] = byte(oversized >> 8) + bhs[7] = byte(oversized) + + _, err := ReadPDU(bytes.NewReader(bhs[:])) + if err == nil { + t.Fatal("expected oversized error") + } + if !strings.Contains(err.Error(), "exceeds maximum") { + t.Fatalf("unexpected error: %v", err) + } +} + +func testReadEOF(t *testing.T) { + _, err := ReadPDU(bytes.NewReader(nil)) + if err != io.EOF { + t.Fatalf("expected io.EOF, got %v", err) + } +} + +func testWriteInvalidAHSLength(t *testing.T) { + p := &PDU{} + p.AHS = make([]byte, 5) // not multiple of 4 + var buf bytes.Buffer + err := WritePDU(&buf, p) + if err != ErrInvalidAHSLength { + t.Fatalf("expected ErrInvalidAHSLength, got %v", err) + } +} + +func testRoundtripAllOpcodes(t *testing.T) { + opcodes := []uint8{ + OpNOPOut, OpSCSICmd, OpSCSITaskMgmt, OpLoginReq, OpTextReq, + OpSCSIDataOut, OpLogoutReq, OpSNACKReq, + OpNOPIn, OpSCSIResp, OpSCSITaskResp, OpLoginResp, OpTextResp, + OpSCSIDataIn, OpLogoutResp, OpR2T, OpAsyncMsg, OpReject, + } + for _, op := range opcodes { + p := &PDU{} + p.SetOpcode(op) + var buf bytes.Buffer + if err := WritePDU(&buf, p); err != nil { + t.Fatalf("op=0x%02x: write: %v", op, err) + } + p2, err := ReadPDU(&buf) + if err != nil { + t.Fatalf("op=0x%02x: read: %v", op, err) + } + if p2.Opcode() != op { + t.Fatalf("op=0x%02x: got 0x%02x", op, p2.Opcode()) + } + } +} + +func testDataSegmentExact4ByteBoundary(t *testing.T) { + for _, size := range []int{4, 8, 12, 16, 256, 4096} { + p := &PDU{} + p.DataSegment = bytes.Repeat([]byte{0xAB}, size) + var buf bytes.Buffer + if err := WritePDU(&buf, p); err != nil { + t.Fatalf("size=%d: %v", size, err) + } + // No padding needed + if buf.Len() != BHSLength+size { + t.Fatalf("size=%d: wire=%d, expected=%d", size, buf.Len(), BHSLength+size) + } + p2, err := ReadPDU(&buf) + if err != nil { + t.Fatalf("size=%d: read: %v", size, err) + } + if len(p2.DataSegment) != size { + t.Fatalf("size=%d: got %d", size, len(p2.DataSegment)) + } + } +} + +func testZeroLengthDataSegment(t *testing.T) { + p := &PDU{} + p.SetOpcode(OpNOPOut) + // DataSegment is nil + + var buf bytes.Buffer + if err := WritePDU(&buf, p); err != nil { + t.Fatal(err) + } + if buf.Len() != BHSLength { + t.Fatalf("expected %d, got %d", BHSLength, buf.Len()) + } + + p2, err := ReadPDU(&buf) + if err != nil { + t.Fatal(err) + } + if len(p2.DataSegment) != 0 { + t.Fatalf("expected empty data segment, got %d bytes", len(p2.DataSegment)) + } +} + +func testMax3ByteDataLength(t *testing.T) { + p := &PDU{} + maxLen := uint32(1<<24 - 1) // 16MB - 1 + p.SetDataSegmentLength(maxLen) + if got := p.DataSegmentLength(); got != maxLen { + t.Fatalf("expected %d, got %d", maxLen, got) + } +} diff --git a/weed/storage/blockvol/iscsi/scsi.go b/weed/storage/blockvol/iscsi/scsi.go new file mode 100644 index 000000000..ef3f87a7c --- /dev/null +++ b/weed/storage/blockvol/iscsi/scsi.go @@ -0,0 +1,463 @@ +package iscsi + +import ( + "encoding/binary" +) + +// SCSI opcode constants (SPC-5 / SBC-4) +const ( + ScsiTestUnitReady uint8 = 0x00 + ScsiInquiry uint8 = 0x12 + ScsiModeSense6 uint8 = 0x1a + ScsiReadCapacity10 uint8 = 0x25 + ScsiRead10 uint8 = 0x28 + ScsiWrite10 uint8 = 0x2a + ScsiSyncCache10 uint8 = 0x35 + ScsiUnmap uint8 = 0x42 + ScsiReportLuns uint8 = 0xa0 + ScsiRead16 uint8 = 0x88 + ScsiWrite16 uint8 = 0x8a + ScsiReadCapacity16 uint8 = 0x9e // SERVICE ACTION IN (16), SA=0x10 + ScsiSyncCache16 uint8 = 0x91 +) + +// Service action for READ CAPACITY (16) +const ScsiSAReadCapacity16 uint8 = 0x10 + +// SCSI sense keys +const ( + SenseNoSense uint8 = 0x00 + SenseNotReady uint8 = 0x02 + SenseMediumError uint8 = 0x03 + SenseHardwareError uint8 = 0x04 + SenseIllegalRequest uint8 = 0x05 + SenseAbortedCommand uint8 = 0x0b +) + +// ASC/ASCQ pairs +const ( + ASCInvalidOpcode uint8 = 0x20 + ASCQLuk uint8 = 0x00 + ASCInvalidFieldInCDB uint8 = 0x24 + ASCLBAOutOfRange uint8 = 0x21 + ASCNotReady uint8 = 0x04 + ASCQNotReady uint8 = 0x03 // manual intervention required +) + +// BlockDevice is the interface that the SCSI command handler uses to +// interact with the underlying storage. This maps onto BlockVol. +type BlockDevice interface { + ReadAt(lba uint64, length uint32) ([]byte, error) + WriteAt(lba uint64, data []byte) error + Trim(lba uint64, length uint32) error + SyncCache() error + BlockSize() uint32 + VolumeSize() uint64 // total size in bytes + IsHealthy() bool +} + +// SCSIHandler processes SCSI commands from iSCSI PDUs. +type SCSIHandler struct { + dev BlockDevice + vendorID string // 8 bytes for INQUIRY + prodID string // 16 bytes for INQUIRY + serial string // for VPD page 0x80 +} + +// NewSCSIHandler creates a SCSI command handler for the given block device. +func NewSCSIHandler(dev BlockDevice) *SCSIHandler { + return &SCSIHandler{ + dev: dev, + vendorID: "SeaweedF", + prodID: "BlockVol ", + serial: "SWF00001", + } +} + +// SCSIResult holds the result of a SCSI command execution. +type SCSIResult struct { + Status uint8 // SCSI status + Data []byte // Response data (for Data-In) + SenseKey uint8 // Sense key (if CHECK_CONDITION) + SenseASC uint8 // Additional sense code + SenseASCQ uint8 // Additional sense code qualifier +} + +// HandleCommand dispatches a SCSI CDB to the appropriate handler. +// dataOut contains any data sent by the initiator (for WRITE commands). +func (h *SCSIHandler) HandleCommand(cdb [16]byte, dataOut []byte) SCSIResult { + opcode := cdb[0] + + switch opcode { + case ScsiTestUnitReady: + return h.testUnitReady() + case ScsiInquiry: + return h.inquiry(cdb) + case ScsiModeSense6: + return h.modeSense6(cdb) + case ScsiReadCapacity10: + return h.readCapacity10() + case ScsiReadCapacity16: + sa := cdb[1] & 0x1f + if sa == ScsiSAReadCapacity16 { + return h.readCapacity16(cdb) + } + return illegalRequest(ASCInvalidOpcode, ASCQLuk) + case ScsiReportLuns: + return h.reportLuns(cdb) + case ScsiRead10: + return h.read10(cdb) + case ScsiRead16: + return h.read16(cdb) + case ScsiWrite10: + return h.write10(cdb, dataOut) + case ScsiWrite16: + return h.write16(cdb, dataOut) + case ScsiSyncCache10: + return h.syncCache() + case ScsiSyncCache16: + return h.syncCache() + case ScsiUnmap: + return h.unmap(cdb, dataOut) + default: + return illegalRequest(ASCInvalidOpcode, ASCQLuk) + } +} + +// --- Metadata commands --- + +func (h *SCSIHandler) testUnitReady() SCSIResult { + if !h.dev.IsHealthy() { + return SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: SenseNotReady, + SenseASC: ASCNotReady, + SenseASCQ: ASCQNotReady, + } + } + return SCSIResult{Status: SCSIStatusGood} +} + +func (h *SCSIHandler) inquiry(cdb [16]byte) SCSIResult { + evpd := cdb[1] & 0x01 + pageCode := cdb[2] + allocLen := binary.BigEndian.Uint16(cdb[3:5]) + if allocLen == 0 { + allocLen = 36 + } + + if evpd != 0 { + return h.inquiryVPD(pageCode, allocLen) + } + + // Standard INQUIRY response (SPC-5, Section 6.6.1) + data := make([]byte, 96) + data[0] = 0x00 // Peripheral device type: SBC (direct access block device) + data[1] = 0x00 // RMB=0 (not removable) + data[2] = 0x06 // SPC-4 version + data[3] = 0x02 // Response data format = 2 (SPC-2+) + data[4] = 91 // Additional length (96-5) + data[5] = 0x00 // SCCS, ACC, TPGS, 3PC + data[6] = 0x00 // Obsolete, EncServ, VS, MultiP + data[7] = 0x02 // CmdQue=1 (supports command queuing) + + // Vendor ID (bytes 8-15, 8 chars, space padded) + copy(data[8:16], padRight(h.vendorID, 8)) + // Product ID (bytes 16-31, 16 chars, space padded) + copy(data[16:32], padRight(h.prodID, 16)) + // Product revision (bytes 32-35, 4 chars) + copy(data[32:36], "0001") + + if int(allocLen) < len(data) { + data = data[:allocLen] + } + return SCSIResult{Status: SCSIStatusGood, Data: data} +} + +func (h *SCSIHandler) inquiryVPD(pageCode uint8, allocLen uint16) SCSIResult { + switch pageCode { + case 0x00: // Supported VPD pages + data := []byte{ + 0x00, // device type + 0x00, // page code + 0x00, 0x03, // page length + 0x00, // supported pages: 0x00 + 0x80, // 0x80 (serial) + 0x83, // 0x83 (device identification) + } + if int(allocLen) < len(data) { + data = data[:allocLen] + } + return SCSIResult{Status: SCSIStatusGood, Data: data} + + case 0x80: // Unit serial number + serial := padRight(h.serial, 8) + data := make([]byte, 4+len(serial)) + data[0] = 0x00 // device type + data[1] = 0x80 // page code + binary.BigEndian.PutUint16(data[2:4], uint16(len(serial))) + copy(data[4:], serial) + if int(allocLen) < len(data) { + data = data[:allocLen] + } + return SCSIResult{Status: SCSIStatusGood, Data: data} + + case 0x83: // Device identification + // NAA identifier (8 bytes) + naaID := []byte{ + 0x01, // code set: binary + 0x03, // identifier type: NAA + 0x00, // reserved + 0x08, // identifier length + 0x60, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // NAA-6 fake + } + data := make([]byte, 4+len(naaID)) + data[0] = 0x00 // device type + data[1] = 0x83 // page code + binary.BigEndian.PutUint16(data[2:4], uint16(len(naaID))) + copy(data[4:], naaID) + if int(allocLen) < len(data) { + data = data[:allocLen] + } + return SCSIResult{Status: SCSIStatusGood, Data: data} + + default: + return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) + } +} + +func (h *SCSIHandler) readCapacity10() SCSIResult { + blockSize := h.dev.BlockSize() + totalBlocks := h.dev.VolumeSize() / uint64(blockSize) + + data := make([]byte, 8) + // If >2TB (blocks > 0xFFFFFFFF), return 0xFFFFFFFF to signal use READ_CAPACITY_16 + if totalBlocks > 0xFFFFFFFF { + binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF) + } else { + binary.BigEndian.PutUint32(data[0:4], uint32(totalBlocks-1)) // last LBA + } + binary.BigEndian.PutUint32(data[4:8], blockSize) + + return SCSIResult{Status: SCSIStatusGood, Data: data} +} + +func (h *SCSIHandler) readCapacity16(cdb [16]byte) SCSIResult { + allocLen := binary.BigEndian.Uint32(cdb[10:14]) + if allocLen < 32 { + allocLen = 32 + } + + blockSize := h.dev.BlockSize() + totalBlocks := h.dev.VolumeSize() / uint64(blockSize) + + data := make([]byte, 32) + binary.BigEndian.PutUint64(data[0:8], totalBlocks-1) // last LBA + binary.BigEndian.PutUint32(data[8:12], blockSize) // block length + // data[12]: LBPME (logical block provisioning management enabled) = 1 for UNMAP support + data[14] = 0x80 // LBPME bit + + if allocLen < uint32(len(data)) { + data = data[:allocLen] + } + return SCSIResult{Status: SCSIStatusGood, Data: data} +} + +func (h *SCSIHandler) modeSense6(cdb [16]byte) SCSIResult { + // Minimal MODE SENSE(6) response — no mode pages + allocLen := cdb[4] + if allocLen == 0 { + allocLen = 4 + } + + data := make([]byte, 4) + data[0] = 3 // Mode data length (3 bytes follow) + data[1] = 0x00 // Medium type: default + data[2] = 0x00 // Device-specific parameter (no write protect) + data[3] = 0x00 // Block descriptor length = 0 + + if int(allocLen) < len(data) { + data = data[:allocLen] + } + return SCSIResult{Status: SCSIStatusGood, Data: data} +} + +func (h *SCSIHandler) reportLuns(cdb [16]byte) SCSIResult { + allocLen := binary.BigEndian.Uint32(cdb[6:10]) + if allocLen < 16 { + allocLen = 16 + } + + // Report a single LUN (LUN 0) + data := make([]byte, 16) + binary.BigEndian.PutUint32(data[0:4], 8) // LUN list length (8 bytes, 1 LUN) + // data[4:7] reserved + // data[8:16] = LUN 0 (all zeros) + + if allocLen < uint32(len(data)) { + data = data[:allocLen] + } + return SCSIResult{Status: SCSIStatusGood, Data: data} +} + +// --- Data commands (Task 2.6) --- + +func (h *SCSIHandler) read10(cdb [16]byte) SCSIResult { + lba := uint64(binary.BigEndian.Uint32(cdb[2:6])) + transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9])) + return h.doRead(lba, transferLen) +} + +func (h *SCSIHandler) read16(cdb [16]byte) SCSIResult { + lba := binary.BigEndian.Uint64(cdb[2:10]) + transferLen := binary.BigEndian.Uint32(cdb[10:14]) + return h.doRead(lba, transferLen) +} + +func (h *SCSIHandler) write10(cdb [16]byte, dataOut []byte) SCSIResult { + lba := uint64(binary.BigEndian.Uint32(cdb[2:6])) + transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9])) + return h.doWrite(lba, transferLen, dataOut) +} + +func (h *SCSIHandler) write16(cdb [16]byte, dataOut []byte) SCSIResult { + lba := binary.BigEndian.Uint64(cdb[2:10]) + transferLen := binary.BigEndian.Uint32(cdb[10:14]) + return h.doWrite(lba, transferLen, dataOut) +} + +func (h *SCSIHandler) doRead(lba uint64, transferLen uint32) SCSIResult { + if transferLen == 0 { + return SCSIResult{Status: SCSIStatusGood} + } + + blockSize := h.dev.BlockSize() + totalBlocks := h.dev.VolumeSize() / uint64(blockSize) + + if lba+uint64(transferLen) > totalBlocks { + return illegalRequest(ASCLBAOutOfRange, ASCQLuk) + } + + byteLen := transferLen * blockSize + data, err := h.dev.ReadAt(lba, byteLen) + if err != nil { + return SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: SenseMediumError, + SenseASC: 0x11, // Unrecovered read error + SenseASCQ: 0x00, + } + } + + return SCSIResult{Status: SCSIStatusGood, Data: data} +} + +func (h *SCSIHandler) doWrite(lba uint64, transferLen uint32, dataOut []byte) SCSIResult { + if transferLen == 0 { + return SCSIResult{Status: SCSIStatusGood} + } + + blockSize := h.dev.BlockSize() + totalBlocks := h.dev.VolumeSize() / uint64(blockSize) + + if lba+uint64(transferLen) > totalBlocks { + return illegalRequest(ASCLBAOutOfRange, ASCQLuk) + } + + expectedBytes := transferLen * blockSize + if uint32(len(dataOut)) < expectedBytes { + return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) + } + + if err := h.dev.WriteAt(lba, dataOut[:expectedBytes]); err != nil { + return SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: SenseMediumError, + SenseASC: 0x0C, // Write error + SenseASCQ: 0x00, + } + } + + return SCSIResult{Status: SCSIStatusGood} +} + +func (h *SCSIHandler) syncCache() SCSIResult { + if err := h.dev.SyncCache(); err != nil { + return SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: SenseHardwareError, + SenseASC: 0x00, + SenseASCQ: 0x00, + } + } + return SCSIResult{Status: SCSIStatusGood} +} + +func (h *SCSIHandler) unmap(cdb [16]byte, dataOut []byte) SCSIResult { + if len(dataOut) < 8 { + return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) + } + + // UNMAP parameter list header (8 bytes) + // descLen := binary.BigEndian.Uint16(dataOut[0:2]) // data length (unused) + blockDescLen := binary.BigEndian.Uint16(dataOut[2:4]) + + if int(blockDescLen)+8 > len(dataOut) { + return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) + } + + // Each UNMAP block descriptor is 16 bytes + descData := dataOut[8 : 8+blockDescLen] + for len(descData) >= 16 { + lba := binary.BigEndian.Uint64(descData[0:8]) + numBlocks := binary.BigEndian.Uint32(descData[8:12]) + // descData[12:16] reserved + + if numBlocks > 0 { + blockSize := h.dev.BlockSize() + if err := h.dev.Trim(lba, numBlocks*blockSize); err != nil { + return SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: SenseMediumError, + SenseASC: 0x0C, + SenseASCQ: 0x00, + } + } + } + descData = descData[16:] + } + + return SCSIResult{Status: SCSIStatusGood} +} + +// BuildSenseData constructs a fixed-format sense data buffer (18 bytes). +func BuildSenseData(key, asc, ascq uint8) []byte { + data := make([]byte, 18) + data[0] = 0x70 // Response code: current errors, fixed format + data[2] = key & 0x0f // Sense key + data[7] = 10 // Additional sense length + data[12] = asc // ASC + data[13] = ascq // ASCQ + return data +} + +func illegalRequest(asc, ascq uint8) SCSIResult { + return SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: SenseIllegalRequest, + SenseASC: asc, + SenseASCQ: ascq, + } +} + +func padRight(s string, n int) string { + if len(s) >= n { + return s[:n] + } + b := make([]byte, n) + copy(b, s) + for i := len(s); i < n; i++ { + b[i] = ' ' + } + return string(b) +} diff --git a/weed/storage/blockvol/iscsi/scsi_test.go b/weed/storage/blockvol/iscsi/scsi_test.go new file mode 100644 index 000000000..daa47f355 --- /dev/null +++ b/weed/storage/blockvol/iscsi/scsi_test.go @@ -0,0 +1,692 @@ +package iscsi + +import ( + "encoding/binary" + "errors" + "testing" +) + +// mockBlockDevice implements BlockDevice for testing. +type mockBlockDevice struct { + blockSize uint32 + volumeSize uint64 + healthy bool + blocks map[uint64][]byte // LBA -> data + syncErr error + readErr error + writeErr error + trimErr error +} + +func newMockDevice(volumeSize uint64) *mockBlockDevice { + return &mockBlockDevice{ + blockSize: 4096, + volumeSize: volumeSize, + healthy: true, + blocks: make(map[uint64][]byte), + } +} + +func (m *mockBlockDevice) ReadAt(lba uint64, length uint32) ([]byte, error) { + if m.readErr != nil { + return nil, m.readErr + } + blockCount := length / m.blockSize + result := make([]byte, length) + for i := uint32(0); i < blockCount; i++ { + if data, ok := m.blocks[lba+uint64(i)]; ok { + copy(result[i*m.blockSize:], data) + } + // Unwritten blocks return zeros (already zeroed) + } + return result, nil +} + +func (m *mockBlockDevice) WriteAt(lba uint64, data []byte) error { + if m.writeErr != nil { + return m.writeErr + } + blockCount := uint32(len(data)) / m.blockSize + for i := uint32(0); i < blockCount; i++ { + block := make([]byte, m.blockSize) + copy(block, data[i*m.blockSize:]) + m.blocks[lba+uint64(i)] = block + } + return nil +} + +func (m *mockBlockDevice) Trim(lba uint64, length uint32) error { + if m.trimErr != nil { + return m.trimErr + } + blockCount := length / m.blockSize + for i := uint32(0); i < blockCount; i++ { + delete(m.blocks, lba+uint64(i)) + } + return nil +} + +func (m *mockBlockDevice) SyncCache() error { return m.syncErr } +func (m *mockBlockDevice) BlockSize() uint32 { return m.blockSize } +func (m *mockBlockDevice) VolumeSize() uint64 { return m.volumeSize } +func (m *mockBlockDevice) IsHealthy() bool { return m.healthy } + +func TestSCSI(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"test_unit_ready_good", testTestUnitReadyGood}, + {"test_unit_ready_not_ready", testTestUnitReadyNotReady}, + {"inquiry_standard", testInquiryStandard}, + {"inquiry_vpd_supported_pages", testInquiryVPDSupportedPages}, + {"inquiry_vpd_serial", testInquiryVPDSerial}, + {"inquiry_vpd_device_id", testInquiryVPDDeviceID}, + {"inquiry_vpd_unknown_page", testInquiryVPDUnknownPage}, + {"inquiry_alloc_length", testInquiryAllocLength}, + {"read_capacity_10", testReadCapacity10}, + {"read_capacity_10_large", testReadCapacity10Large}, + {"read_capacity_16", testReadCapacity16}, + {"read_capacity_16_lbpme", testReadCapacity16LBPME}, + {"mode_sense_6", testModeSense6}, + {"report_luns", testReportLuns}, + {"unknown_opcode", testUnknownOpcode}, + {"read_10", testRead10}, + {"read_16", testRead16}, + {"write_10", testWrite10}, + {"write_16", testWrite16}, + {"read_write_roundtrip", testReadWriteRoundtrip}, + {"write_oob", testWriteOOB}, + {"read_oob", testReadOOB}, + {"zero_length_transfer", testZeroLengthTransfer}, + {"sync_cache", testSyncCache}, + {"sync_cache_error", testSyncCacheError}, + {"unmap_single", testUnmapSingle}, + {"unmap_multiple_descriptors", testUnmapMultipleDescriptors}, + {"unmap_short_param", testUnmapShortParam}, + {"build_sense_data", testBuildSenseData}, + {"read_error", testReadError}, + {"write_error", testWriteError}, + {"read_capacity_16_invalid_sa", testReadCapacity16InvalidSA}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testTestUnitReadyGood(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiTestUnitReady + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatalf("status: %d", r.Status) + } +} + +func testTestUnitReadyNotReady(t *testing.T) { + dev := newMockDevice(1024 * 1024) + dev.healthy = false + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiTestUnitReady + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusCheckCond { + t.Fatalf("status: %d", r.Status) + } + if r.SenseKey != SenseNotReady { + t.Fatalf("sense key: %d", r.SenseKey) + } +} + +func testInquiryStandard(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiInquiry + binary.BigEndian.PutUint16(cdb[3:5], 96) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatalf("status: %d", r.Status) + } + if len(r.Data) != 96 { + t.Fatalf("data length: %d", len(r.Data)) + } + // Check peripheral device type + if r.Data[0] != 0x00 { + t.Fatal("not SBC device type") + } + // Check vendor + vendor := string(r.Data[8:16]) + if vendor != "SeaweedF" { + t.Fatalf("vendor: %q", vendor) + } + // Check CmdQue + if r.Data[7]&0x02 == 0 { + t.Fatal("CmdQue not set") + } +} + +func testInquiryVPDSupportedPages(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiInquiry + cdb[1] = 0x01 // EVPD + cdb[2] = 0x00 // Supported pages + binary.BigEndian.PutUint16(cdb[3:5], 255) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatalf("status: %d", r.Status) + } + if r.Data[1] != 0x00 { + t.Fatal("wrong page code") + } + // Should list pages 0x00, 0x80, 0x83 + if len(r.Data) < 7 || r.Data[4] != 0x00 || r.Data[5] != 0x80 || r.Data[6] != 0x83 { + t.Fatalf("supported pages: %v", r.Data) + } +} + +func testInquiryVPDSerial(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiInquiry + cdb[1] = 0x01 + cdb[2] = 0x80 + binary.BigEndian.PutUint16(cdb[3:5], 255) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatalf("status: %d", r.Status) + } + if r.Data[1] != 0x80 { + t.Fatal("wrong page code") + } +} + +func testInquiryVPDDeviceID(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiInquiry + cdb[1] = 0x01 + cdb[2] = 0x83 + binary.BigEndian.PutUint16(cdb[3:5], 255) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatalf("status: %d", r.Status) + } + if r.Data[1] != 0x83 { + t.Fatal("wrong page code") + } +} + +func testInquiryVPDUnknownPage(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiInquiry + cdb[1] = 0x01 + cdb[2] = 0xFF // unknown page + binary.BigEndian.PutUint16(cdb[3:5], 255) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusCheckCond { + t.Fatal("expected CHECK_CONDITION") + } + if r.SenseKey != SenseIllegalRequest { + t.Fatal("expected ILLEGAL_REQUEST") + } +} + +func testInquiryAllocLength(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiInquiry + binary.BigEndian.PutUint16(cdb[3:5], 10) // small alloc + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("should succeed") + } + if len(r.Data) != 10 { + t.Fatalf("truncation: expected 10, got %d", len(r.Data)) + } +} + +func testReadCapacity10(t *testing.T) { + dev := newMockDevice(100 * 4096) // 100 blocks + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiReadCapacity10 + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("status not good") + } + if len(r.Data) != 8 { + t.Fatalf("data length: %d", len(r.Data)) + } + lastLBA := binary.BigEndian.Uint32(r.Data[0:4]) + blockSize := binary.BigEndian.Uint32(r.Data[4:8]) + if lastLBA != 99 { + t.Fatalf("last LBA: %d, expected 99", lastLBA) + } + if blockSize != 4096 { + t.Fatalf("block size: %d", blockSize) + } +} + +func testReadCapacity10Large(t *testing.T) { + // Volume with >2^32 blocks should return 0xFFFFFFFF + dev := newMockDevice(uint64(0x100000001) * 4096) // 2^32+1 blocks + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiReadCapacity10 + r := h.HandleCommand(cdb, nil) + lastLBA := binary.BigEndian.Uint32(r.Data[0:4]) + if lastLBA != 0xFFFFFFFF { + t.Fatalf("should return 0xFFFFFFFF for >2TB, got %d", lastLBA) + } +} + +func testReadCapacity16(t *testing.T) { + dev := newMockDevice(3 * 1024 * 1024 * 1024 * 1024) // 3 TB + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiReadCapacity16 + cdb[1] = ScsiSAReadCapacity16 + binary.BigEndian.PutUint32(cdb[10:14], 32) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("status not good") + } + lastLBA := binary.BigEndian.Uint64(r.Data[0:8]) + expectedBlocks := uint64(3*1024*1024*1024*1024) / 4096 + if lastLBA != expectedBlocks-1 { + t.Fatalf("last LBA: %d, expected %d", lastLBA, expectedBlocks-1) + } +} + +func testReadCapacity16LBPME(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiReadCapacity16 + cdb[1] = ScsiSAReadCapacity16 + binary.BigEndian.PutUint32(cdb[10:14], 32) + r := h.HandleCommand(cdb, nil) + // LBPME bit should be set (byte 14, bit 7) + if r.Data[14]&0x80 == 0 { + t.Fatal("LBPME bit not set") + } +} + +func testModeSense6(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiModeSense6 + cdb[4] = 255 + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("status not good") + } + if len(r.Data) != 4 { + t.Fatalf("mode sense data: %d bytes", len(r.Data)) + } + // No write protect + if r.Data[2]&0x80 != 0 { + t.Fatal("write protect set") + } +} + +func testReportLuns(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiReportLuns + binary.BigEndian.PutUint32(cdb[6:10], 256) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("status not good") + } + lunListLen := binary.BigEndian.Uint32(r.Data[0:4]) + if lunListLen != 8 { + t.Fatalf("LUN list length: %d (expected 8 for 1 LUN)", lunListLen) + } +} + +func testUnknownOpcode(t *testing.T) { + dev := newMockDevice(1024 * 1024) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = 0xFF + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusCheckCond { + t.Fatal("expected CHECK_CONDITION") + } + if r.SenseKey != SenseIllegalRequest { + t.Fatal("expected ILLEGAL_REQUEST") + } +} + +func testRead10(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + // Write some data first + data := make([]byte, 4096) + for i := range data { + data[i] = 0xAB + } + dev.blocks[5] = data + + var cdb [16]byte + cdb[0] = ScsiRead10 + binary.BigEndian.PutUint32(cdb[2:6], 5) // LBA=5 + binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("read failed") + } + if len(r.Data) != 4096 { + t.Fatalf("data length: %d", len(r.Data)) + } + if r.Data[0] != 0xAB { + t.Fatal("data mismatch") + } +} + +func testRead16(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + data := make([]byte, 4096) + data[0] = 0xCD + dev.blocks[10] = data + + var cdb [16]byte + cdb[0] = ScsiRead16 + binary.BigEndian.PutUint64(cdb[2:10], 10) + binary.BigEndian.PutUint32(cdb[10:14], 1) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("read16 failed") + } + if r.Data[0] != 0xCD { + t.Fatal("data mismatch") + } +} + +func testWrite10(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + dataOut := make([]byte, 4096) + dataOut[0] = 0xEF + + var cdb [16]byte + cdb[0] = ScsiWrite10 + binary.BigEndian.PutUint32(cdb[2:6], 7) + binary.BigEndian.PutUint16(cdb[7:9], 1) + r := h.HandleCommand(cdb, dataOut) + if r.Status != SCSIStatusGood { + t.Fatal("write failed") + } + if dev.blocks[7][0] != 0xEF { + t.Fatal("data not written") + } +} + +func testWrite16(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + dataOut := make([]byte, 8192) + dataOut[0] = 0x11 + dataOut[4096] = 0x22 + + var cdb [16]byte + cdb[0] = ScsiWrite16 + binary.BigEndian.PutUint64(cdb[2:10], 50) + binary.BigEndian.PutUint32(cdb[10:14], 2) + r := h.HandleCommand(cdb, dataOut) + if r.Status != SCSIStatusGood { + t.Fatal("write16 failed") + } + if dev.blocks[50][0] != 0x11 { + t.Fatal("block 50 wrong") + } + if dev.blocks[51][0] != 0x22 { + t.Fatal("block 51 wrong") + } +} + +func testReadWriteRoundtrip(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + // Write + dataOut := make([]byte, 4096) + for i := range dataOut { + dataOut[i] = byte(i % 256) + } + var wcdb [16]byte + wcdb[0] = ScsiWrite10 + binary.BigEndian.PutUint32(wcdb[2:6], 0) + binary.BigEndian.PutUint16(wcdb[7:9], 1) + h.HandleCommand(wcdb, dataOut) + + // Read back + var rcdb [16]byte + rcdb[0] = ScsiRead10 + binary.BigEndian.PutUint32(rcdb[2:6], 0) + binary.BigEndian.PutUint16(rcdb[7:9], 1) + r := h.HandleCommand(rcdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("read failed") + } + for i := 0; i < 4096; i++ { + if r.Data[i] != byte(i%256) { + t.Fatalf("byte %d: got %d, want %d", i, r.Data[i], i%256) + } + } +} + +func testWriteOOB(t *testing.T) { + dev := newMockDevice(10 * 4096) // 10 blocks + h := NewSCSIHandler(dev) + + var cdb [16]byte + cdb[0] = ScsiWrite10 + binary.BigEndian.PutUint32(cdb[2:6], 9) + binary.BigEndian.PutUint16(cdb[7:9], 2) // LBA 9 + 2 blocks > 10 + r := h.HandleCommand(cdb, make([]byte, 8192)) + if r.Status != SCSIStatusCheckCond { + t.Fatal("should fail for OOB") + } +} + +func testReadOOB(t *testing.T) { + dev := newMockDevice(10 * 4096) + h := NewSCSIHandler(dev) + + var cdb [16]byte + cdb[0] = ScsiRead10 + binary.BigEndian.PutUint32(cdb[2:6], 10) // LBA 10 == total blocks + binary.BigEndian.PutUint16(cdb[7:9], 1) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusCheckCond { + t.Fatal("should fail for OOB") + } +} + +func testZeroLengthTransfer(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + var cdb [16]byte + cdb[0] = ScsiRead10 + binary.BigEndian.PutUint32(cdb[2:6], 0) + binary.BigEndian.PutUint16(cdb[7:9], 0) // 0 blocks + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("zero-length read should succeed") + } +} + +func testSyncCache(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiSyncCache10 + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusGood { + t.Fatal("sync cache failed") + } +} + +func testSyncCacheError(t *testing.T) { + dev := newMockDevice(100 * 4096) + dev.syncErr = errors.New("disk error") + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiSyncCache10 + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusCheckCond { + t.Fatal("should fail") + } +} + +func testUnmapSingle(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + // Write data at LBA 5 + dev.blocks[5] = make([]byte, 4096) + dev.blocks[5][0] = 0xFF + + // UNMAP parameter list + unmapData := make([]byte, 24) // 8 header + 16 descriptor + binary.BigEndian.PutUint16(unmapData[0:2], 22) // data length + binary.BigEndian.PutUint16(unmapData[2:4], 16) // block desc length + binary.BigEndian.PutUint64(unmapData[8:16], 5) // LBA + binary.BigEndian.PutUint32(unmapData[16:20], 1) // num blocks + + var cdb [16]byte + cdb[0] = ScsiUnmap + r := h.HandleCommand(cdb, unmapData) + if r.Status != SCSIStatusGood { + t.Fatal("unmap failed") + } + if _, ok := dev.blocks[5]; ok { + t.Fatal("block 5 should be trimmed") + } +} + +func testUnmapMultipleDescriptors(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + + dev.blocks[3] = make([]byte, 4096) + dev.blocks[7] = make([]byte, 4096) + + // 2 descriptors + unmapData := make([]byte, 40) // 8 header + 2*16 descriptors + binary.BigEndian.PutUint16(unmapData[0:2], 38) + binary.BigEndian.PutUint16(unmapData[2:4], 32) + // Descriptor 1: LBA=3, count=1 + binary.BigEndian.PutUint64(unmapData[8:16], 3) + binary.BigEndian.PutUint32(unmapData[16:20], 1) + // Descriptor 2: LBA=7, count=1 + binary.BigEndian.PutUint64(unmapData[24:32], 7) + binary.BigEndian.PutUint32(unmapData[32:36], 1) + + var cdb [16]byte + cdb[0] = ScsiUnmap + r := h.HandleCommand(cdb, unmapData) + if r.Status != SCSIStatusGood { + t.Fatal("unmap failed") + } + if _, ok := dev.blocks[3]; ok { + t.Fatal("block 3 should be trimmed") + } + if _, ok := dev.blocks[7]; ok { + t.Fatal("block 7 should be trimmed") + } +} + +func testUnmapShortParam(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiUnmap + r := h.HandleCommand(cdb, []byte{1, 2, 3}) // too short + if r.Status != SCSIStatusCheckCond { + t.Fatal("should fail for short unmap params") + } +} + +func testBuildSenseData(t *testing.T) { + data := BuildSenseData(SenseIllegalRequest, ASCInvalidOpcode, ASCQLuk) + if len(data) != 18 { + t.Fatalf("length: %d", len(data)) + } + if data[0] != 0x70 { + t.Fatal("response code wrong") + } + if data[2] != SenseIllegalRequest { + t.Fatal("sense key wrong") + } + if data[12] != ASCInvalidOpcode { + t.Fatal("ASC wrong") + } +} + +func testReadError(t *testing.T) { + dev := newMockDevice(100 * 4096) + dev.readErr = errors.New("io error") + h := NewSCSIHandler(dev) + + var cdb [16]byte + cdb[0] = ScsiRead10 + binary.BigEndian.PutUint32(cdb[2:6], 0) + binary.BigEndian.PutUint16(cdb[7:9], 1) + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusCheckCond { + t.Fatal("should fail") + } + if r.SenseKey != SenseMediumError { + t.Fatal("should be MEDIUM_ERROR") + } +} + +func testWriteError(t *testing.T) { + dev := newMockDevice(100 * 4096) + dev.writeErr = errors.New("io error") + h := NewSCSIHandler(dev) + + var cdb [16]byte + cdb[0] = ScsiWrite10 + binary.BigEndian.PutUint32(cdb[2:6], 0) + binary.BigEndian.PutUint16(cdb[7:9], 1) + r := h.HandleCommand(cdb, make([]byte, 4096)) + if r.Status != SCSIStatusCheckCond { + t.Fatal("should fail") + } +} + +func testReadCapacity16InvalidSA(t *testing.T) { + dev := newMockDevice(100 * 4096) + h := NewSCSIHandler(dev) + var cdb [16]byte + cdb[0] = ScsiReadCapacity16 + cdb[1] = 0x05 // wrong service action + r := h.HandleCommand(cdb, nil) + if r.Status != SCSIStatusCheckCond { + t.Fatal("should fail for wrong SA") + } +} diff --git a/weed/storage/blockvol/iscsi/session.go b/weed/storage/blockvol/iscsi/session.go new file mode 100644 index 000000000..898a869aa --- /dev/null +++ b/weed/storage/blockvol/iscsi/session.go @@ -0,0 +1,421 @@ +package iscsi + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "sync" + "sync/atomic" +) + +var ( + ErrSessionClosed = errors.New("iscsi: session closed") + ErrCmdSNOutOfWindow = errors.New("iscsi: CmdSN out of window") +) + +// SessionState tracks the lifecycle of an iSCSI session. +type SessionState int + +const ( + SessionLogin SessionState = iota // login phase + SessionLoggedIn // full feature phase + SessionLogout // logout requested + SessionClosed // terminated +) + +// Session manages a single iSCSI session (one initiator connection). +type Session struct { + mu sync.Mutex + + state SessionState + conn net.Conn + scsi *SCSIHandler + config TargetConfig + resolver TargetResolver + devices DeviceLookup + + // Sequence numbers + expCmdSN atomic.Uint32 // expected CmdSN from initiator + maxCmdSN atomic.Uint32 // max CmdSN we allow + statSN uint32 // target status sequence number + + // Login state + negotiator *LoginNegotiator + loginDone bool + + // Data sequencing + dataInWriter *DataInWriter + + // Shutdown + closed atomic.Bool + closeErr error + + // Logging + logger *log.Logger +} + +// NewSession creates a new iSCSI session on the given connection. +func NewSession(conn net.Conn, config TargetConfig, resolver TargetResolver, devices DeviceLookup, logger *log.Logger) *Session { + if logger == nil { + logger = log.Default() + } + s := &Session{ + state: SessionLogin, + conn: conn, + config: config, + resolver: resolver, + devices: devices, + negotiator: NewLoginNegotiator(config), + logger: logger, + } + s.expCmdSN.Store(1) + s.maxCmdSN.Store(32) // window of 32 commands + return s +} + +// HandleConnection processes PDUs until the connection is closed or an error occurs. +func (s *Session) HandleConnection() error { + defer s.close() + + for !s.closed.Load() { + pdu, err := ReadPDU(s.conn) + if err != nil { + if s.closed.Load() { + return nil + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + return nil + } + return fmt.Errorf("read PDU: %w", err) + } + + if err := s.dispatch(pdu); err != nil { + if s.closed.Load() { + return nil + } + return fmt.Errorf("dispatch %s: %w", OpcodeName(pdu.Opcode()), err) + } + } + return nil +} + +// Close terminates the session. +func (s *Session) Close() error { + s.closed.Store(true) + return s.conn.Close() +} + +func (s *Session) close() { + s.closed.Store(true) + s.mu.Lock() + s.state = SessionClosed + s.mu.Unlock() + s.conn.Close() +} + +func (s *Session) dispatch(pdu *PDU) error { + op := pdu.Opcode() + + switch op { + case OpLoginReq: + return s.handleLogin(pdu) + case OpTextReq: + return s.handleText(pdu) + case OpSCSICmd: + return s.handleSCSICmd(pdu) + case OpSCSIDataOut: + // Handled inline during write command processing + return nil + case OpNOPOut: + return s.handleNOPOut(pdu) + case OpLogoutReq: + return s.handleLogout(pdu) + case OpSCSITaskMgmt: + return s.handleTaskMgmt(pdu) + default: + s.logger.Printf("unhandled opcode: %s", OpcodeName(op)) + return s.sendReject(pdu, 0x04) // command not supported + } +} + +func (s *Session) handleLogin(pdu *PDU) error { + s.mu.Lock() + defer s.mu.Unlock() + + resp := s.negotiator.HandleLoginPDU(pdu, s.resolver) + + // Set sequence numbers + resp.SetStatSN(s.statSN) + s.statSN++ + resp.SetExpCmdSN(s.expCmdSN.Load()) + resp.SetMaxCmdSN(s.maxCmdSN.Load()) + + if err := WritePDU(s.conn, resp); err != nil { + return err + } + + if s.negotiator.Done() { + s.loginDone = true + s.state = SessionLoggedIn + result := s.negotiator.Result() + s.dataInWriter = NewDataInWriter(uint32(result.MaxRecvDataSegLen)) + + // Bind SCSI handler to the device for the target the initiator logged into + if s.devices != nil && result.TargetName != "" { + dev := s.devices.LookupDevice(result.TargetName) + s.scsi = NewSCSIHandler(dev) + } + + s.logger.Printf("login complete: initiator=%s target=%s session=%s", + result.InitiatorName, result.TargetName, result.SessionType) + } + + return nil +} + +func (s *Session) handleText(pdu *PDU) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Gather discovery targets from resolver if it supports listing + var targets []DiscoveryTarget + if lister, ok := s.resolver.(TargetLister); ok { + targets = lister.ListTargets() + } + + resp := HandleTextRequest(pdu, targets) + resp.SetStatSN(s.statSN) + s.statSN++ + resp.SetExpCmdSN(s.expCmdSN.Load()) + resp.SetMaxCmdSN(s.maxCmdSN.Load()) + + return WritePDU(s.conn, resp) +} + +func (s *Session) handleSCSICmd(pdu *PDU) error { + if !s.loginDone { + return s.sendReject(pdu, 0x0b) // protocol error + } + + cdb := pdu.CDB() + itt := pdu.InitiatorTaskTag() + flags := pdu.OpSpecific1() + + // Advance CmdSN + if !pdu.Immediate() { + s.advanceCmdSN() + } + + isWrite := flags&FlagW != 0 + isRead := flags&FlagR != 0 + expectedLen := pdu.ExpectedDataTransferLength() + + // Handle write commands — collect data + var dataOut []byte + if isWrite && expectedLen > 0 { + collector := NewDataOutCollector(expectedLen) + + // Immediate data + if len(pdu.DataSegment) > 0 { + if err := collector.AddImmediateData(pdu.DataSegment); err != nil { + return s.sendCheckCondition(itt, SenseIllegalRequest, ASCInvalidFieldInCDB, ASCQLuk) + } + } + + // If more data needed, send R2T and collect Data-Out PDUs + if !collector.Done() { + if err := s.collectDataOut(collector, itt); err != nil { + return err + } + } + + dataOut = collector.Data() + } + + // Execute SCSI command + result := s.scsi.HandleCommand(cdb, dataOut) + + // Send response + s.mu.Lock() + expCmdSN := s.expCmdSN.Load() + maxCmdSN := s.maxCmdSN.Load() + s.mu.Unlock() + + if isRead && result.Status == SCSIStatusGood && len(result.Data) > 0 { + // Send Data-In PDUs + s.mu.Lock() + _, err := s.dataInWriter.WriteDataIn(s.conn, result.Data, itt, expCmdSN, maxCmdSN, &s.statSN) + s.mu.Unlock() + return err + } + + // Send SCSI Response + s.mu.Lock() + err := SendSCSIResponse(s.conn, result, itt, &s.statSN, expCmdSN, maxCmdSN) + s.mu.Unlock() + return err +} + +func (s *Session) collectDataOut(collector *DataOutCollector, itt uint32) error { + var r2tSN uint32 + ttt := itt // use ITT as TTT for simplicity + + for !collector.Done() { + // Send R2T + s.mu.Lock() + r2t := BuildR2T(itt, ttt, r2tSN, s.totalReceived(collector), collector.Remaining(), + s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load()) + s.mu.Unlock() + + if err := WritePDU(s.conn, r2t); err != nil { + return err + } + r2tSN++ + + // Read Data-Out PDUs until F-bit + for { + doPDU, err := ReadPDU(s.conn) + if err != nil { + return err + } + if doPDU.Opcode() != OpSCSIDataOut { + return fmt.Errorf("expected Data-Out, got %s", OpcodeName(doPDU.Opcode())) + } + if err := collector.AddDataOut(doPDU); err != nil { + return err + } + if doPDU.OpSpecific1()&FlagF != 0 { + break + } + } + } + return nil +} + +func (s *Session) totalReceived(c *DataOutCollector) uint32 { + return c.expectedLen - c.Remaining() +} + +func (s *Session) handleNOPOut(pdu *PDU) error { + resp := &PDU{} + resp.SetOpcode(OpNOPIn) + resp.SetOpSpecific1(FlagF) + resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) + resp.SetTargetTransferTag(0xFFFFFFFF) + + s.mu.Lock() + resp.SetStatSN(s.statSN) + s.statSN++ + resp.SetExpCmdSN(s.expCmdSN.Load()) + resp.SetMaxCmdSN(s.maxCmdSN.Load()) + s.mu.Unlock() + + // Echo back data if present + if len(pdu.DataSegment) > 0 { + resp.DataSegment = pdu.DataSegment + } + + return WritePDU(s.conn, resp) +} + +func (s *Session) handleLogout(pdu *PDU) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.state = SessionLogout + + resp := &PDU{} + resp.SetOpcode(OpLogoutResp) + resp.SetOpSpecific1(FlagF) + resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) + resp.BHS[2] = 0x00 // response: connection/session closed successfully + resp.SetStatSN(s.statSN) + s.statSN++ + resp.SetExpCmdSN(s.expCmdSN.Load()) + resp.SetMaxCmdSN(s.maxCmdSN.Load()) + + if err := WritePDU(s.conn, resp); err != nil { + return err + } + + // Signal HandleConnection to exit + s.closed.Store(true) + s.conn.Close() + return nil +} + +func (s *Session) handleTaskMgmt(pdu *PDU) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Simplified: always respond with "function complete" + resp := &PDU{} + resp.SetOpcode(OpSCSITaskResp) + resp.SetOpSpecific1(FlagF) + resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) + resp.BHS[2] = 0x00 // function complete + resp.SetStatSN(s.statSN) + s.statSN++ + resp.SetExpCmdSN(s.expCmdSN.Load()) + resp.SetMaxCmdSN(s.maxCmdSN.Load()) + + return WritePDU(s.conn, resp) +} + +func (s *Session) advanceCmdSN() { + s.expCmdSN.Add(1) + s.maxCmdSN.Add(1) +} + +func (s *Session) sendReject(origPDU *PDU, reason uint8) error { + resp := &PDU{} + resp.SetOpcode(OpReject) + resp.SetOpSpecific1(FlagF) + resp.BHS[2] = reason + resp.SetInitiatorTaskTag(0xFFFFFFFF) + + s.mu.Lock() + resp.SetStatSN(s.statSN) + s.statSN++ + resp.SetExpCmdSN(s.expCmdSN.Load()) + resp.SetMaxCmdSN(s.maxCmdSN.Load()) + s.mu.Unlock() + + // Include the rejected BHS in the data segment + resp.DataSegment = origPDU.BHS[:] + + return WritePDU(s.conn, resp) +} + +func (s *Session) sendCheckCondition(itt uint32, senseKey, asc, ascq uint8) error { + result := SCSIResult{ + Status: SCSIStatusCheckCond, + SenseKey: senseKey, + SenseASC: asc, + SenseASCQ: ascq, + } + s.mu.Lock() + defer s.mu.Unlock() + return SendSCSIResponse(s.conn, result, itt, &s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load()) +} + +// TargetLister is an optional interface that TargetResolver can implement +// to support SendTargets discovery. +type TargetLister interface { + ListTargets() []DiscoveryTarget +} + +// DeviceLookup resolves a target IQN to a BlockDevice. +// Used after login to bind the session to the correct volume. +type DeviceLookup interface { + LookupDevice(iqn string) BlockDevice +} + +// State returns the current session state. +func (s *Session) State() SessionState { + s.mu.Lock() + defer s.mu.Unlock() + return s.state +} diff --git a/weed/storage/blockvol/iscsi/session_test.go b/weed/storage/blockvol/iscsi/session_test.go new file mode 100644 index 000000000..3b30427fb --- /dev/null +++ b/weed/storage/blockvol/iscsi/session_test.go @@ -0,0 +1,364 @@ +package iscsi + +import ( + "bytes" + "encoding/binary" + "io" + "log" + "net" + "testing" + "time" +) + +// testResolver implements TargetResolver, TargetLister, and DeviceLookup. +type testResolver struct { + targets []DiscoveryTarget + dev BlockDevice +} + +func (r *testResolver) HasTarget(name string) bool { + for _, t := range r.targets { + if t.Name == name { + return true + } + } + return false +} + +func (r *testResolver) ListTargets() []DiscoveryTarget { + return r.targets +} + +func (r *testResolver) LookupDevice(iqn string) BlockDevice { + if r.dev != nil { + return r.dev + } + return &nullDevice{} +} + +func newTestResolver() *testResolver { + return newTestResolverWithDevice(nil) +} + +func newTestResolverWithDevice(dev BlockDevice) *testResolver { + return &testResolver{ + targets: []DiscoveryTarget{ + {Name: testTargetName, Address: "127.0.0.1:3260,1"}, + }, + dev: dev, + } +} + +func TestSession(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"login_and_read", testLoginAndRead}, + {"login_and_write", testLoginAndWrite}, + {"nop_ping", testNOPPing}, + {"logout", testLogout}, + {"discovery_session", testDiscoverySession}, + {"task_mgmt", testTaskMgmt}, + {"reject_scsi_before_login", testRejectSCSIBeforeLogin}, + {"connection_close", testConnectionClose}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +type sessionTestEnv struct { + clientConn net.Conn + session *Session + done chan error +} + +func setupSession(t *testing.T) *sessionTestEnv { + t.Helper() + client, server := net.Pipe() + + dev := newMockDevice(1024 * 4096) // 1024 blocks + config := DefaultTargetConfig() + config.TargetName = testTargetName + resolver := newTestResolverWithDevice(dev) + logger := log.New(io.Discard, "", 0) + + sess := NewSession(server, config, resolver, resolver, logger) + done := make(chan error, 1) + go func() { + done <- sess.HandleConnection() + }() + + t.Cleanup(func() { + sess.Close() + client.Close() + }) + + return &sessionTestEnv{clientConn: client, session: sess, done: done} +} + +func doLogin(t *testing.T, conn net.Conn) { + t.Helper() + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + params.Set("SessionType", "Normal") + + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + req.SetCmdSN(1) + + if err := WritePDU(conn, req); err != nil { + t.Fatal(err) + } + + resp, err := ReadPDU(conn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpLoginResp { + t.Fatalf("expected LoginResp, got %s", OpcodeName(resp.Opcode())) + } + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) + } +} + +func testLoginAndRead(t *testing.T) { + env := setupSession(t) + doLogin(t, env.clientConn) + + // Send SCSI READ_10 for 1 block at LBA 0 + cmd := &PDU{} + cmd.SetOpcode(OpSCSICmd) + cmd.SetOpSpecific1(FlagF | FlagR) // Final + Read + cmd.SetInitiatorTaskTag(1) + cmd.SetExpectedDataTransferLength(4096) + cmd.SetCmdSN(2) + cmd.SetExpStatSN(2) + + var cdb [16]byte + cdb[0] = ScsiRead10 + binary.BigEndian.PutUint32(cdb[2:6], 0) // LBA=0 + binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block + cmd.SetCDB(cdb) + + if err := WritePDU(env.clientConn, cmd); err != nil { + t.Fatal(err) + } + + // Expect Data-In with S-bit (single PDU, 4096 bytes of zeros) + resp, err := ReadPDU(env.clientConn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpSCSIDataIn { + t.Fatalf("expected Data-In, got %s", OpcodeName(resp.Opcode())) + } + if resp.OpSpecific1()&FlagS == 0 { + t.Fatal("S-bit not set") + } + if len(resp.DataSegment) != 4096 { + t.Fatalf("data length: %d", len(resp.DataSegment)) + } +} + +func testLoginAndWrite(t *testing.T) { + env := setupSession(t) + doLogin(t, env.clientConn) + + // SCSI WRITE_10: 1 block at LBA 5 with immediate data + cmd := &PDU{} + cmd.SetOpcode(OpSCSICmd) + cmd.SetOpSpecific1(FlagF | FlagW) // Final + Write + cmd.SetInitiatorTaskTag(1) + cmd.SetExpectedDataTransferLength(4096) + cmd.SetCmdSN(2) + + var cdb [16]byte + cdb[0] = ScsiWrite10 + binary.BigEndian.PutUint32(cdb[2:6], 5) + binary.BigEndian.PutUint16(cdb[7:9], 1) + cmd.SetCDB(cdb) + + // Include immediate data + cmd.DataSegment = bytes.Repeat([]byte{0xAA}, 4096) + + if err := WritePDU(env.clientConn, cmd); err != nil { + t.Fatal(err) + } + + // Expect SCSI Response (good) + resp, err := ReadPDU(env.clientConn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpSCSIResp { + t.Fatalf("expected SCSI Response, got %s", OpcodeName(resp.Opcode())) + } + if resp.SCSIStatus() != SCSIStatusGood { + t.Fatalf("status: %d", resp.SCSIStatus()) + } +} + +func testNOPPing(t *testing.T) { + env := setupSession(t) + doLogin(t, env.clientConn) + + // Send NOP-Out + nop := &PDU{} + nop.SetOpcode(OpNOPOut) + nop.SetOpSpecific1(FlagF) + nop.SetInitiatorTaskTag(0x9999) + nop.SetImmediate(true) + nop.DataSegment = []byte("ping") + + if err := WritePDU(env.clientConn, nop); err != nil { + t.Fatal(err) + } + + resp, err := ReadPDU(env.clientConn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpNOPIn { + t.Fatalf("expected NOP-In, got %s", OpcodeName(resp.Opcode())) + } + if resp.InitiatorTaskTag() != 0x9999 { + t.Fatal("ITT mismatch") + } + if string(resp.DataSegment) != "ping" { + t.Fatalf("echo data: %q", resp.DataSegment) + } +} + +func testLogout(t *testing.T) { + env := setupSession(t) + doLogin(t, env.clientConn) + + // Send Logout + logout := &PDU{} + logout.SetOpcode(OpLogoutReq) + logout.SetOpSpecific1(FlagF) + logout.SetInitiatorTaskTag(0xAAAA) + logout.SetCmdSN(2) + + if err := WritePDU(env.clientConn, logout); err != nil { + t.Fatal(err) + } + + resp, err := ReadPDU(env.clientConn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpLogoutResp { + t.Fatalf("expected Logout Response, got %s", OpcodeName(resp.Opcode())) + } + + // After logout, the connection should be closed by the target. + // Verify by trying to read — should get EOF. + _, err = ReadPDU(env.clientConn) + if err == nil { + t.Fatal("expected EOF after logout") + } +} + +func testDiscoverySession(t *testing.T) { + env := setupSession(t) + + // Login as discovery session + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("SessionType", "Discovery") + + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + req.SetCmdSN(1) + WritePDU(env.clientConn, req) + ReadPDU(env.clientConn) // login resp + + // Send SendTargets=All + textParams := NewParams() + textParams.Set("SendTargets", "All") + textReq := makeTextReq(textParams) + textReq.SetCmdSN(2) + + if err := WritePDU(env.clientConn, textReq); err != nil { + t.Fatal(err) + } + + resp, err := ReadPDU(env.clientConn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpTextResp { + t.Fatalf("expected Text Response, got %s", OpcodeName(resp.Opcode())) + } + body := string(resp.DataSegment) + if len(body) == 0 { + t.Fatal("empty discovery response") + } +} + +func testTaskMgmt(t *testing.T) { + env := setupSession(t) + doLogin(t, env.clientConn) + + tm := &PDU{} + tm.SetOpcode(OpSCSITaskMgmt) + tm.SetOpSpecific1(FlagF | 0x01) // ABORT TASK + tm.SetInitiatorTaskTag(0xBBBB) + tm.SetImmediate(true) + + if err := WritePDU(env.clientConn, tm); err != nil { + t.Fatal(err) + } + + resp, err := ReadPDU(env.clientConn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpSCSITaskResp { + t.Fatalf("expected Task Mgmt Response, got %s", OpcodeName(resp.Opcode())) + } +} + +func testRejectSCSIBeforeLogin(t *testing.T) { + env := setupSession(t) + + // Send SCSI command without login + cmd := &PDU{} + cmd.SetOpcode(OpSCSICmd) + cmd.SetOpSpecific1(FlagF | FlagR) + cmd.SetInitiatorTaskTag(1) + + if err := WritePDU(env.clientConn, cmd); err != nil { + t.Fatal(err) + } + + resp, err := ReadPDU(env.clientConn) + if err != nil { + t.Fatal(err) + } + if resp.Opcode() != OpReject { + t.Fatalf("expected Reject, got %s", OpcodeName(resp.Opcode())) + } +} + +func testConnectionClose(t *testing.T) { + env := setupSession(t) + doLogin(t, env.clientConn) + + // Close client side + env.clientConn.Close() + + select { + case err := <-env.done: + if err != nil { + t.Fatalf("unexpected error on clean close: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("session did not detect close") + } +} diff --git a/weed/storage/blockvol/iscsi/target.go b/weed/storage/blockvol/iscsi/target.go new file mode 100644 index 000000000..8746a22ce --- /dev/null +++ b/weed/storage/blockvol/iscsi/target.go @@ -0,0 +1,216 @@ +package iscsi + +import ( + "errors" + "fmt" + "log" + "net" + "sync" + "sync/atomic" +) + +var ( + ErrTargetClosed = errors.New("iscsi: target server closed") + ErrVolumeNotFound = errors.New("iscsi: volume not found") +) + +// TargetServer manages the iSCSI target: TCP listener, volume registry, +// and active sessions. +type TargetServer struct { + mu sync.RWMutex + listener net.Listener + config TargetConfig + volumes map[string]BlockDevice // target IQN -> device + addr string + + // Active session tracking for graceful shutdown + activeMu sync.Mutex + active map[uint64]*Session + nextID atomic.Uint64 + + sessions sync.WaitGroup + closed chan struct{} + logger *log.Logger +} + +// NewTargetServer creates a target server bound to the given address. +func NewTargetServer(addr string, config TargetConfig, logger *log.Logger) *TargetServer { + if logger == nil { + logger = log.Default() + } + return &TargetServer{ + config: config, + volumes: make(map[string]BlockDevice), + active: make(map[uint64]*Session), + addr: addr, + closed: make(chan struct{}), + logger: logger, + } +} + +// AddVolume registers a block device under the given target IQN. +func (ts *TargetServer) AddVolume(iqn string, dev BlockDevice) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.volumes[iqn] = dev + ts.logger.Printf("volume added: %s", iqn) +} + +// RemoveVolume unregisters a target IQN. +func (ts *TargetServer) RemoveVolume(iqn string) { + ts.mu.Lock() + defer ts.mu.Unlock() + delete(ts.volumes, iqn) + ts.logger.Printf("volume removed: %s", iqn) +} + +// HasTarget implements TargetResolver. +func (ts *TargetServer) HasTarget(name string) bool { + ts.mu.RLock() + defer ts.mu.RUnlock() + _, ok := ts.volumes[name] + return ok +} + +// ListTargets implements TargetLister. +func (ts *TargetServer) ListTargets() []DiscoveryTarget { + ts.mu.RLock() + defer ts.mu.RUnlock() + targets := make([]DiscoveryTarget, 0, len(ts.volumes)) + for iqn := range ts.volumes { + targets = append(targets, DiscoveryTarget{ + Name: iqn, + Address: ts.ListenAddr(), + }) + } + return targets +} + +// ListenAndServe starts listening for iSCSI connections. +// Blocks until Close() is called or an error occurs. +func (ts *TargetServer) ListenAndServe() error { + ln, err := net.Listen("tcp", ts.addr) + if err != nil { + return fmt.Errorf("iscsi target: listen: %w", err) + } + ts.mu.Lock() + ts.listener = ln + ts.mu.Unlock() + + ts.logger.Printf("iSCSI target listening on %s", ln.Addr()) + return ts.acceptLoop(ln) +} + +// Serve accepts connections on an existing listener. +func (ts *TargetServer) Serve(ln net.Listener) error { + ts.mu.Lock() + ts.listener = ln + ts.mu.Unlock() + return ts.acceptLoop(ln) +} + +func (ts *TargetServer) acceptLoop(ln net.Listener) error { + for { + conn, err := ln.Accept() + if err != nil { + select { + case <-ts.closed: + return nil + default: + return fmt.Errorf("iscsi target: accept: %w", err) + } + } + + ts.sessions.Add(1) + go ts.handleConn(conn) + } +} + +func (ts *TargetServer) handleConn(conn net.Conn) { + defer ts.sessions.Done() + defer conn.Close() + + ts.logger.Printf("new connection from %s", conn.RemoteAddr()) + + sess := NewSession(conn, ts.config, ts, ts, ts.logger) + + id := ts.nextID.Add(1) + ts.activeMu.Lock() + ts.active[id] = sess + ts.activeMu.Unlock() + + defer func() { + ts.activeMu.Lock() + delete(ts.active, id) + ts.activeMu.Unlock() + }() + + if err := sess.HandleConnection(); err != nil { + ts.logger.Printf("session error (%s): %v", conn.RemoteAddr(), err) + } + + ts.logger.Printf("connection closed: %s", conn.RemoteAddr()) +} + +// LookupDevice implements DeviceLookup. Returns the BlockDevice for the given IQN, +// or a nullDevice if not found. +func (ts *TargetServer) LookupDevice(iqn string) BlockDevice { + ts.mu.RLock() + defer ts.mu.RUnlock() + + if dev, ok := ts.volumes[iqn]; ok { + return dev + } + return &nullDevice{} +} + +// Close gracefully shuts down the target server. +func (ts *TargetServer) Close() error { + select { + case <-ts.closed: + return nil + default: + } + close(ts.closed) + + ts.mu.RLock() + ln := ts.listener + ts.mu.RUnlock() + + if ln != nil { + ln.Close() + } + + // Close all active sessions to unblock ReadPDU + ts.activeMu.Lock() + for _, sess := range ts.active { + sess.Close() + } + ts.activeMu.Unlock() + + ts.sessions.Wait() + return nil +} + +// ListenAddr returns the actual listen address (useful when port=0). +func (ts *TargetServer) ListenAddr() string { + ts.mu.RLock() + defer ts.mu.RUnlock() + if ts.listener != nil { + return ts.listener.Addr().String() + } + return ts.addr +} + +// nullDevice is a stub BlockDevice used when no volumes are registered. +type nullDevice struct{} + +func (d *nullDevice) ReadAt(lba uint64, length uint32) ([]byte, error) { + return make([]byte, length), nil +} +func (d *nullDevice) WriteAt(lba uint64, data []byte) error { return nil } +func (d *nullDevice) Trim(lba uint64, length uint32) error { return nil } +func (d *nullDevice) SyncCache() error { return nil } +func (d *nullDevice) BlockSize() uint32 { return 4096 } +func (d *nullDevice) VolumeSize() uint64 { return 0 } +func (d *nullDevice) IsHealthy() bool { return false } diff --git a/weed/storage/blockvol/iscsi/target_test.go b/weed/storage/blockvol/iscsi/target_test.go new file mode 100644 index 000000000..58f80b4c2 --- /dev/null +++ b/weed/storage/blockvol/iscsi/target_test.go @@ -0,0 +1,259 @@ +package iscsi + +import ( + "encoding/binary" + "io" + "log" + "net" + "strings" + "testing" + "time" +) + +func TestTarget(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {"listen_and_connect", testListenAndConnect}, + {"discovery_via_target", testDiscoveryViaTarget}, + {"login_read_write", testTargetLoginReadWrite}, + {"graceful_shutdown", testGracefulShutdown}, + {"add_remove_volume", testAddRemoveVolume}, + {"multiple_connections", testMultipleConnections}, + {"connect_no_volumes", testConnectNoVolumes}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func setupTarget(t *testing.T) (*TargetServer, string) { + t.Helper() + config := DefaultTargetConfig() + config.TargetName = testTargetName + logger := log.New(io.Discard, "", 0) + ts := NewTargetServer("127.0.0.1:0", config, logger) + + dev := newMockDevice(256 * 4096) // 256 blocks = 1MB + ts.AddVolume(testTargetName, dev) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + go ts.Serve(ln) + + addr := ln.Addr().String() + t.Cleanup(func() { + ts.Close() + }) + + return ts, addr +} + +func dialTarget(t *testing.T, addr string) net.Conn { + t.Helper() + conn, err := net.DialTimeout("tcp", addr, 2*time.Second) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + return conn +} + +func loginToTarget(t *testing.T, conn net.Conn) { + t.Helper() + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("TargetName", testTargetName) + params.Set("SessionType", "Normal") + + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + req.SetCmdSN(1) + + if err := WritePDU(conn, req); err != nil { + t.Fatal(err) + } + + resp, err := ReadPDU(conn) + if err != nil { + t.Fatal(err) + } + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) + } +} + +func testListenAndConnect(t *testing.T) { + _, addr := setupTarget(t) + + conn := dialTarget(t, addr) + loginToTarget(t, conn) +} + +func testDiscoveryViaTarget(t *testing.T) { + _, addr := setupTarget(t) + + conn := dialTarget(t, addr) + + // Login as discovery + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("SessionType", "Discovery") + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + req.SetCmdSN(1) + WritePDU(conn, req) + ReadPDU(conn) + + // SendTargets=All + textParams := NewParams() + textParams.Set("SendTargets", "All") + textReq := makeTextReq(textParams) + textReq.SetCmdSN(2) + WritePDU(conn, textReq) + + resp, err := ReadPDU(conn) + if err != nil { + t.Fatal(err) + } + body := string(resp.DataSegment) + if !strings.Contains(body, testTargetName) { + t.Fatalf("discovery response missing target: %q", body) + } +} + +func testTargetLoginReadWrite(t *testing.T) { + _, addr := setupTarget(t) + conn := dialTarget(t, addr) + loginToTarget(t, conn) + + // Write 1 block at LBA 0 + cmd := &PDU{} + cmd.SetOpcode(OpSCSICmd) + cmd.SetOpSpecific1(FlagF | FlagW) + cmd.SetInitiatorTaskTag(1) + cmd.SetExpectedDataTransferLength(4096) + cmd.SetCmdSN(2) + var cdb [16]byte + cdb[0] = ScsiWrite10 + binary.BigEndian.PutUint32(cdb[2:6], 0) + binary.BigEndian.PutUint16(cdb[7:9], 1) + cmd.SetCDB(cdb) + data := make([]byte, 4096) + for i := range data { + data[i] = 0xAB + } + cmd.DataSegment = data + + WritePDU(conn, cmd) + resp, _ := ReadPDU(conn) + if resp.SCSIStatus() != SCSIStatusGood { + t.Fatalf("write failed: %d", resp.SCSIStatus()) + } + + // Read it back + cmd2 := &PDU{} + cmd2.SetOpcode(OpSCSICmd) + cmd2.SetOpSpecific1(FlagF | FlagR) + cmd2.SetInitiatorTaskTag(2) + cmd2.SetExpectedDataTransferLength(4096) + cmd2.SetCmdSN(3) + var cdb2 [16]byte + cdb2[0] = ScsiRead10 + binary.BigEndian.PutUint32(cdb2[2:6], 0) + binary.BigEndian.PutUint16(cdb2[7:9], 1) + cmd2.SetCDB(cdb2) + + WritePDU(conn, cmd2) + resp2, _ := ReadPDU(conn) + if resp2.Opcode() != OpSCSIDataIn { + t.Fatalf("expected Data-In, got %s", OpcodeName(resp2.Opcode())) + } + if resp2.DataSegment[0] != 0xAB { + t.Fatal("data mismatch") + } +} + +func testGracefulShutdown(t *testing.T) { + ts, addr := setupTarget(t) + + conn := dialTarget(t, addr) + loginToTarget(t, conn) + + // Close the target — should shut down cleanly + ts.Close() + + // Connection should be dropped + _, err := ReadPDU(conn) + if err == nil { + t.Fatal("expected error after shutdown") + } +} + +func testAddRemoveVolume(t *testing.T) { + config := DefaultTargetConfig() + logger := log.New(io.Discard, "", 0) + ts := NewTargetServer("127.0.0.1:0", config, logger) + + ts.AddVolume("iqn.2024.com.test:v1", newMockDevice(1024*4096)) + ts.AddVolume("iqn.2024.com.test:v2", newMockDevice(1024*4096)) + + if !ts.HasTarget("iqn.2024.com.test:v1") { + t.Fatal("v1 should exist") + } + + ts.RemoveVolume("iqn.2024.com.test:v1") + if ts.HasTarget("iqn.2024.com.test:v1") { + t.Fatal("v1 should be removed") + } + if !ts.HasTarget("iqn.2024.com.test:v2") { + t.Fatal("v2 should still exist") + } + + targets := ts.ListTargets() + if len(targets) != 1 { + t.Fatalf("expected 1 target, got %d", len(targets)) + } +} + +func testMultipleConnections(t *testing.T) { + _, addr := setupTarget(t) + + // Connect multiple clients + for i := 0; i < 5; i++ { + conn := dialTarget(t, addr) + loginToTarget(t, conn) + } +} + +func testConnectNoVolumes(t *testing.T) { + config := DefaultTargetConfig() + logger := log.New(io.Discard, "", 0) + ts := NewTargetServer("127.0.0.1:0", config, logger) + // No volumes added + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + go ts.Serve(ln) + t.Cleanup(func() { ts.Close() }) + + conn := dialTarget(t, ln.Addr().String()) + + // Login should work (discovery is fine without volumes) + params := NewParams() + params.Set("InitiatorName", testInitiatorName) + params.Set("SessionType", "Discovery") + req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) + req.SetCmdSN(1) + WritePDU(conn, req) + resp, _ := ReadPDU(conn) + if resp.LoginStatusClass() != LoginStatusSuccess { + t.Fatal("discovery login should succeed without volumes") + } +} diff --git a/weed/storage/blockvol/lba.go b/weed/storage/blockvol/lba.go new file mode 100644 index 000000000..431f6eaad --- /dev/null +++ b/weed/storage/blockvol/lba.go @@ -0,0 +1,38 @@ +package blockvol + +import ( + "errors" + "fmt" +) + +var ( + ErrLBAOutOfBounds = errors.New("blockvol: LBA out of bounds") + ErrWritePastEnd = errors.New("blockvol: write extends past volume end") + ErrAlignment = errors.New("blockvol: data length not aligned to block size") +) + +// ValidateLBA checks that lba is within the volume's logical address space. +func ValidateLBA(lba uint64, volumeSize uint64, blockSize uint32) error { + maxLBA := volumeSize / uint64(blockSize) + if lba >= maxLBA { + return fmt.Errorf("%w: lba=%d, max=%d", ErrLBAOutOfBounds, lba, maxLBA-1) + } + return nil +} + +// ValidateWrite checks that a write at lba with dataLen bytes fits within +// the volume and that dataLen is aligned to blockSize. +func ValidateWrite(lba uint64, dataLen uint32, volumeSize uint64, blockSize uint32) error { + if err := ValidateLBA(lba, volumeSize, blockSize); err != nil { + return err + } + if dataLen%blockSize != 0 { + return fmt.Errorf("%w: dataLen=%d, blockSize=%d", ErrAlignment, dataLen, blockSize) + } + blocksNeeded := uint64(dataLen / blockSize) + maxLBA := volumeSize / uint64(blockSize) + if lba+blocksNeeded > maxLBA { + return fmt.Errorf("%w: lba=%d, blocks=%d, max=%d", ErrWritePastEnd, lba, blocksNeeded, maxLBA) + } + return nil +} diff --git a/weed/storage/blockvol/lba_test.go b/weed/storage/blockvol/lba_test.go new file mode 100644 index 000000000..74952ee6f --- /dev/null +++ b/weed/storage/blockvol/lba_test.go @@ -0,0 +1,95 @@ +package blockvol + +import ( + "errors" + "testing" +) + +func TestLBAValidation(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "lba_within_bounds", run: testLBAWithinBounds}, + {name: "lba_last_block", run: testLBALastBlock}, + {name: "lba_out_of_bounds", run: testLBAOutOfBounds}, + {name: "lba_write_spans_end", run: testLBAWriteSpansEnd}, + {name: "lba_alignment", run: testLBAAlignment}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +const ( + testVolSize = 100 * 1024 * 1024 * 1024 // 100GB + testBlockSize = 4096 +) + +func testLBAWithinBounds(t *testing.T) { + err := ValidateLBA(0, testVolSize, testBlockSize) + if err != nil { + t.Errorf("LBA=0 should be valid: %v", err) + } + + err = ValidateLBA(1000, testVolSize, testBlockSize) + if err != nil { + t.Errorf("LBA=1000 should be valid: %v", err) + } +} + +func testLBALastBlock(t *testing.T) { + maxLBA := uint64(testVolSize/testBlockSize) - 1 + err := ValidateLBA(maxLBA, testVolSize, testBlockSize) + if err != nil { + t.Errorf("last LBA=%d should be valid: %v", maxLBA, err) + } +} + +func testLBAOutOfBounds(t *testing.T) { + maxLBA := uint64(testVolSize / testBlockSize) + err := ValidateLBA(maxLBA, testVolSize, testBlockSize) + if !errors.Is(err, ErrLBAOutOfBounds) { + t.Errorf("LBA=%d (one past end): expected ErrLBAOutOfBounds, got %v", maxLBA, err) + } + + err = ValidateLBA(maxLBA+1000, testVolSize, testBlockSize) + if !errors.Is(err, ErrLBAOutOfBounds) { + t.Errorf("LBA far past end: expected ErrLBAOutOfBounds, got %v", err) + } +} + +func testLBAWriteSpansEnd(t *testing.T) { + maxLBA := uint64(testVolSize / testBlockSize) + // Write 2 blocks starting at last LBA -- spans past end. + lastLBA := maxLBA - 1 + err := ValidateWrite(lastLBA, 2*testBlockSize, testVolSize, testBlockSize) + if !errors.Is(err, ErrWritePastEnd) { + t.Errorf("write spanning end: expected ErrWritePastEnd, got %v", err) + } + + // Write 1 block at last LBA -- should succeed. + err = ValidateWrite(lastLBA, testBlockSize, testVolSize, testBlockSize) + if err != nil { + t.Errorf("write at last LBA should succeed: %v", err) + } +} + +func testLBAAlignment(t *testing.T) { + err := ValidateWrite(0, 4096, testVolSize, testBlockSize) + if err != nil { + t.Errorf("aligned write should succeed: %v", err) + } + + err = ValidateWrite(0, 4000, testVolSize, testBlockSize) + if !errors.Is(err, ErrAlignment) { + t.Errorf("unaligned write: expected ErrAlignment, got %v", err) + } + + err = ValidateWrite(0, 1, testVolSize, testBlockSize) + if !errors.Is(err, ErrAlignment) { + t.Errorf("1-byte write: expected ErrAlignment, got %v", err) + } +} diff --git a/weed/storage/blockvol/recovery.go b/weed/storage/blockvol/recovery.go new file mode 100644 index 000000000..22bbff78d --- /dev/null +++ b/weed/storage/blockvol/recovery.go @@ -0,0 +1,157 @@ +package blockvol + +import ( + "fmt" + "os" +) + +// RecoveryResult contains the outcome of WAL recovery. +type RecoveryResult struct { + EntriesReplayed int // number of entries replayed into dirty map + HighestLSN uint64 // highest LSN seen during recovery + TornEntries int // entries discarded due to CRC failure +} + +// RecoverWAL scans the WAL region from tail to head, replaying valid entries +// into the dirty map. Entries with LSN <= checkpointLSN are skipped (already +// in extent). Scanning stops at the first CRC failure (torn write). +// +// The WAL is a circular buffer. If head >= tail, scan [tail, head). +// If head < tail (wrapped), scan [tail, walSize) then [0, head). +func RecoverWAL(fd *os.File, sb *Superblock, dirtyMap *DirtyMap) (RecoveryResult, error) { + result := RecoveryResult{} + + logicalHead := sb.WALHead + logicalTail := sb.WALTail + walOffset := sb.WALOffset + walSize := sb.WALSize + checkpointLSN := sb.WALCheckpointLSN + + if logicalHead == logicalTail { + // WAL is empty (or fully flushed). + return result, nil + } + + // Convert logical positions to physical. + physHead := logicalHead % walSize + physTail := logicalTail % walSize + + // Build the list of byte ranges to scan. + type scanRange struct { + start, end uint64 // physical positions within WAL + } + + var ranges []scanRange + if physHead > physTail { + // No wrap: scan [tail, head). + ranges = append(ranges, scanRange{physTail, physHead}) + } else if physHead == physTail { + // Head and tail at same physical position but different logical positions + // means the WAL is completely full. Scan the entire region. + ranges = append(ranges, scanRange{physTail, walSize}) + if physHead > 0 { + ranges = append(ranges, scanRange{0, physHead}) + } + } else { + // Wrapped: scan [tail, walSize) then [0, head). + ranges = append(ranges, scanRange{physTail, walSize}) + if physHead > 0 { + ranges = append(ranges, scanRange{0, physHead}) + } + } + + for _, r := range ranges { + pos := r.start + for pos < r.end { + remaining := r.end - pos + + // Need at least a header to proceed. + if remaining < uint64(walEntryHeaderSize) { + break + } + + // Read header. + headerBuf := make([]byte, walEntryHeaderSize) + absOff := int64(walOffset + pos) + if _, err := fd.ReadAt(headerBuf, absOff); err != nil { + return result, fmt.Errorf("recovery: read header at WAL+%d: %w", pos, err) + } + + // Parse entry type and length field. + entryType := headerBuf[16] + lengthField := parseLength(headerBuf) + + // For padding entries, skip forward. + if entryType == EntryTypePadding { + entrySize := uint64(walEntryHeaderSize) + uint64(lengthField) + pos += entrySize + continue + } + + // Calculate on-disk entry size. WRITE and PADDING carry data payload; + // TRIM and BARRIER do not (Length is metadata, not data size). + var payloadLen uint64 + if entryType == EntryTypeWrite { + payloadLen = uint64(lengthField) + } + entrySize := uint64(walEntryHeaderSize) + payloadLen + if entrySize > remaining { + // Torn write: entry extends past available data. + result.TornEntries++ + break + } + + // Read full entry. + fullBuf := make([]byte, entrySize) + if _, err := fd.ReadAt(fullBuf, absOff); err != nil { + return result, fmt.Errorf("recovery: read entry at WAL+%d: %w", pos, err) + } + + // Decode and validate CRC. + entry, err := DecodeWALEntry(fullBuf) + if err != nil { + // CRC failure or corrupt entry — stop here (torn write). + result.TornEntries++ + break + } + + // Skip entries already flushed to extent. + if entry.LSN <= checkpointLSN { + pos += entrySize + continue + } + + // Replay entry. + switch entry.Type { + case EntryTypeWrite: + blocks := entry.Length / sb.BlockSize + for i := uint32(0); i < blocks; i++ { + dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize) + } + result.EntriesReplayed++ + + case EntryTypeTrim: + // TRIM carries Length (bytes) covering multiple blocks. + blocks := entry.Length / sb.BlockSize + if blocks == 0 { + blocks = 1 // legacy single-block trim + } + for i := uint32(0); i < blocks; i++ { + dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize) + } + result.EntriesReplayed++ + + case EntryTypeBarrier: + // Barriers don't modify data, just skip. + } + + if entry.LSN > result.HighestLSN { + result.HighestLSN = entry.LSN + } + + pos += entrySize + } + } + + return result, nil +} diff --git a/weed/storage/blockvol/recovery_test.go b/weed/storage/blockvol/recovery_test.go new file mode 100644 index 000000000..bd23390d7 --- /dev/null +++ b/weed/storage/blockvol/recovery_test.go @@ -0,0 +1,416 @@ +package blockvol + +import ( + "bytes" + "path/filepath" + "testing" + "time" +) + +func TestRecovery(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "recover_empty_wal", run: testRecoverEmptyWAL}, + {name: "recover_one_entry", run: testRecoverOneEntry}, + {name: "recover_many_entries", run: testRecoverManyEntries}, + {name: "recover_torn_write", run: testRecoverTornWrite}, + {name: "recover_after_checkpoint", run: testRecoverAfterCheckpoint}, + {name: "recover_idempotent", run: testRecoverIdempotent}, + {name: "recover_wal_full", run: testRecoverWALFull}, + {name: "recover_barrier_only", run: testRecoverBarrierOnly}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// simulateCrash closes the volume without syncing and returns the path. +func simulateCrash(v *BlockVol) string { + path := v.Path() + v.groupCommit.Stop() + v.fd.Close() + return path +} + +func testRecoverEmptyWAL(t *testing.T) { + v := createTestVol(t) + // Sync to make superblock durable. + v.fd.Sync() + path := simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Empty volume: read should return zeros. + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("expected zeros from empty volume") + } +} + +func testRecoverOneEntry(t *testing.T) { + v := createTestVol(t) + data := makeBlock('A') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + // Sync WAL to make the entry durable. + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Update superblock with WAL state. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path := simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA after recovery: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("data not recovered") + } +} + +func testRecoverManyEntries(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "many.blockvol") + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 4 * 1024 * 1024, // 4MB + BlockSize: 4096, + WALSize: 2 * 1024 * 1024, // 2MB WAL -- enough for ~500 entries + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + const numWrites = 400 + for i := uint64(0); i < numWrites; i++ { + if err := v.WriteLBA(i, makeBlock(byte(i%26+'A'))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + for i := uint64(0); i < numWrites; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d) after recovery: %v", i, err) + } + expected := makeBlock(byte(i%26 + 'A')) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: data mismatch after recovery", i) + } + } +} + +func testRecoverTornWrite(t *testing.T) { + v := createTestVol(t) + + // Write 2 entries. + if err := v.WriteLBA(0, makeBlock('A')); err != nil { + t.Fatalf("WriteLBA(0): %v", err) + } + if err := v.WriteLBA(1, makeBlock('B')); err != nil { + t.Fatalf("WriteLBA(1): %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Save superblock with correct WAL state. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + // Corrupt the last 2 bytes of the second entry to simulate torn write. + entrySize := uint64(walEntryHeaderSize + 4096) + secondEntryEnd := v.super.WALOffset + entrySize*2 + corruptOff := int64(secondEntryEnd - 2) + v.fd.WriteAt([]byte{0xFF, 0xFF}, corruptOff) + v.fd.Sync() + + path := simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // First entry should be recovered. + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA(0) after torn recovery: %v", err) + } + if !bytes.Equal(got, makeBlock('A')) { + t.Error("block 0 should be recovered") + } + + // Second entry was torn -- should NOT be in dirty map. + // Read returns zeros (from extent). + got, err = v2.ReadLBA(1, 4096) + if err != nil { + t.Fatalf("ReadLBA(1) after torn recovery: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("block 1 (torn) should return zeros") + } +} + +func testRecoverAfterCheckpoint(t *testing.T) { + v := createTestVol(t) + + // Write blocks 0-9. + for i := uint64(0); i < 10; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Flush first 5 blocks (flusher moves them to extent, advances checkpoint). + f := NewFlusher(FlusherConfig{ + FD: v.fd, + Super: &v.super, + WAL: v.wal, + DirtyMap: v.dirtyMap, + Interval: 1 * time.Hour, + }) + + // Flush all (flusher takes all dirty entries). + // To simulate partial flush, we'll manually set checkpoint to midpoint. + if err := f.FlushOnce(); err != nil { + t.Fatalf("FlushOnce: %v", err) + } + + // Now write 5 more blocks (these will need replay after crash). + for i := uint64(10); i < 15; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + // Update superblock. + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path := simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // Blocks 0-9 should be in extent (flushed). Blocks 10-14 replayed from WAL. + for i := uint64(0); i < 15; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: data mismatch after checkpoint recovery", i) + } + } +} + +func testRecoverIdempotent(t *testing.T) { + v := createTestVol(t) + data := makeBlock('X') + if err := v.WriteLBA(0, data); err != nil { + t.Fatalf("WriteLBA: %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path := simulateCrash(v) + + // First recovery. + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol 1: %v", err) + } + + // Verify data. + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA 1: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("first recovery: data mismatch") + } + + // Close and recover again (should be idempotent). + path2 := simulateCrash(v2) + + v3, err := OpenBlockVol(path2) + if err != nil { + t.Fatalf("OpenBlockVol 2: %v", err) + } + defer v3.Close() + + got, err = v3.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA 2: %v", err) + } + if !bytes.Equal(got, data) { + t.Error("second recovery: data mismatch") + } +} + +func testRecoverWALFull(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "full.blockvol") + + entrySize := uint64(walEntryHeaderSize + 4096) + walSize := entrySize * 5 // room for exactly 5 entries + + v, err := CreateBlockVol(path, CreateOptions{ + VolumeSize: 1 * 1024 * 1024, + BlockSize: 4096, + WALSize: walSize, + }) + if err != nil { + t.Fatalf("CreateBlockVol: %v", err) + } + + // Write exactly 5 entries to fill WAL. + for i := uint64(0); i < 5; i++ { + if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { + t.Fatalf("WriteLBA(%d): %v", i, err) + } + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path = simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + for i := uint64(0); i < 5; i++ { + got, err := v2.ReadLBA(i, 4096) + if err != nil { + t.Fatalf("ReadLBA(%d): %v", i, err) + } + expected := makeBlock(byte('A' + i)) + if !bytes.Equal(got, expected) { + t.Errorf("block %d: data mismatch after full WAL recovery", i) + } + } +} + +func testRecoverBarrierOnly(t *testing.T) { + v := createTestVol(t) + + // Write a barrier entry. + lsn := v.nextLSN.Add(1) - 1 + entry := &WALEntry{ + LSN: lsn, + Type: EntryTypeBarrier, + LBA: 0, + } + if _, err := v.wal.Append(entry); err != nil { + t.Fatalf("Append barrier: %v", err) + } + + if err := v.SyncCache(); err != nil { + t.Fatalf("SyncCache: %v", err) + } + + v.super.WALHead = v.wal.LogicalHead() + v.super.WALTail = v.wal.LogicalTail() + v.fd.Seek(0, 0) + v.super.WriteTo(v.fd) + v.fd.Sync() + + path := simulateCrash(v) + + v2, err := OpenBlockVol(path) + if err != nil { + t.Fatalf("OpenBlockVol: %v", err) + } + defer v2.Close() + + // No data changes from barrier. + got, err := v2.ReadLBA(0, 4096) + if err != nil { + t.Fatalf("ReadLBA: %v", err) + } + if !bytes.Equal(got, make([]byte, 4096)) { + t.Error("barrier-only WAL should leave data as zeros") + } +} diff --git a/weed/storage/blockvol/superblock.go b/weed/storage/blockvol/superblock.go new file mode 100644 index 000000000..a6a341d21 --- /dev/null +++ b/weed/storage/blockvol/superblock.go @@ -0,0 +1,249 @@ +package blockvol + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/google/uuid" +) + +const ( + SuperblockSize = 4096 + MagicSWBK = "SWBK" + CurrentVersion = 1 +) + +var ( + ErrNotBlockVol = errors.New("blockvol: not a blockvol file (bad magic)") + ErrUnsupportedVersion = errors.New("blockvol: unsupported version") + ErrInvalidVolumeSize = errors.New("blockvol: volume size must be > 0") +) + +// Superblock is the 4KB header at offset 0 of a blockvol file. +// It identifies the file format, stores volume geometry, and tracks WAL state. +type Superblock struct { + Magic [4]byte + Version uint16 + Flags uint16 + UUID [16]byte + VolumeSize uint64 // logical size in bytes + ExtentSize uint32 // default 64KB + BlockSize uint32 // default 4KB (min I/O unit) + WALOffset uint64 // byte offset where WAL region starts + WALSize uint64 // WAL region size in bytes + WALHead uint64 // logical WAL write position (monotonically increasing) + WALTail uint64 // logical WAL flush position (monotonically increasing) + WALCheckpointLSN uint64 // last LSN flushed to extent region + Replication [4]byte + CreatedAt uint64 // unix timestamp + SnapshotCount uint32 +} + +// superblockOnDisk is the fixed-size on-disk layout (binary.Write/Read target). +// Must be <= SuperblockSize bytes. Remaining space is zero-padded. +type superblockOnDisk struct { + Magic [4]byte + Version uint16 + Flags uint16 + UUID [16]byte + VolumeSize uint64 + ExtentSize uint32 + BlockSize uint32 + WALOffset uint64 + WALSize uint64 + WALHead uint64 + WALTail uint64 + WALCheckpointLSN uint64 + Replication [4]byte + CreatedAt uint64 + SnapshotCount uint32 +} + +// NewSuperblock creates a superblock with defaults and a fresh UUID. +func NewSuperblock(volumeSize uint64, opts CreateOptions) (Superblock, error) { + if volumeSize == 0 { + return Superblock{}, ErrInvalidVolumeSize + } + + extentSize := opts.ExtentSize + if extentSize == 0 { + extentSize = 64 * 1024 // 64KB + } + blockSize := opts.BlockSize + if blockSize == 0 { + blockSize = 4096 // 4KB + } + walSize := opts.WALSize + if walSize == 0 { + walSize = 64 * 1024 * 1024 // 64MB + } + + var repl [4]byte + if opts.Replication != "" { + copy(repl[:], opts.Replication) + } else { + copy(repl[:], "000") + } + + id := uuid.New() + + sb := Superblock{ + Version: CurrentVersion, + VolumeSize: volumeSize, + ExtentSize: extentSize, + BlockSize: blockSize, + WALOffset: SuperblockSize, + WALSize: walSize, + } + copy(sb.Magic[:], MagicSWBK) + sb.UUID = id + sb.Replication = repl + + return sb, nil +} + +// WriteTo serializes the superblock to w as a 4096-byte block. +func (sb *Superblock) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, SuperblockSize) + + d := superblockOnDisk{ + Magic: sb.Magic, + Version: sb.Version, + Flags: sb.Flags, + UUID: sb.UUID, + VolumeSize: sb.VolumeSize, + ExtentSize: sb.ExtentSize, + BlockSize: sb.BlockSize, + WALOffset: sb.WALOffset, + WALSize: sb.WALSize, + WALHead: sb.WALHead, + WALTail: sb.WALTail, + WALCheckpointLSN: sb.WALCheckpointLSN, + Replication: sb.Replication, + CreatedAt: sb.CreatedAt, + SnapshotCount: sb.SnapshotCount, + } + + // Encode into beginning of buf; rest stays zero (padding). + endian := binary.LittleEndian + off := 0 + off += copy(buf[off:], d.Magic[:]) + endian.PutUint16(buf[off:], d.Version) + off += 2 + endian.PutUint16(buf[off:], d.Flags) + off += 2 + off += copy(buf[off:], d.UUID[:]) + endian.PutUint64(buf[off:], d.VolumeSize) + off += 8 + endian.PutUint32(buf[off:], d.ExtentSize) + off += 4 + endian.PutUint32(buf[off:], d.BlockSize) + off += 4 + endian.PutUint64(buf[off:], d.WALOffset) + off += 8 + endian.PutUint64(buf[off:], d.WALSize) + off += 8 + endian.PutUint64(buf[off:], d.WALHead) + off += 8 + endian.PutUint64(buf[off:], d.WALTail) + off += 8 + endian.PutUint64(buf[off:], d.WALCheckpointLSN) + off += 8 + off += copy(buf[off:], d.Replication[:]) + endian.PutUint64(buf[off:], d.CreatedAt) + off += 8 + endian.PutUint32(buf[off:], d.SnapshotCount) + + n, err := w.Write(buf) + return int64(n), err +} + +// ReadSuperblock reads and validates a superblock from r. +func ReadSuperblock(r io.Reader) (Superblock, error) { + buf := make([]byte, SuperblockSize) + if _, err := io.ReadFull(r, buf); err != nil { + return Superblock{}, fmt.Errorf("blockvol: read superblock: %w", err) + } + + endian := binary.LittleEndian + var sb Superblock + off := 0 + copy(sb.Magic[:], buf[off:off+4]) + off += 4 + + if string(sb.Magic[:]) != MagicSWBK { + return Superblock{}, ErrNotBlockVol + } + + sb.Version = endian.Uint16(buf[off:]) + off += 2 + if sb.Version != CurrentVersion { + return Superblock{}, fmt.Errorf("%w: got %d, want %d", ErrUnsupportedVersion, sb.Version, CurrentVersion) + } + + sb.Flags = endian.Uint16(buf[off:]) + off += 2 + copy(sb.UUID[:], buf[off:off+16]) + off += 16 + sb.VolumeSize = endian.Uint64(buf[off:]) + off += 8 + + if sb.VolumeSize == 0 { + return Superblock{}, ErrInvalidVolumeSize + } + + sb.ExtentSize = endian.Uint32(buf[off:]) + off += 4 + sb.BlockSize = endian.Uint32(buf[off:]) + off += 4 + sb.WALOffset = endian.Uint64(buf[off:]) + off += 8 + sb.WALSize = endian.Uint64(buf[off:]) + off += 8 + sb.WALHead = endian.Uint64(buf[off:]) + off += 8 + sb.WALTail = endian.Uint64(buf[off:]) + off += 8 + sb.WALCheckpointLSN = endian.Uint64(buf[off:]) + off += 8 + copy(sb.Replication[:], buf[off:off+4]) + off += 4 + sb.CreatedAt = endian.Uint64(buf[off:]) + off += 8 + sb.SnapshotCount = endian.Uint32(buf[off:]) + + return sb, nil +} + +var ErrInvalidSuperblock = errors.New("blockvol: invalid superblock") + +// Validate checks that the superblock fields are internally consistent. +func (sb *Superblock) Validate() error { + if string(sb.Magic[:]) != MagicSWBK { + return ErrNotBlockVol + } + if sb.Version != CurrentVersion { + return fmt.Errorf("%w: got %d", ErrUnsupportedVersion, sb.Version) + } + if sb.VolumeSize == 0 { + return ErrInvalidVolumeSize + } + if sb.BlockSize == 0 { + return fmt.Errorf("%w: BlockSize is 0", ErrInvalidSuperblock) + } + if sb.ExtentSize == 0 { + return fmt.Errorf("%w: ExtentSize is 0", ErrInvalidSuperblock) + } + if sb.WALSize == 0 { + return fmt.Errorf("%w: WALSize is 0", ErrInvalidSuperblock) + } + if sb.WALOffset != SuperblockSize { + return fmt.Errorf("%w: WALOffset=%d, expected %d", ErrInvalidSuperblock, sb.WALOffset, SuperblockSize) + } + if sb.VolumeSize%uint64(sb.BlockSize) != 0 { + return fmt.Errorf("%w: VolumeSize %d not aligned to BlockSize %d", ErrInvalidSuperblock, sb.VolumeSize, sb.BlockSize) + } + return nil +} diff --git a/weed/storage/blockvol/superblock_test.go b/weed/storage/blockvol/superblock_test.go new file mode 100644 index 000000000..8decf51c6 --- /dev/null +++ b/weed/storage/blockvol/superblock_test.go @@ -0,0 +1,146 @@ +package blockvol + +import ( + "bytes" + "encoding/binary" + "errors" + "testing" +) + +func TestSuperblock(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "superblock_roundtrip", run: testSuperblockRoundtrip}, + {name: "superblock_magic_check", run: testSuperblockMagicCheck}, + {name: "superblock_version_check", run: testSuperblockVersionCheck}, + {name: "superblock_zero_vol_size", run: testSuperblockZeroVolSize}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testSuperblockRoundtrip(t *testing.T) { + sb, err := NewSuperblock(100*1024*1024*1024, CreateOptions{ + ExtentSize: 64 * 1024, + BlockSize: 4096, + WALSize: 64 * 1024 * 1024, + Replication: "010", + }) + if err != nil { + t.Fatalf("NewSuperblock: %v", err) + } + sb.CreatedAt = 1700000000 + sb.SnapshotCount = 3 + sb.WALHead = 8192 + sb.WALCheckpointLSN = 42 + + var buf bytes.Buffer + n, err := sb.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo: %v", err) + } + if n != SuperblockSize { + t.Fatalf("WriteTo wrote %d bytes, want %d", n, SuperblockSize) + } + + got, err := ReadSuperblock(&buf) + if err != nil { + t.Fatalf("ReadSuperblock: %v", err) + } + + if string(got.Magic[:]) != MagicSWBK { + t.Errorf("Magic = %q, want %q", got.Magic, MagicSWBK) + } + if got.Version != CurrentVersion { + t.Errorf("Version = %d, want %d", got.Version, CurrentVersion) + } + if got.UUID != sb.UUID { + t.Errorf("UUID mismatch") + } + if got.VolumeSize != sb.VolumeSize { + t.Errorf("VolumeSize = %d, want %d", got.VolumeSize, sb.VolumeSize) + } + if got.ExtentSize != sb.ExtentSize { + t.Errorf("ExtentSize = %d, want %d", got.ExtentSize, sb.ExtentSize) + } + if got.BlockSize != sb.BlockSize { + t.Errorf("BlockSize = %d, want %d", got.BlockSize, sb.BlockSize) + } + if got.WALOffset != sb.WALOffset { + t.Errorf("WALOffset = %d, want %d", got.WALOffset, sb.WALOffset) + } + if got.WALSize != sb.WALSize { + t.Errorf("WALSize = %d, want %d", got.WALSize, sb.WALSize) + } + if got.WALHead != sb.WALHead { + t.Errorf("WALHead = %d, want %d", got.WALHead, sb.WALHead) + } + if got.WALCheckpointLSN != sb.WALCheckpointLSN { + t.Errorf("WALCheckpointLSN = %d, want %d", got.WALCheckpointLSN, sb.WALCheckpointLSN) + } + if got.Replication != sb.Replication { + t.Errorf("Replication = %q, want %q", got.Replication, sb.Replication) + } + if got.CreatedAt != sb.CreatedAt { + t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, sb.CreatedAt) + } + if got.SnapshotCount != sb.SnapshotCount { + t.Errorf("SnapshotCount = %d, want %d", got.SnapshotCount, sb.SnapshotCount) + } +} + +func testSuperblockMagicCheck(t *testing.T) { + // Write a valid superblock, then corrupt the magic bytes. + sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) + var buf bytes.Buffer + sb.WriteTo(&buf) + + data := buf.Bytes() + copy(data[0:4], "XXXX") // corrupt magic + + _, err := ReadSuperblock(bytes.NewReader(data)) + if !errors.Is(err, ErrNotBlockVol) { + t.Errorf("expected ErrNotBlockVol, got %v", err) + } +} + +func testSuperblockVersionCheck(t *testing.T) { + sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) + var buf bytes.Buffer + sb.WriteTo(&buf) + + data := buf.Bytes() + // Version is at offset 4, uint16 little-endian. + binary.LittleEndian.PutUint16(data[4:], 99) + + _, err := ReadSuperblock(bytes.NewReader(data)) + if !errors.Is(err, ErrUnsupportedVersion) { + t.Errorf("expected ErrUnsupportedVersion, got %v", err) + } +} + +func testSuperblockZeroVolSize(t *testing.T) { + _, err := NewSuperblock(0, CreateOptions{}) + if !errors.Is(err, ErrInvalidVolumeSize) { + t.Errorf("NewSuperblock(0): expected ErrInvalidVolumeSize, got %v", err) + } + + // Also test reading a superblock with zero volume size. + sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) + var buf bytes.Buffer + sb.WriteTo(&buf) + + data := buf.Bytes() + // VolumeSize is at offset 4+2+2+16 = 24, uint64 little-endian. + binary.LittleEndian.PutUint64(data[24:], 0) + + _, err = ReadSuperblock(bytes.NewReader(data)) + if !errors.Is(err, ErrInvalidVolumeSize) { + t.Errorf("ReadSuperblock(vol_size=0): expected ErrInvalidVolumeSize, got %v", err) + } +} diff --git a/weed/storage/blockvol/wal_entry.go b/weed/storage/blockvol/wal_entry.go new file mode 100644 index 000000000..d3c8c428d --- /dev/null +++ b/weed/storage/blockvol/wal_entry.go @@ -0,0 +1,153 @@ +package blockvol + +import ( + "encoding/binary" + "errors" + "fmt" + "hash/crc32" +) + +const ( + EntryTypeWrite = 0x01 + EntryTypeTrim = 0x02 + EntryTypeBarrier = 0x03 + EntryTypePadding = 0xFF + + // walEntryHeaderSize is the fixed portion: LSN(8) + Epoch(8) + Type(1) + + // Flags(1) + LBA(8) + Length(4) + CRC32(4) + EntrySize(4) = 38 bytes. + walEntryHeaderSize = 38 +) + +var ( + ErrCRCMismatch = errors.New("blockvol: CRC mismatch") + ErrInvalidEntry = errors.New("blockvol: invalid WAL entry") + ErrEntryTruncated = errors.New("blockvol: WAL entry truncated") +) + +// WALEntry is a variable-size record in the WAL region. +type WALEntry struct { + LSN uint64 + Epoch uint64 // writer's epoch (Phase 1: always 0) + Type uint8 // EntryTypeWrite, EntryTypeTrim, EntryTypeBarrier + Flags uint8 + LBA uint64 // in blocks + Length uint32 // data length in bytes (WRITE: data size, TRIM: trim size, BARRIER: 0) + Data []byte // present only for WRITE + CRC32 uint32 // covers LSN through Data + EntrySize uint32 // total serialized size +} + +// Encode serializes the entry into a byte slice, computing CRC over all fields. +func (e *WALEntry) Encode() ([]byte, error) { + switch e.Type { + case EntryTypeWrite: + if len(e.Data) == 0 { + return nil, fmt.Errorf("%w: WRITE entry with no data", ErrInvalidEntry) + } + if uint32(len(e.Data)) != e.Length { + return nil, fmt.Errorf("%w: data length %d != Length field %d", ErrInvalidEntry, len(e.Data), e.Length) + } + case EntryTypeTrim: + if len(e.Data) != 0 { + return nil, fmt.Errorf("%w: TRIM entry must have no data payload", ErrInvalidEntry) + } + // TRIM carries Length (trim size in bytes) but no Data payload. + case EntryTypeBarrier: + if e.Length != 0 || len(e.Data) != 0 { + return nil, fmt.Errorf("%w: BARRIER entry must have no data", ErrInvalidEntry) + } + } + + totalSize := uint32(walEntryHeaderSize + len(e.Data)) + buf := make([]byte, totalSize) + + le := binary.LittleEndian + off := 0 + le.PutUint64(buf[off:], e.LSN) + off += 8 + le.PutUint64(buf[off:], e.Epoch) + off += 8 + buf[off] = e.Type + off++ + buf[off] = e.Flags + off++ + le.PutUint64(buf[off:], e.LBA) + off += 8 + le.PutUint32(buf[off:], e.Length) + off += 4 + + if len(e.Data) > 0 { + copy(buf[off:], e.Data) + off += len(e.Data) + } + + // CRC covers everything from start through Data (offset 0 to off). + checksum := crc32.ChecksumIEEE(buf[:off]) + le.PutUint32(buf[off:], checksum) + off += 4 + le.PutUint32(buf[off:], totalSize) + + e.CRC32 = checksum + e.EntrySize = totalSize + return buf, nil +} + +// DecodeWALEntry deserializes a WAL entry from buf, validating CRC. +func DecodeWALEntry(buf []byte) (WALEntry, error) { + if len(buf) < walEntryHeaderSize { + return WALEntry{}, fmt.Errorf("%w: need %d bytes, have %d", ErrEntryTruncated, walEntryHeaderSize, len(buf)) + } + + le := binary.LittleEndian + var e WALEntry + off := 0 + e.LSN = le.Uint64(buf[off:]) + off += 8 + e.Epoch = le.Uint64(buf[off:]) + off += 8 + e.Type = buf[off] + off++ + e.Flags = buf[off] + off++ + e.LBA = le.Uint64(buf[off:]) + off += 8 + e.Length = le.Uint32(buf[off:]) + off += 4 + + // For WRITE entries, Length is the data payload size. + // For TRIM entries, Length is the trim extent in bytes (no data payload). + // For BARRIER/PADDING, Length is 0 (or padding size). + var dataLen int + if e.Type == EntryTypeWrite || e.Type == EntryTypePadding { + dataLen = int(e.Length) + } + + dataEnd := off + dataLen + if dataEnd+8 > len(buf) { // +8 for CRC32 + EntrySize + return WALEntry{}, fmt.Errorf("%w: need %d bytes for data+footer, have %d", ErrEntryTruncated, dataEnd+8, len(buf)) + } + + if dataLen > 0 { + e.Data = make([]byte, dataLen) + copy(e.Data, buf[off:dataEnd]) + } + off = dataEnd + + e.CRC32 = le.Uint32(buf[off:]) + off += 4 + e.EntrySize = le.Uint32(buf[off:]) + + // Verify CRC: covers LSN through Data. + expected := crc32.ChecksumIEEE(buf[:dataEnd]) + if e.CRC32 != expected { + return WALEntry{}, fmt.Errorf("%w: stored=%08x computed=%08x", ErrCRCMismatch, e.CRC32, expected) + } + + // Verify EntrySize matches actual layout. + expectedSize := uint32(walEntryHeaderSize) + uint32(dataLen) + if e.EntrySize != expectedSize { + return WALEntry{}, fmt.Errorf("%w: EntrySize=%d, expected=%d", ErrInvalidEntry, e.EntrySize, expectedSize) + } + + return e, nil +} diff --git a/weed/storage/blockvol/wal_entry_test.go b/weed/storage/blockvol/wal_entry_test.go new file mode 100644 index 000000000..6c1f4547b --- /dev/null +++ b/weed/storage/blockvol/wal_entry_test.go @@ -0,0 +1,284 @@ +package blockvol + +import ( + "bytes" + "encoding/binary" + "errors" + "testing" +) + +func TestWALEntry(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "wal_entry_roundtrip", run: testWALEntryRoundtrip}, + {name: "wal_entry_trim_no_data", run: testWALEntryTrimNoData}, + {name: "wal_entry_barrier", run: testWALEntryBarrier}, + {name: "wal_entry_crc_valid", run: testWALEntryCRCValid}, + {name: "wal_entry_crc_corrupt", run: testWALEntryCRCCorrupt}, + {name: "wal_entry_max_size", run: testWALEntryMaxSize}, + {name: "wal_entry_zero_length", run: testWALEntryZeroLength}, + {name: "wal_entry_trim_rejects_data", run: testWALEntryTrimRejectsData}, + {name: "wal_entry_barrier_rejects_data", run: testWALEntryBarrierRejectsData}, + {name: "wal_entry_bad_entry_size", run: testWALEntryBadEntrySize}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func testWALEntryRoundtrip(t *testing.T) { + data := make([]byte, 4096) + for i := range data { + data[i] = byte(i % 251) // deterministic pattern + } + + e := WALEntry{ + LSN: 1, + Epoch: 0, + Type: EntryTypeWrite, + Flags: 0, + LBA: 100, + Length: uint32(len(data)), + Data: data, + } + + buf, err := e.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + + got, err := DecodeWALEntry(buf) + if err != nil { + t.Fatalf("DecodeWALEntry: %v", err) + } + + if got.LSN != e.LSN { + t.Errorf("LSN = %d, want %d", got.LSN, e.LSN) + } + if got.Epoch != e.Epoch { + t.Errorf("Epoch = %d, want %d", got.Epoch, e.Epoch) + } + if got.Type != e.Type { + t.Errorf("Type = %d, want %d", got.Type, e.Type) + } + if got.LBA != e.LBA { + t.Errorf("LBA = %d, want %d", got.LBA, e.LBA) + } + if got.Length != e.Length { + t.Errorf("Length = %d, want %d", got.Length, e.Length) + } + if !bytes.Equal(got.Data, e.Data) { + t.Errorf("Data mismatch") + } + if got.EntrySize != uint32(walEntryHeaderSize+len(data)) { + t.Errorf("EntrySize = %d, want %d", got.EntrySize, walEntryHeaderSize+len(data)) + } +} + +func testWALEntryTrimNoData(t *testing.T) { + // TRIM carries Length (trim size in bytes) but no Data payload. + e := WALEntry{ + LSN: 5, + Epoch: 0, + Type: EntryTypeTrim, + LBA: 200, + Length: 4096, + } + + buf, err := e.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + + got, err := DecodeWALEntry(buf) + if err != nil { + t.Fatalf("DecodeWALEntry: %v", err) + } + + if got.Type != EntryTypeTrim { + t.Errorf("Type = %d, want %d", got.Type, EntryTypeTrim) + } + if got.Length != 4096 { + t.Errorf("Length = %d, want 4096", got.Length) + } + if len(got.Data) != 0 { + t.Errorf("Data should be empty, got %d bytes", len(got.Data)) + } + // TRIM with Length but no Data: EntrySize = header only. + if got.EntrySize != uint32(walEntryHeaderSize) { + t.Errorf("EntrySize = %d, want %d (header only)", got.EntrySize, walEntryHeaderSize) + } +} + +func testWALEntryBarrier(t *testing.T) { + e := WALEntry{ + LSN: 10, + Epoch: 0, + Type: EntryTypeBarrier, + } + + buf, err := e.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + + got, err := DecodeWALEntry(buf) + if err != nil { + t.Fatalf("DecodeWALEntry: %v", err) + } + + if got.Type != EntryTypeBarrier { + t.Errorf("Type = %d, want %d", got.Type, EntryTypeBarrier) + } + if len(got.Data) != 0 { + t.Errorf("Barrier should have no data, got %d bytes", len(got.Data)) + } +} + +func testWALEntryCRCValid(t *testing.T) { + data := []byte("hello blockvol WAL") + e := WALEntry{ + LSN: 1, + Type: EntryTypeWrite, + LBA: 0, + Length: uint32(len(data)), + Data: data, + } + + buf, err := e.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Decode should succeed with valid CRC. + _, err = DecodeWALEntry(buf) + if err != nil { + t.Fatalf("valid entry should decode without error, got: %v", err) + } +} + +func testWALEntryCRCCorrupt(t *testing.T) { + data := []byte("hello blockvol WAL") + e := WALEntry{ + LSN: 1, + Type: EntryTypeWrite, + LBA: 0, + Length: uint32(len(data)), + Data: data, + } + + buf, err := e.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Flip one bit in the data region. + buf[walEntryHeaderSize-8] ^= 0x01 // flip bit in data area (before CRC/EntrySize footer) + + _, err = DecodeWALEntry(buf) + if !errors.Is(err, ErrCRCMismatch) { + t.Errorf("expected ErrCRCMismatch, got %v", err) + } +} + +func testWALEntryMaxSize(t *testing.T) { + data := make([]byte, 64*1024) // 64KB + for i := range data { + data[i] = byte(i % 256) + } + + e := WALEntry{ + LSN: 99, + Epoch: 0, + Type: EntryTypeWrite, + LBA: 0, + Length: uint32(len(data)), + Data: data, + } + + buf, err := e.Encode() + if err != nil { + t.Fatalf("Encode 64KB entry: %v", err) + } + + got, err := DecodeWALEntry(buf) + if err != nil { + t.Fatalf("DecodeWALEntry: %v", err) + } + + if !bytes.Equal(got.Data, data) { + t.Errorf("64KB data mismatch after roundtrip") + } +} + +func testWALEntryZeroLength(t *testing.T) { + e := WALEntry{ + LSN: 1, + Type: EntryTypeWrite, + LBA: 0, + Length: 0, + Data: nil, + } + + _, err := e.Encode() + if !errors.Is(err, ErrInvalidEntry) { + t.Errorf("WRITE with zero data: expected ErrInvalidEntry, got %v", err) + } +} + +func testWALEntryTrimRejectsData(t *testing.T) { + // TRIM allows Length (trim extent) but rejects Data payload. + e := WALEntry{ + LSN: 1, + Type: EntryTypeTrim, + LBA: 0, + Length: 4096, + Data: []byte("bad!"), + } + _, err := e.Encode() + if !errors.Is(err, ErrInvalidEntry) { + t.Errorf("TRIM with data payload: expected ErrInvalidEntry, got %v", err) + } +} + +func testWALEntryBarrierRejectsData(t *testing.T) { + e := WALEntry{ + LSN: 1, + Type: EntryTypeBarrier, + Length: 1, + Data: []byte("x"), + } + _, err := e.Encode() + if !errors.Is(err, ErrInvalidEntry) { + t.Errorf("BARRIER with data: expected ErrInvalidEntry, got %v", err) + } +} + +func testWALEntryBadEntrySize(t *testing.T) { + data := []byte("test data for entry size validation") + e := WALEntry{ + LSN: 1, + Type: EntryTypeWrite, + LBA: 0, + Length: uint32(len(data)), + Data: data, + } + + buf, err := e.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Corrupt EntrySize (last 4 bytes of buf). + le := binary.LittleEndian + le.PutUint32(buf[len(buf)-4:], 9999) + + _, err = DecodeWALEntry(buf) + if !errors.Is(err, ErrInvalidEntry) { + t.Errorf("bad EntrySize: expected ErrInvalidEntry, got %v", err) + } +} diff --git a/weed/storage/blockvol/wal_writer.go b/weed/storage/blockvol/wal_writer.go new file mode 100644 index 000000000..a16c741c0 --- /dev/null +++ b/weed/storage/blockvol/wal_writer.go @@ -0,0 +1,192 @@ +package blockvol + +import ( + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "os" + "sync" +) + +var ( + ErrWALFull = errors.New("blockvol: WAL region full") +) + +// WALWriter appends entries to the circular WAL region of a blockvol file. +// +// It uses logical (monotonically increasing) head and tail counters to track +// used space. Physical position = logical % walSize. This eliminates the +// classic circular buffer ambiguity where head==tail could mean empty or full. +// Used space = logicalHead - logicalTail. Free space = walSize - used. +type WALWriter struct { + mu sync.Mutex + fd *os.File + walOffset uint64 // absolute file offset where WAL region starts + walSize uint64 // size of the WAL region in bytes + logicalHead uint64 // monotonically increasing write position + logicalTail uint64 // monotonically increasing flush position +} + +// NewWALWriter creates a WAL writer for the given file. +// head and tail are physical positions relative to WAL region start. +// For a fresh WAL, both are 0. +func NewWALWriter(fd *os.File, walOffset, walSize, head, tail uint64) *WALWriter { + return &WALWriter{ + fd: fd, + walOffset: walOffset, + walSize: walSize, + logicalHead: head, // on fresh WAL, physical == logical (both start at 0) + logicalTail: tail, + } +} + +// physicalPos converts a logical position to a physical WAL offset. +func (w *WALWriter) physicalPos(logical uint64) uint64 { + return logical % w.walSize +} + +// used returns the number of bytes occupied in the WAL. +func (w *WALWriter) used() uint64 { + return w.logicalHead - w.logicalTail +} + +// Append writes a serialized WAL entry to the circular WAL region. +// Returns the physical WAL-relative offset where the entry was written. +// If the entry doesn't fit in the remaining space before the region end, +// a padding entry is written and the real entry starts at physical offset 0. +func (w *WALWriter) Append(entry *WALEntry) (walRelOffset uint64, err error) { + buf, err := entry.Encode() + if err != nil { + return 0, fmt.Errorf("WALWriter.Append: encode: %w", err) + } + + w.mu.Lock() + defer w.mu.Unlock() + + entryLen := uint64(len(buf)) + if entryLen > w.walSize { + return 0, fmt.Errorf("%w: entry size %d exceeds WAL size %d", ErrWALFull, entryLen, w.walSize) + } + + physHead := w.physicalPos(w.logicalHead) + remaining := w.walSize - physHead + + if remaining < entryLen { + // Not enough room at end of region -- write padding and wrap. + // Padding consumes 'remaining' bytes logically. + if w.used()+remaining+entryLen > w.walSize { + return 0, ErrWALFull + } + if err := w.writePadding(remaining, physHead); err != nil { + return 0, fmt.Errorf("WALWriter.Append: padding: %w", err) + } + w.logicalHead += remaining + physHead = 0 + } + + // Check if there's enough free space for the entry. + if w.used()+entryLen > w.walSize { + return 0, ErrWALFull + } + + absOffset := int64(w.walOffset + physHead) + if _, err := w.fd.WriteAt(buf, absOffset); err != nil { + return 0, fmt.Errorf("WALWriter.Append: pwrite at offset %d: %w", absOffset, err) + } + + writeOffset := physHead + w.logicalHead += entryLen + return writeOffset, nil +} + +// writePadding writes a padding entry at the given physical position. +func (w *WALWriter) writePadding(size uint64, physPos uint64) error { + if size < walEntryHeaderSize { + // Too small for a proper entry header -- zero it out. + buf := make([]byte, size) + absOffset := int64(w.walOffset + physPos) + _, err := w.fd.WriteAt(buf, absOffset) + return err + } + + buf := make([]byte, size) + le := binary.LittleEndian + off := 0 + le.PutUint64(buf[off:], 0) // LSN=0 + off += 8 + le.PutUint64(buf[off:], 0) // Epoch=0 + off += 8 + buf[off] = EntryTypePadding + off++ + buf[off] = 0 // Flags + off++ + le.PutUint64(buf[off:], 0) // LBA=0 + off += 8 + paddingDataLen := uint32(size) - uint32(walEntryHeaderSize) + le.PutUint32(buf[off:], paddingDataLen) + off += 4 + dataEnd := off + int(paddingDataLen) + + crc := crc32.ChecksumIEEE(buf[:dataEnd]) + le.PutUint32(buf[dataEnd:], crc) + le.PutUint32(buf[dataEnd+4:], uint32(size)) + + absOffset := int64(w.walOffset + physPos) + _, err := w.fd.WriteAt(buf, absOffset) + return err +} + +// AdvanceTail moves the tail forward, freeing WAL space. +// Called by the flusher after entries have been written to the extent region. +// newTail is a physical position; it is converted to a logical advance. +func (w *WALWriter) AdvanceTail(newTail uint64) { + w.mu.Lock() + physTail := w.physicalPos(w.logicalTail) + var advance uint64 + if newTail >= physTail { + advance = newTail - physTail + } else { + // Tail wrapped around. + advance = w.walSize - physTail + newTail + } + w.logicalTail += advance + w.mu.Unlock() +} + +// Head returns the current physical head position (relative to WAL start). +func (w *WALWriter) Head() uint64 { + w.mu.Lock() + h := w.physicalPos(w.logicalHead) + w.mu.Unlock() + return h +} + +// Tail returns the current physical tail position (relative to WAL start). +func (w *WALWriter) Tail() uint64 { + w.mu.Lock() + t := w.physicalPos(w.logicalTail) + w.mu.Unlock() + return t +} + +// LogicalHead returns the logical (monotonically increasing) head position. +func (w *WALWriter) LogicalHead() uint64 { + w.mu.Lock() + h := w.logicalHead + w.mu.Unlock() + return h +} + +// LogicalTail returns the logical (monotonically increasing) tail position. +func (w *WALWriter) LogicalTail() uint64 { + w.mu.Lock() + t := w.logicalTail + w.mu.Unlock() + return t +} + +// Sync fsyncs the underlying file descriptor. +func (w *WALWriter) Sync() error { + return w.fd.Sync() +} diff --git a/weed/storage/blockvol/wal_writer_test.go b/weed/storage/blockvol/wal_writer_test.go new file mode 100644 index 000000000..efa97524b --- /dev/null +++ b/weed/storage/blockvol/wal_writer_test.go @@ -0,0 +1,244 @@ +package blockvol + +import ( + "os" + "path/filepath" + "testing" +) + +func TestWALWriter(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + {name: "wal_writer_append_read_back", run: testWALWriterAppendReadBack}, + {name: "wal_writer_multiple_entries", run: testWALWriterMultipleEntries}, + {name: "wal_writer_wrap_around", run: testWALWriterWrapAround}, + {name: "wal_writer_full", run: testWALWriterFull}, + {name: "wal_writer_advance_tail_frees_space", run: testWALWriterAdvanceTailFreesSpace}, + {name: "wal_writer_fill_no_flusher", run: testWALWriterFillNoFlusher}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +// createTestWAL creates a temp file with a WAL region at the given offset and size. +func createTestWAL(t *testing.T, walOffset, walSize uint64) (*os.File, func()) { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "test.blockvol") + fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + t.Fatalf("create temp file: %v", err) + } + // Extend file to cover superblock + WAL region. + totalSize := int64(walOffset + walSize) + if err := fd.Truncate(totalSize); err != nil { + fd.Close() + t.Fatalf("truncate: %v", err) + } + return fd, func() { fd.Close() } +} + +func testWALWriterAppendReadBack(t *testing.T) { + walOffset := uint64(SuperblockSize) + walSize := uint64(64 * 1024) // 64KB WAL + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + data := []byte("hello WAL writer test!") + // Pad to block size for a realistic entry. + padded := make([]byte, 4096) + copy(padded, data) + + entry := &WALEntry{ + LSN: 1, + Type: EntryTypeWrite, + LBA: 0, + Length: uint32(len(padded)), + Data: padded, + } + + off, err := w.Append(entry) + if err != nil { + t.Fatalf("Append: %v", err) + } + if off != 0 { + t.Errorf("first entry offset = %d, want 0", off) + } + + // Read back from file and decode. + buf := make([]byte, entry.EntrySize) + if _, err := fd.ReadAt(buf, int64(walOffset+off)); err != nil { + t.Fatalf("ReadAt: %v", err) + } + decoded, err := DecodeWALEntry(buf) + if err != nil { + t.Fatalf("DecodeWALEntry: %v", err) + } + if decoded.LSN != 1 { + t.Errorf("decoded LSN = %d, want 1", decoded.LSN) + } +} + +func testWALWriterMultipleEntries(t *testing.T) { + walOffset := uint64(SuperblockSize) + walSize := uint64(64 * 1024) + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + for i := uint64(1); i <= 5; i++ { + entry := &WALEntry{ + LSN: i, + Type: EntryTypeWrite, + LBA: i * 10, + Length: 4096, + Data: make([]byte, 4096), + } + _, err := w.Append(entry) + if err != nil { + t.Fatalf("Append entry %d: %v", i, err) + } + } + + expectedHead := uint64(5 * (walEntryHeaderSize + 4096)) + if w.Head() != expectedHead { + t.Errorf("head = %d, want %d", w.Head(), expectedHead) + } +} + +func testWALWriterWrapAround(t *testing.T) { + walOffset := uint64(SuperblockSize) + // Small WAL: 2.5x entry size. After writing 2 entries (head at 2*entrySize), + // remaining (0.5*entrySize) is too small for entry 3, so it wraps to offset 0. + // Tail must be advanced past both entries so there's free space after wrap. + entrySize := uint64(walEntryHeaderSize + 4096) + walSize := entrySize*2 + entrySize/2 + + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + // Write 2 entries to fill most of the WAL. + for i := uint64(1); i <= 2; i++ { + entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)} + if _, err := w.Append(entry); err != nil { + t.Fatalf("Append entry %d: %v", i, err) + } + } + + // Advance tail past both entries (simulates flusher flushed them). + w.AdvanceTail(entrySize * 2) + + // Write a 3rd entry -- should wrap around. + entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 3, Length: 4096, Data: make([]byte, 4096)} + off, err := w.Append(entry3) + if err != nil { + t.Fatalf("Append entry 3 (wrap): %v", err) + } + + // After wrap, entry should be written at offset 0. + if off != 0 { + t.Errorf("wrapped entry offset = %d, want 0", off) + } +} + +func testWALWriterFull(t *testing.T) { + walOffset := uint64(SuperblockSize) + entrySize := uint64(walEntryHeaderSize + 4096) + walSize := entrySize * 2 // fits exactly 2 entries + + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + // Fill the WAL with 2 entries (exact fit). + for i := uint64(1); i <= 2; i++ { + entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)} + if _, err := w.Append(entry); err != nil { + t.Fatalf("Append entry %d: %v", i, err) + } + } + + // Third entry should fail -- tail hasn't moved, so no free space. + entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)} + _, err := w.Append(entry3) + if err == nil { + t.Fatal("expected ErrWALFull when WAL is full") + } +} + +func testWALWriterAdvanceTailFreesSpace(t *testing.T) { + walOffset := uint64(SuperblockSize) + entrySize := uint64(walEntryHeaderSize + 4096) + walSize := entrySize * 2 + + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + // Fill the WAL with 2 entries. + for i := uint64(1); i <= 2; i++ { + entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)} + if _, err := w.Append(entry); err != nil { + t.Fatalf("Append entry %d: %v", i, err) + } + } + + // WAL is full. Third entry should fail. + entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)} + if _, err := w.Append(entry3); err == nil { + t.Fatal("expected ErrWALFull before AdvanceTail") + } + + // Advance tail to free space for 1 entry. + w.AdvanceTail(entrySize) + + // Now entry 3 should succeed. + if _, err := w.Append(entry3); err != nil { + t.Fatalf("Append after AdvanceTail: %v", err) + } +} + +func testWALWriterFillNoFlusher(t *testing.T) { + // QA-001 regression: fill WAL without flusher (tail stays at 0). + // After wrap, head=0 and tail=0 must NOT be treated as "empty". + walOffset := uint64(SuperblockSize) + entrySize := uint64(walEntryHeaderSize + 4096) + walSize := entrySize * 10 // room for ~10 entries + + fd, cleanup := createTestWAL(t, walOffset, walSize) + defer cleanup() + + w := NewWALWriter(fd, walOffset, walSize, 0, 0) + + // Fill the WAL completely -- tail never moves (no flusher). + written := 0 + for i := uint64(1); ; i++ { + entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)} + _, err := w.Append(entry) + if err != nil { + // Should eventually get ErrWALFull, NOT wrap and overwrite. + break + } + written++ + if written > 20 { + t.Fatalf("wrote %d entries to a 10-entry WAL without ErrWALFull -- wrap overwrote live entries (QA-001)", written) + } + } + + if written == 0 { + t.Fatal("should have written at least 1 entry") + } + t.Logf("correctly wrote %d entries then got ErrWALFull", written) +}