Browse Source
feat: add BlockVol engine and iSCSI target (Phase 1 + Phase 2)
feat: add BlockVol engine and iSCSI target (Phase 1 + Phase 2)
Phase 1: Extent-mapped block storage engine with WAL, crash recovery, dirty map, flusher, and group commit. 174 tests, zero SeaweedFS imports. Phase 2: Pure Go iSCSI target (RFC 7143) with PDU codec, login negotiation, SendTargets discovery, 12 SCSI opcodes, Data-In/Out/R2T sequencing, session management, and standalone iscsi-target binary. 164 tests. IQN->BlockDevice binding via DeviceLookup interface. Total: 338 tests, 14.6K lines. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>feature/sw-block
38 changed files with 14618 additions and 0 deletions
-
371weed/storage/blockvol/blockvol.go
-
3903weed/storage/blockvol/blockvol_qa_test.go
-
685weed/storage/blockvol/blockvol_test.go
-
110weed/storage/blockvol/dirty_map.go
-
153weed/storage/blockvol/dirty_map_test.go
-
239weed/storage/blockvol/flusher.go
-
261weed/storage/blockvol/flusher_test.go
-
167weed/storage/blockvol/group_commit.go
-
287weed/storage/blockvol/group_commit_test.go
-
141weed/storage/blockvol/iscsi/cmd/iscsi-target/main.go
-
216weed/storage/blockvol/iscsi/cmd/iscsi-target/smoke-test.sh
-
207weed/storage/blockvol/iscsi/dataio.go
-
410weed/storage/blockvol/iscsi/dataio_test.go
-
68weed/storage/blockvol/iscsi/discovery.go
-
213weed/storage/blockvol/iscsi/discovery_test.go
-
412weed/storage/blockvol/iscsi/integration_test.go
-
396weed/storage/blockvol/iscsi/login.go
-
444weed/storage/blockvol/iscsi/login_test.go
-
183weed/storage/blockvol/iscsi/params.go
-
357weed/storage/blockvol/iscsi/params_test.go
-
447weed/storage/blockvol/iscsi/pdu.go
-
559weed/storage/blockvol/iscsi/pdu_test.go
-
463weed/storage/blockvol/iscsi/scsi.go
-
692weed/storage/blockvol/iscsi/scsi_test.go
-
421weed/storage/blockvol/iscsi/session.go
-
364weed/storage/blockvol/iscsi/session_test.go
-
216weed/storage/blockvol/iscsi/target.go
-
259weed/storage/blockvol/iscsi/target_test.go
-
38weed/storage/blockvol/lba.go
-
95weed/storage/blockvol/lba_test.go
-
157weed/storage/blockvol/recovery.go
-
416weed/storage/blockvol/recovery_test.go
-
249weed/storage/blockvol/superblock.go
-
146weed/storage/blockvol/superblock_test.go
-
153weed/storage/blockvol/wal_entry.go
-
284weed/storage/blockvol/wal_entry_test.go
-
192weed/storage/blockvol/wal_writer.go
-
244weed/storage/blockvol/wal_writer_test.go
@ -0,0 +1,371 @@ |
|||
// Package blockvol implements a block volume storage engine with WAL,
|
|||
// dirty map, group commit, and crash recovery. It operates on a single
|
|||
// file and has no dependency on SeaweedFS internals.
|
|||
package blockvol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"fmt" |
|||
"os" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
) |
|||
|
|||
// CreateOptions configures a new block volume.
|
|||
type CreateOptions struct { |
|||
VolumeSize uint64 // required, logical size in bytes
|
|||
ExtentSize uint32 // default 64KB
|
|||
BlockSize uint32 // default 4KB
|
|||
WALSize uint64 // default 64MB
|
|||
Replication string // default "000"
|
|||
} |
|||
|
|||
// BlockVol is the core block volume engine.
|
|||
type BlockVol struct { |
|||
mu sync.RWMutex |
|||
fd *os.File |
|||
path string |
|||
super Superblock |
|||
wal *WALWriter |
|||
dirtyMap *DirtyMap |
|||
groupCommit *GroupCommitter |
|||
flusher *Flusher |
|||
nextLSN atomic.Uint64 |
|||
healthy atomic.Bool |
|||
} |
|||
|
|||
// CreateBlockVol creates a new block volume file at path.
|
|||
func CreateBlockVol(path string, opts CreateOptions) (*BlockVol, error) { |
|||
if opts.VolumeSize == 0 { |
|||
return nil, ErrInvalidVolumeSize |
|||
} |
|||
|
|||
sb, err := NewSuperblock(opts.VolumeSize, opts) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("blockvol: create superblock: %w", err) |
|||
} |
|||
sb.CreatedAt = uint64(time.Now().Unix()) |
|||
|
|||
// Extent region starts after superblock + WAL.
|
|||
extentStart := sb.WALOffset + sb.WALSize |
|||
totalFileSize := extentStart + opts.VolumeSize |
|||
|
|||
fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0644) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("blockvol: create file: %w", err) |
|||
} |
|||
|
|||
if err := fd.Truncate(int64(totalFileSize)); err != nil { |
|||
fd.Close() |
|||
os.Remove(path) |
|||
return nil, fmt.Errorf("blockvol: truncate: %w", err) |
|||
} |
|||
|
|||
if _, err := sb.WriteTo(fd); err != nil { |
|||
fd.Close() |
|||
os.Remove(path) |
|||
return nil, fmt.Errorf("blockvol: write superblock: %w", err) |
|||
} |
|||
|
|||
if err := fd.Sync(); err != nil { |
|||
fd.Close() |
|||
os.Remove(path) |
|||
return nil, fmt.Errorf("blockvol: sync: %w", err) |
|||
} |
|||
|
|||
dm := NewDirtyMap() |
|||
wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, 0, 0) |
|||
v := &BlockVol{ |
|||
fd: fd, |
|||
path: path, |
|||
super: sb, |
|||
wal: wal, |
|||
dirtyMap: dm, |
|||
} |
|||
v.nextLSN.Store(1) |
|||
v.healthy.Store(true) |
|||
v.groupCommit = NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: fd.Sync, |
|||
OnDegraded: func() { v.healthy.Store(false) }, |
|||
}) |
|||
go v.groupCommit.Run() |
|||
v.flusher = NewFlusher(FlusherConfig{ |
|||
FD: fd, |
|||
Super: &v.super, |
|||
WAL: wal, |
|||
DirtyMap: dm, |
|||
}) |
|||
go v.flusher.Run() |
|||
return v, nil |
|||
} |
|||
|
|||
// OpenBlockVol opens an existing block volume file and runs crash recovery.
|
|||
func OpenBlockVol(path string) (*BlockVol, error) { |
|||
fd, err := os.OpenFile(path, os.O_RDWR, 0644) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("blockvol: open file: %w", err) |
|||
} |
|||
|
|||
sb, err := ReadSuperblock(fd) |
|||
if err != nil { |
|||
fd.Close() |
|||
return nil, fmt.Errorf("blockvol: read superblock: %w", err) |
|||
} |
|||
if err := sb.Validate(); err != nil { |
|||
fd.Close() |
|||
return nil, fmt.Errorf("blockvol: validate superblock: %w", err) |
|||
} |
|||
|
|||
dirtyMap := NewDirtyMap() |
|||
|
|||
// Run WAL recovery: replay entries from tail to head.
|
|||
result, err := RecoverWAL(fd, &sb, dirtyMap) |
|||
if err != nil { |
|||
fd.Close() |
|||
return nil, fmt.Errorf("blockvol: recovery: %w", err) |
|||
} |
|||
|
|||
nextLSN := sb.WALCheckpointLSN + 1 |
|||
if result.HighestLSN >= nextLSN { |
|||
nextLSN = result.HighestLSN + 1 |
|||
} |
|||
|
|||
wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, sb.WALHead, sb.WALTail) |
|||
v := &BlockVol{ |
|||
fd: fd, |
|||
path: path, |
|||
super: sb, |
|||
wal: wal, |
|||
dirtyMap: dirtyMap, |
|||
} |
|||
v.nextLSN.Store(nextLSN) |
|||
v.healthy.Store(true) |
|||
v.groupCommit = NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: fd.Sync, |
|||
OnDegraded: func() { v.healthy.Store(false) }, |
|||
}) |
|||
go v.groupCommit.Run() |
|||
v.flusher = NewFlusher(FlusherConfig{ |
|||
FD: fd, |
|||
Super: &v.super, |
|||
WAL: wal, |
|||
DirtyMap: dirtyMap, |
|||
}) |
|||
go v.flusher.Run() |
|||
return v, nil |
|||
} |
|||
|
|||
// WriteLBA writes data at the given logical block address.
|
|||
// Data length must be a multiple of BlockSize.
|
|||
func (v *BlockVol) WriteLBA(lba uint64, data []byte) error { |
|||
if err := ValidateWrite(lba, uint32(len(data)), v.super.VolumeSize, v.super.BlockSize); err != nil { |
|||
return err |
|||
} |
|||
|
|||
lsn := v.nextLSN.Add(1) - 1 |
|||
entry := &WALEntry{ |
|||
LSN: lsn, |
|||
Epoch: 0, // Phase 1: no fencing
|
|||
Type: EntryTypeWrite, |
|||
LBA: lba, |
|||
Length: uint32(len(data)), |
|||
Data: data, |
|||
} |
|||
|
|||
walOff, err := v.wal.Append(entry) |
|||
if err != nil { |
|||
return fmt.Errorf("blockvol: WriteLBA: %w", err) |
|||
} |
|||
|
|||
// Update dirty map: one entry per block written.
|
|||
blocks := uint32(len(data)) / v.super.BlockSize |
|||
for i := uint32(0); i < blocks; i++ { |
|||
blockOff := walOff // all blocks in this entry share the same WAL offset
|
|||
v.dirtyMap.Put(lba+uint64(i), blockOff, lsn, v.super.BlockSize) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// ReadLBA reads data at the given logical block address.
|
|||
// length is in bytes and must be a multiple of BlockSize.
|
|||
func (v *BlockVol) ReadLBA(lba uint64, length uint32) ([]byte, error) { |
|||
if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
blocks := length / v.super.BlockSize |
|||
result := make([]byte, length) |
|||
|
|||
for i := uint32(0); i < blocks; i++ { |
|||
blockLBA := lba + uint64(i) |
|||
blockData, err := v.readOneBlock(blockLBA) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("blockvol: ReadLBA block %d: %w", blockLBA, err) |
|||
} |
|||
copy(result[i*v.super.BlockSize:], blockData) |
|||
} |
|||
|
|||
return result, nil |
|||
} |
|||
|
|||
// readOneBlock reads a single block, checking dirty map first, then extent.
|
|||
func (v *BlockVol) readOneBlock(lba uint64) ([]byte, error) { |
|||
walOff, _, _, ok := v.dirtyMap.Get(lba) |
|||
if ok { |
|||
return v.readBlockFromWAL(walOff, lba) |
|||
} |
|||
return v.readBlockFromExtent(lba) |
|||
} |
|||
|
|||
// maxWALEntryDataLen caps the data length we trust from a WAL entry header.
|
|||
// Anything larger than the WAL region itself is corrupt.
|
|||
const maxWALEntryDataLen = 256 * 1024 * 1024 // 256MB absolute ceiling
|
|||
|
|||
// readBlockFromWAL reads a block's data from its WAL entry.
|
|||
func (v *BlockVol) readBlockFromWAL(walOff uint64, lba uint64) ([]byte, error) { |
|||
// Read the WAL entry header to get the full entry size.
|
|||
headerBuf := make([]byte, walEntryHeaderSize) |
|||
absOff := int64(v.super.WALOffset + walOff) |
|||
if _, err := v.fd.ReadAt(headerBuf, absOff); err != nil { |
|||
return nil, fmt.Errorf("readBlockFromWAL: read header at %d: %w", absOff, err) |
|||
} |
|||
|
|||
// Check entry type first — TRIM has no data payload, so Length is
|
|||
// metadata (trim extent), not a data size to allocate.
|
|||
entryType := headerBuf[16] // Type is at offset LSN(8) + Epoch(8) = 16
|
|||
if entryType == EntryTypeTrim { |
|||
// TRIM entry: return zeros regardless of Length field.
|
|||
return make([]byte, v.super.BlockSize), nil |
|||
} |
|||
if entryType != EntryTypeWrite { |
|||
return nil, fmt.Errorf("readBlockFromWAL: expected WRITE or TRIM entry, got type 0x%02x", entryType) |
|||
} |
|||
|
|||
// Parse and validate the data Length field before allocating (WRITE only).
|
|||
dataLen := v.parseDataLength(headerBuf) |
|||
if uint64(dataLen) > v.super.WALSize || uint64(dataLen) > maxWALEntryDataLen { |
|||
return nil, fmt.Errorf("readBlockFromWAL: corrupt entry length %d exceeds WAL size %d", dataLen, v.super.WALSize) |
|||
} |
|||
|
|||
entryLen := walEntryHeaderSize + int(dataLen) |
|||
fullBuf := make([]byte, entryLen) |
|||
if _, err := v.fd.ReadAt(fullBuf, absOff); err != nil { |
|||
return nil, fmt.Errorf("readBlockFromWAL: read entry at %d: %w", absOff, err) |
|||
} |
|||
|
|||
entry, err := DecodeWALEntry(fullBuf) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("readBlockFromWAL: decode: %w", err) |
|||
} |
|||
|
|||
// Find the block within the entry's data.
|
|||
blockOffset := (lba - entry.LBA) * uint64(v.super.BlockSize) |
|||
if blockOffset+uint64(v.super.BlockSize) > uint64(len(entry.Data)) { |
|||
return nil, fmt.Errorf("readBlockFromWAL: block offset %d out of range for entry data len %d", blockOffset, len(entry.Data)) |
|||
} |
|||
|
|||
block := make([]byte, v.super.BlockSize) |
|||
copy(block, entry.Data[blockOffset:blockOffset+uint64(v.super.BlockSize)]) |
|||
return block, nil |
|||
} |
|||
|
|||
// parseDataLength extracts the Length field from a WAL entry header buffer.
|
|||
func (v *BlockVol) parseDataLength(headerBuf []byte) uint32 { |
|||
// Length is at offset: LSN(8) + Epoch(8) + Type(1) + Flags(1) + LBA(8) = 26
|
|||
return binary.LittleEndian.Uint32(headerBuf[26:]) |
|||
} |
|||
|
|||
// readBlockFromExtent reads a block directly from the extent region.
|
|||
func (v *BlockVol) readBlockFromExtent(lba uint64) ([]byte, error) { |
|||
extentStart := v.super.WALOffset + v.super.WALSize |
|||
byteOffset := extentStart + lba*uint64(v.super.BlockSize) |
|||
|
|||
block := make([]byte, v.super.BlockSize) |
|||
if _, err := v.fd.ReadAt(block, int64(byteOffset)); err != nil { |
|||
return nil, fmt.Errorf("readBlockFromExtent: pread at %d: %w", byteOffset, err) |
|||
} |
|||
return block, nil |
|||
} |
|||
|
|||
// Trim marks blocks as deallocated. Subsequent reads return zeros.
|
|||
// The trim is recorded in the WAL with a Length field so the flusher
|
|||
// can zero the extent region and recovery can replay the trim.
|
|||
func (v *BlockVol) Trim(lba uint64, length uint32) error { |
|||
if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil { |
|||
return err |
|||
} |
|||
|
|||
lsn := v.nextLSN.Add(1) - 1 |
|||
entry := &WALEntry{ |
|||
LSN: lsn, |
|||
Epoch: 0, |
|||
Type: EntryTypeTrim, |
|||
LBA: lba, |
|||
Length: length, |
|||
} |
|||
|
|||
walOff, err := v.wal.Append(entry) |
|||
if err != nil { |
|||
return fmt.Errorf("blockvol: Trim: %w", err) |
|||
} |
|||
|
|||
// Update dirty map: mark each trimmed block so the flusher sees it.
|
|||
// readOneBlock checks entry type and returns zeros for TRIM entries.
|
|||
blocks := length / v.super.BlockSize |
|||
for i := uint32(0); i < blocks; i++ { |
|||
v.dirtyMap.Put(lba+uint64(i), walOff, lsn, v.super.BlockSize) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Path returns the file path of the block volume.
|
|||
func (v *BlockVol) Path() string { |
|||
return v.path |
|||
} |
|||
|
|||
// Info returns volume metadata.
|
|||
func (v *BlockVol) Info() VolumeInfo { |
|||
return VolumeInfo{ |
|||
VolumeSize: v.super.VolumeSize, |
|||
BlockSize: v.super.BlockSize, |
|||
ExtentSize: v.super.ExtentSize, |
|||
WALSize: v.super.WALSize, |
|||
Healthy: v.healthy.Load(), |
|||
} |
|||
} |
|||
|
|||
// VolumeInfo contains read-only volume metadata.
|
|||
type VolumeInfo struct { |
|||
VolumeSize uint64 |
|||
BlockSize uint32 |
|||
ExtentSize uint32 |
|||
WALSize uint64 |
|||
Healthy bool |
|||
} |
|||
|
|||
// SyncCache ensures all previously written WAL entries are durable on disk.
|
|||
// It submits a sync request to the group committer, which batches fsyncs.
|
|||
func (v *BlockVol) SyncCache() error { |
|||
return v.groupCommit.Submit() |
|||
} |
|||
|
|||
// Close shuts down the block volume and closes the file.
|
|||
// Shutdown order: group committer → stop flusher goroutine → final flush → close fd.
|
|||
func (v *BlockVol) Close() error { |
|||
if v.groupCommit != nil { |
|||
v.groupCommit.Stop() |
|||
} |
|||
var flushErr error |
|||
if v.flusher != nil { |
|||
v.flusher.Stop() // stop background goroutine first (no concurrent flush)
|
|||
flushErr = v.flusher.FlushOnce() // then do final flush safely
|
|||
} |
|||
closeErr := v.fd.Close() |
|||
if flushErr != nil { |
|||
return flushErr |
|||
} |
|||
return closeErr |
|||
} |
|||
3903
weed/storage/blockvol/blockvol_qa_test.go
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
@ -0,0 +1,685 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"errors" |
|||
"os" |
|||
"path/filepath" |
|||
"testing" |
|||
) |
|||
|
|||
func TestBlockVol(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "write_then_read", run: testWriteThenRead}, |
|||
{name: "overwrite_read_latest", run: testOverwriteReadLatest}, |
|||
{name: "write_no_sync_not_durable", run: testWriteNoSyncNotDurable}, |
|||
{name: "write_sync_durable", run: testWriteSyncDurable}, |
|||
{name: "write_multiple_sync", run: testWriteMultipleSync}, |
|||
{name: "read_unflushed", run: testReadUnflushed}, |
|||
{name: "read_flushed", run: testReadFlushed}, |
|||
{name: "read_mixed_dirty_clean", run: testReadMixedDirtyClean}, |
|||
{name: "wal_read_corrupt_length", run: testWALReadCorruptLength}, |
|||
{name: "open_invalid_superblock", run: testOpenInvalidSuperblock}, |
|||
{name: "trim_large_length_read_returns_zero", run: testTrimLargeLengthReadReturnsZero}, |
|||
// Task 1.10: Lifecycle tests.
|
|||
{name: "lifecycle_create_close_reopen", run: testLifecycleCreateCloseReopen}, |
|||
{name: "lifecycle_close_flushes_dirty", run: testLifecycleCloseFlushes}, |
|||
{name: "lifecycle_double_close", run: testLifecycleDoubleClose}, |
|||
{name: "lifecycle_info", run: testLifecycleInfo}, |
|||
{name: "lifecycle_write_sync_close_reopen", run: testLifecycleWriteSyncCloseReopen}, |
|||
// Task 1.11: Crash stress test.
|
|||
{name: "crash_stress_100", run: testCrashStress100}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func createTestVol(t *testing.T) *BlockVol { |
|||
t.Helper() |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "test.blockvol") |
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: 1 * 1024 * 1024, // 1MB
|
|||
BlockSize: 4096, |
|||
WALSize: 256 * 1024, // 256KB WAL
|
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
return v |
|||
} |
|||
|
|||
func makeBlock(fill byte) []byte { |
|||
b := make([]byte, 4096) |
|||
for i := range b { |
|||
b[i] = fill |
|||
} |
|||
return b |
|||
} |
|||
|
|||
func testWriteThenRead(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
data := makeBlock('A') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("read data does not match written data") |
|||
} |
|||
} |
|||
|
|||
func testOverwriteReadLatest(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
|||
t.Fatalf("WriteLBA(A): %v", err) |
|||
} |
|||
if err := v.WriteLBA(0, makeBlock('B')); err != nil { |
|||
t.Fatalf("WriteLBA(B): %v", err) |
|||
} |
|||
|
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA: %v", err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock('B')) { |
|||
t.Error("read should return latest write ('B'), not 'A'") |
|||
} |
|||
} |
|||
|
|||
func testWriteNoSyncNotDurable(t *testing.T) { |
|||
v := createTestVol(t) |
|||
path := v.Path() |
|||
|
|||
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Simulate crash: close fd without sync.
|
|||
v.fd.Close() |
|||
|
|||
// Reopen -- without recovery (Phase 1.9), data MAY be lost.
|
|||
// This test just verifies we can reopen without error.
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol after crash: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
// Data may or may not be present -- both are correct without SyncCache.
|
|||
} |
|||
|
|||
func testWriteSyncDurable(t *testing.T) { |
|||
v := createTestVol(t) |
|||
path := v.Path() |
|||
|
|||
data := makeBlock('A') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Sync WAL (manual fsync for now, group commit in Task 1.7).
|
|||
if err := v.wal.Sync(); err != nil { |
|||
t.Fatalf("Sync: %v", err) |
|||
} |
|||
|
|||
// Update superblock WALHead so reopen knows where entries are.
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALCheckpointLSN = 0 |
|||
if _, err := v.fd.Seek(0, 0); err != nil { |
|||
t.Fatalf("Seek: %v", err) |
|||
} |
|||
if _, err := v.super.WriteTo(v.fd); err != nil { |
|||
t.Fatalf("WriteTo: %v", err) |
|||
} |
|||
v.fd.Sync() |
|||
v.fd.Close() |
|||
|
|||
// Reopen and manually replay WAL to verify data is durable.
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
// Manually replay: read WAL entry and populate dirty map.
|
|||
replayBuf := make([]byte, v2.super.WALHead) |
|||
if _, err := v2.fd.ReadAt(replayBuf, int64(v2.super.WALOffset)); err != nil { |
|||
t.Fatalf("read WAL for replay: %v", err) |
|||
} |
|||
entry, err := DecodeWALEntry(replayBuf) |
|||
if err != nil { |
|||
t.Fatalf("decode WAL entry: %v", err) |
|||
} |
|||
blocks := entry.Length / v2.super.BlockSize |
|||
for i := uint32(0); i < blocks; i++ { |
|||
v2.dirtyMap.Put(entry.LBA+uint64(i), 0, entry.LSN, v2.super.BlockSize) |
|||
} |
|||
|
|||
got, err := v2.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA after recovery: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("data not durable after sync + reopen") |
|||
} |
|||
} |
|||
|
|||
func testWriteMultipleSync(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
// Write 10 blocks with different data.
|
|||
for i := uint64(0); i < 10; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Read all back.
|
|||
for i := uint64(0); i < 10; i++ { |
|||
got, err := v.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d): %v", i, err) |
|||
} |
|||
expected := makeBlock(byte('A' + i)) |
|||
if !bytes.Equal(got, expected) { |
|||
t.Errorf("block %d: data mismatch", i) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testReadUnflushed(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
data := makeBlock('X') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Read before flusher runs -- should come from dirty map / WAL.
|
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("unflushed read: data mismatch") |
|||
} |
|||
} |
|||
|
|||
func testReadFlushed(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
data := makeBlock('F') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Manually flush: copy WAL data to extent region, clear dirty map.
|
|||
extentStart := v.super.WALOffset + v.super.WALSize |
|||
if _, err := v.fd.WriteAt(data, int64(extentStart)); err != nil { |
|||
t.Fatalf("manual flush write: %v", err) |
|||
} |
|||
v.dirtyMap.Delete(0) |
|||
|
|||
// Read should now come from extent region.
|
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA after flush: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("flushed read: data mismatch") |
|||
} |
|||
} |
|||
|
|||
func testReadMixedDirtyClean(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
// Write blocks 0, 2, 4 (dirty).
|
|||
for _, lba := range []uint64{0, 2, 4} { |
|||
if err := v.WriteLBA(lba, makeBlock(byte('A'+lba))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", lba, err) |
|||
} |
|||
} |
|||
|
|||
// Manually flush block 0 to extent, remove from dirty map.
|
|||
extentStart := v.super.WALOffset + v.super.WALSize |
|||
if _, err := v.fd.WriteAt(makeBlock('A'), int64(extentStart)); err != nil { |
|||
t.Fatalf("manual flush: %v", err) |
|||
} |
|||
v.dirtyMap.Delete(0) |
|||
|
|||
// Block 0: from extent (flushed)
|
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(0): %v", err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock('A')) { |
|||
t.Error("block 0 (flushed) mismatch") |
|||
} |
|||
|
|||
// Block 2: from dirty map (WAL)
|
|||
got, err = v.ReadLBA(2, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(2): %v", err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock('C')) { // 'A'+2 = 'C'
|
|||
t.Error("block 2 (dirty) mismatch") |
|||
} |
|||
|
|||
// Block 4: from dirty map (WAL)
|
|||
got, err = v.ReadLBA(4, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(4): %v", err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock('E')) { // 'A'+4 = 'E'
|
|||
t.Error("block 4 (dirty) mismatch") |
|||
} |
|||
|
|||
// Blocks 1, 3, 5: never written, should be zeros (from extent region).
|
|||
for _, lba := range []uint64{1, 3, 5} { |
|||
got, err = v.ReadLBA(lba, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d): %v", lba, err) |
|||
} |
|||
if !bytes.Equal(got, make([]byte, 4096)) { |
|||
t.Errorf("block %d (unwritten) should be zeros", lba) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testWALReadCorruptLength(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
// Write a valid block.
|
|||
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Get the WAL offset from dirty map.
|
|||
walOff, _, _, ok := v.dirtyMap.Get(0) |
|||
if !ok { |
|||
t.Fatal("block 0 not in dirty map") |
|||
} |
|||
|
|||
// Corrupt the Length field in the WAL entry on disk.
|
|||
// Length is at header offset 26 (LSN=8 + Epoch=8 + Type=1 + Flags=1 + LBA=8).
|
|||
absOff := int64(v.super.WALOffset + walOff) |
|||
lengthOff := absOff + 26 |
|||
var hugeLenBuf [4]byte |
|||
binary.LittleEndian.PutUint32(hugeLenBuf[:], 999999999) // ~1GB
|
|||
if _, err := v.fd.WriteAt(hugeLenBuf[:], lengthOff); err != nil { |
|||
t.Fatalf("corrupt length: %v", err) |
|||
} |
|||
|
|||
// ReadLBA should detect the corrupt length and error (not panic/OOM).
|
|||
_, err := v.ReadLBA(0, 4096) |
|||
if err == nil { |
|||
t.Error("expected error reading corrupt WAL entry, got nil") |
|||
} |
|||
} |
|||
|
|||
func testOpenInvalidSuperblock(t *testing.T) { |
|||
dir := t.TempDir() |
|||
|
|||
// Create a valid volume, then corrupt BlockSize to 0.
|
|||
path := filepath.Join(dir, "corrupt.blockvol") |
|||
v, err := CreateBlockVol(path, CreateOptions{VolumeSize: 1024 * 1024}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
v.Close() |
|||
|
|||
// Corrupt BlockSize (at superblock offset 36: Magic=4 + Version=2 + Flags=2 + UUID=16 + VolumeSize=8 + ExtentSize=4 = 36).
|
|||
fd, err := os.OpenFile(path, os.O_RDWR, 0644) |
|||
if err != nil { |
|||
t.Fatalf("open for corrupt: %v", err) |
|||
} |
|||
var zeroBuf [4]byte |
|||
if _, err := fd.WriteAt(zeroBuf[:], 36); err != nil { |
|||
fd.Close() |
|||
t.Fatalf("corrupt blocksize: %v", err) |
|||
} |
|||
fd.Close() |
|||
|
|||
// OpenBlockVol should reject the corrupt superblock.
|
|||
_, err = OpenBlockVol(path) |
|||
if err == nil { |
|||
t.Fatal("expected error opening volume with BlockSize=0") |
|||
} |
|||
if !errors.Is(err, ErrInvalidSuperblock) { |
|||
t.Errorf("expected ErrInvalidSuperblock, got: %v", err) |
|||
} |
|||
} |
|||
|
|||
// --- Task 1.10: Lifecycle tests ---
|
|||
|
|||
func testLifecycleCreateCloseReopen(t *testing.T) { |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "lifecycle.blockvol") |
|||
|
|||
// Create, write, close.
|
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: 1 * 1024 * 1024, |
|||
BlockSize: 4096, |
|||
WALSize: 256 * 1024, |
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
|
|||
data := makeBlock('L') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
if err := v.Close(); err != nil { |
|||
t.Fatalf("Close: %v", err) |
|||
} |
|||
|
|||
// Reopen and verify data survived.
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
got, err := v2.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("data not durable after create→write→sync→close→reopen") |
|||
} |
|||
} |
|||
|
|||
func testLifecycleCloseFlushes(t *testing.T) { |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "flush.blockvol") |
|||
|
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: 1 * 1024 * 1024, |
|||
BlockSize: 4096, |
|||
WALSize: 256 * 1024, |
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
|
|||
// Write several blocks.
|
|||
for i := uint64(0); i < 5; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
// Close does a final flush — dirty map should be drained.
|
|||
if err := v.Close(); err != nil { |
|||
t.Fatalf("Close: %v", err) |
|||
} |
|||
|
|||
// Reopen and verify.
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
for i := uint64(0); i < 5; i++ { |
|||
got, err := v2.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d): %v", i, err) |
|||
} |
|||
expected := makeBlock(byte('A' + i)) |
|||
if !bytes.Equal(got, expected) { |
|||
t.Errorf("block %d: data mismatch after close+reopen", i) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testLifecycleDoubleClose(t *testing.T) { |
|||
v := createTestVol(t) |
|||
if err := v.Close(); err != nil { |
|||
t.Fatalf("first Close: %v", err) |
|||
} |
|||
// Second close should not panic (group committer + flusher are idempotent).
|
|||
// The fd.Close() will return an error but should not panic.
|
|||
_ = v.Close() |
|||
} |
|||
|
|||
func testLifecycleInfo(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
info := v.Info() |
|||
if info.VolumeSize != 1*1024*1024 { |
|||
t.Errorf("VolumeSize = %d, want %d", info.VolumeSize, 1*1024*1024) |
|||
} |
|||
if info.BlockSize != 4096 { |
|||
t.Errorf("BlockSize = %d, want 4096", info.BlockSize) |
|||
} |
|||
if !info.Healthy { |
|||
t.Error("Healthy = false, want true") |
|||
} |
|||
} |
|||
|
|||
func testLifecycleWriteSyncCloseReopen(t *testing.T) { |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "wsco.blockvol") |
|||
|
|||
// Cycle: create → write → sync → close → reopen → write → sync → close → verify.
|
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: 1 * 1024 * 1024, |
|||
BlockSize: 4096, |
|||
WALSize: 256 * 1024, |
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
|
|||
if err := v.WriteLBA(0, makeBlock('X')); err != nil { |
|||
t.Fatalf("WriteLBA round 1: %v", err) |
|||
} |
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache round 1: %v", err) |
|||
} |
|||
v.Close() |
|||
|
|||
// Reopen, write more.
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol round 2: %v", err) |
|||
} |
|||
|
|||
if err := v2.WriteLBA(1, makeBlock('Y')); err != nil { |
|||
t.Fatalf("WriteLBA round 2: %v", err) |
|||
} |
|||
// Overwrite block 0.
|
|||
if err := v2.WriteLBA(0, makeBlock('Z')); err != nil { |
|||
t.Fatalf("WriteLBA overwrite: %v", err) |
|||
} |
|||
if err := v2.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache round 2: %v", err) |
|||
} |
|||
v2.Close() |
|||
|
|||
// Final reopen — verify.
|
|||
v3, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol round 3: %v", err) |
|||
} |
|||
defer v3.Close() |
|||
|
|||
got0, _ := v3.ReadLBA(0, 4096) |
|||
if !bytes.Equal(got0, makeBlock('Z')) { |
|||
t.Error("block 0: expected 'Z' (overwritten)") |
|||
} |
|||
got1, _ := v3.ReadLBA(1, 4096) |
|||
if !bytes.Equal(got1, makeBlock('Y')) { |
|||
t.Error("block 1: expected 'Y'") |
|||
} |
|||
} |
|||
|
|||
// --- Task 1.11: Crash stress test ---
|
|||
|
|||
func testCrashStress100(t *testing.T) { |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "stress.blockvol") |
|||
|
|||
const ( |
|||
volumeSize = 256 * 1024 // 256KB volume (64 blocks of 4KB)
|
|||
blockSize = 4096 |
|||
walSize = 64 * 1024 // 64KB WAL
|
|||
maxLBA = volumeSize / blockSize |
|||
iterations = 100 |
|||
) |
|||
|
|||
// Oracle: tracks the expected state of each block.
|
|||
oracle := make(map[uint64]byte) // lba → fill byte (0 = zeros/trimmed)
|
|||
|
|||
// Create initial volume.
|
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: volumeSize, |
|||
BlockSize: blockSize, |
|||
WALSize: walSize, |
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
|
|||
for iter := 0; iter < iterations; iter++ { |
|||
// Deterministic "random" ops using iteration number.
|
|||
numOps := 3 + (iter % 5) // 3-7 ops per iteration
|
|||
|
|||
for op := 0; op < numOps; op++ { |
|||
lba := uint64((iter*7 + op*13) % maxLBA) |
|||
action := (iter + op) % 3 // 0=write, 1=overwrite, 2=trim
|
|||
|
|||
switch action { |
|||
case 0, 1: // write
|
|||
fill := byte('A' + (iter+op)%26) |
|||
err := v.WriteLBA(lba, makeBlock(fill)) |
|||
if err != nil { |
|||
if errors.Is(err, ErrWALFull) { |
|||
continue // WAL full, skip this op
|
|||
} |
|||
t.Fatalf("iter %d op %d: WriteLBA(%d): %v", iter, op, lba, err) |
|||
} |
|||
oracle[lba] = fill |
|||
case 2: // trim
|
|||
err := v.Trim(lba, blockSize) |
|||
if err != nil { |
|||
if errors.Is(err, ErrWALFull) { |
|||
continue |
|||
} |
|||
t.Fatalf("iter %d op %d: Trim(%d): %v", iter, op, lba, err) |
|||
} |
|||
oracle[lba] = 0 |
|||
} |
|||
} |
|||
|
|||
// Sync WAL.
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("iter %d: SyncCache: %v", iter, err) |
|||
} |
|||
|
|||
// Simulate crash: close fd without clean shutdown.
|
|||
v.fd.Sync() // ensure WAL is on disk
|
|||
// Write superblock with current WAL positions for recovery.
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
if _, seekErr := v.fd.Seek(0, 0); seekErr != nil { |
|||
t.Fatalf("iter %d: Seek: %v", iter, seekErr) |
|||
} |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
// Hard crash: close fd, stop goroutines.
|
|||
v.groupCommit.Stop() |
|||
v.flusher.Stop() |
|||
v.fd.Close() |
|||
|
|||
// Reopen with recovery.
|
|||
v, err = OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("iter %d: OpenBlockVol: %v", iter, err) |
|||
} |
|||
|
|||
// Verify oracle against actual reads.
|
|||
for lba, fill := range oracle { |
|||
got, readErr := v.ReadLBA(lba, blockSize) |
|||
if readErr != nil { |
|||
t.Fatalf("iter %d: ReadLBA(%d): %v", iter, lba, readErr) |
|||
} |
|||
var expected []byte |
|||
if fill == 0 { |
|||
expected = make([]byte, blockSize) |
|||
} else { |
|||
expected = makeBlock(fill) |
|||
} |
|||
if !bytes.Equal(got, expected) { |
|||
t.Fatalf("iter %d: block %d mismatch: got[0]=%d want[0]=%d", iter, lba, got[0], expected[0]) |
|||
} |
|||
} |
|||
} |
|||
|
|||
v.Close() |
|||
} |
|||
|
|||
func testTrimLargeLengthReadReturnsZero(t *testing.T) { |
|||
v := createTestVol(t) |
|||
defer v.Close() |
|||
|
|||
// Write data first.
|
|||
data := makeBlock('A') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Verify data is written.
|
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA before trim: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Fatal("data mismatch before trim") |
|||
} |
|||
|
|||
// Trim with a length larger than WAL size — should still work.
|
|||
// The trim Length is metadata (trim extent), not a data allocation.
|
|||
if err := v.Trim(0, 4096); err != nil { |
|||
t.Fatalf("Trim: %v", err) |
|||
} |
|||
|
|||
// Read should return zeros (TRIM entry in dirty map).
|
|||
got, err = v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA after trim: %v", err) |
|||
} |
|||
if !bytes.Equal(got, make([]byte, 4096)) { |
|||
t.Error("read after trim should return zeros") |
|||
} |
|||
} |
|||
@ -0,0 +1,110 @@ |
|||
package blockvol |
|||
|
|||
import "sync" |
|||
|
|||
// dirtyEntry tracks a single LBA's location in the WAL.
|
|||
type dirtyEntry struct { |
|||
walOffset uint64 |
|||
lsn uint64 |
|||
length uint32 |
|||
} |
|||
|
|||
// DirtyMap is a concurrency-safe map from LBA to WAL offset.
|
|||
// Phase 1 uses a single shard; Phase 3 upgrades to 256 shards.
|
|||
type DirtyMap struct { |
|||
mu sync.RWMutex |
|||
m map[uint64]dirtyEntry |
|||
} |
|||
|
|||
// NewDirtyMap creates an empty dirty map.
|
|||
func NewDirtyMap() *DirtyMap { |
|||
return &DirtyMap{m: make(map[uint64]dirtyEntry)} |
|||
} |
|||
|
|||
// Put records that the given LBA has dirty data at walOffset.
|
|||
func (d *DirtyMap) Put(lba uint64, walOffset uint64, lsn uint64, length uint32) { |
|||
d.mu.Lock() |
|||
d.m[lba] = dirtyEntry{walOffset: walOffset, lsn: lsn, length: length} |
|||
d.mu.Unlock() |
|||
} |
|||
|
|||
// Get returns the dirty entry for lba. ok is false if lba is not dirty.
|
|||
func (d *DirtyMap) Get(lba uint64) (walOffset uint64, lsn uint64, length uint32, ok bool) { |
|||
d.mu.RLock() |
|||
e, found := d.m[lba] |
|||
d.mu.RUnlock() |
|||
if !found { |
|||
return 0, 0, 0, false |
|||
} |
|||
return e.walOffset, e.lsn, e.length, true |
|||
} |
|||
|
|||
// Delete removes a single LBA from the dirty map.
|
|||
func (d *DirtyMap) Delete(lba uint64) { |
|||
d.mu.Lock() |
|||
delete(d.m, lba) |
|||
d.mu.Unlock() |
|||
} |
|||
|
|||
// rangeEntry is a snapshot of one dirty entry for lock-free iteration.
|
|||
type rangeEntry struct { |
|||
lba uint64 |
|||
walOffset uint64 |
|||
lsn uint64 |
|||
length uint32 |
|||
} |
|||
|
|||
// Range calls fn for each dirty entry with LBA in [start, start+count).
|
|||
// Entries are copied under lock, then fn is called without holding the lock,
|
|||
// so fn may safely call back into DirtyMap (Put, Delete, etc.).
|
|||
func (d *DirtyMap) Range(start uint64, count uint32, fn func(lba, walOffset, lsn uint64, length uint32)) { |
|||
end := start + uint64(count) |
|||
|
|||
// Snapshot matching entries under lock.
|
|||
d.mu.RLock() |
|||
entries := make([]rangeEntry, 0, len(d.m)) |
|||
for lba, e := range d.m { |
|||
if lba >= start && lba < end { |
|||
entries = append(entries, rangeEntry{lba, e.walOffset, e.lsn, e.length}) |
|||
} |
|||
} |
|||
d.mu.RUnlock() |
|||
|
|||
// Call fn without lock.
|
|||
for _, e := range entries { |
|||
fn(e.lba, e.walOffset, e.lsn, e.length) |
|||
} |
|||
} |
|||
|
|||
// Len returns the number of dirty entries.
|
|||
func (d *DirtyMap) Len() int { |
|||
d.mu.RLock() |
|||
n := len(d.m) |
|||
d.mu.RUnlock() |
|||
return n |
|||
} |
|||
|
|||
// SnapshotEntry is an exported snapshot of one dirty entry.
|
|||
type SnapshotEntry struct { |
|||
Lba uint64 |
|||
WalOffset uint64 |
|||
Lsn uint64 |
|||
Length uint32 |
|||
} |
|||
|
|||
// Snapshot returns a copy of all dirty entries. The snapshot is taken under
|
|||
// read lock but returned without holding the lock.
|
|||
func (d *DirtyMap) Snapshot() []SnapshotEntry { |
|||
d.mu.RLock() |
|||
entries := make([]SnapshotEntry, 0, len(d.m)) |
|||
for lba, e := range d.m { |
|||
entries = append(entries, SnapshotEntry{ |
|||
Lba: lba, |
|||
WalOffset: e.walOffset, |
|||
Lsn: e.lsn, |
|||
Length: e.length, |
|||
}) |
|||
} |
|||
d.mu.RUnlock() |
|||
return entries |
|||
} |
|||
@ -0,0 +1,153 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"sync" |
|||
"testing" |
|||
) |
|||
|
|||
func TestDirtyMap(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "dirty_put_get", run: testDirtyPutGet}, |
|||
{name: "dirty_overwrite", run: testDirtyOverwrite}, |
|||
{name: "dirty_delete", run: testDirtyDelete}, |
|||
{name: "dirty_range_query", run: testDirtyRangeQuery}, |
|||
{name: "dirty_empty", run: testDirtyEmpty}, |
|||
{name: "dirty_concurrent_rw", run: testDirtyConcurrentRW}, |
|||
{name: "dirty_range_modify", run: testDirtyRangeModify}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testDirtyPutGet(t *testing.T) { |
|||
dm := NewDirtyMap() |
|||
dm.Put(100, 5000, 1, 4096) |
|||
|
|||
off, lsn, length, ok := dm.Get(100) |
|||
if !ok { |
|||
t.Fatal("Get(100) returned not-found") |
|||
} |
|||
if off != 5000 { |
|||
t.Errorf("walOffset = %d, want 5000", off) |
|||
} |
|||
if lsn != 1 { |
|||
t.Errorf("lsn = %d, want 1", lsn) |
|||
} |
|||
if length != 4096 { |
|||
t.Errorf("length = %d, want 4096", length) |
|||
} |
|||
} |
|||
|
|||
func testDirtyOverwrite(t *testing.T) { |
|||
dm := NewDirtyMap() |
|||
dm.Put(100, 5000, 1, 4096) |
|||
dm.Put(100, 9000, 2, 4096) |
|||
|
|||
off, lsn, _, ok := dm.Get(100) |
|||
if !ok { |
|||
t.Fatal("Get(100) returned not-found") |
|||
} |
|||
if off != 9000 { |
|||
t.Errorf("walOffset = %d, want 9000 (latest)", off) |
|||
} |
|||
if lsn != 2 { |
|||
t.Errorf("lsn = %d, want 2 (latest)", lsn) |
|||
} |
|||
} |
|||
|
|||
func testDirtyDelete(t *testing.T) { |
|||
dm := NewDirtyMap() |
|||
dm.Put(100, 5000, 1, 4096) |
|||
dm.Delete(100) |
|||
|
|||
_, _, _, ok := dm.Get(100) |
|||
if ok { |
|||
t.Error("Get(100) should return not-found after delete") |
|||
} |
|||
} |
|||
|
|||
func testDirtyRangeQuery(t *testing.T) { |
|||
dm := NewDirtyMap() |
|||
for i := uint64(100); i < 110; i++ { |
|||
dm.Put(i, i*1000, i, 4096) |
|||
} |
|||
|
|||
found := make(map[uint64]bool) |
|||
dm.Range(100, 10, func(lba, walOffset, lsn uint64, length uint32) { |
|||
found[lba] = true |
|||
}) |
|||
|
|||
if len(found) != 10 { |
|||
t.Errorf("Range returned %d entries, want 10", len(found)) |
|||
} |
|||
for i := uint64(100); i < 110; i++ { |
|||
if !found[i] { |
|||
t.Errorf("Range missing LBA %d", i) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testDirtyEmpty(t *testing.T) { |
|||
dm := NewDirtyMap() |
|||
|
|||
_, _, _, ok := dm.Get(0) |
|||
if ok { |
|||
t.Error("empty map: Get(0) should return not-found") |
|||
} |
|||
_, _, _, ok = dm.Get(999) |
|||
if ok { |
|||
t.Error("empty map: Get(999) should return not-found") |
|||
} |
|||
if dm.Len() != 0 { |
|||
t.Errorf("empty map: Len() = %d, want 0", dm.Len()) |
|||
} |
|||
} |
|||
|
|||
func testDirtyConcurrentRW(t *testing.T) { |
|||
dm := NewDirtyMap() |
|||
const goroutines = 16 |
|||
const opsPerGoroutine = 1000 |
|||
|
|||
var wg sync.WaitGroup |
|||
wg.Add(goroutines) |
|||
for g := 0; g < goroutines; g++ { |
|||
go func(id int) { |
|||
defer wg.Done() |
|||
base := uint64(id * opsPerGoroutine) |
|||
for i := uint64(0); i < opsPerGoroutine; i++ { |
|||
lba := base + i |
|||
dm.Put(lba, lba*10, lba, 4096) |
|||
dm.Get(lba) |
|||
if i%3 == 0 { |
|||
dm.Delete(lba) |
|||
} |
|||
} |
|||
}(g) |
|||
} |
|||
wg.Wait() |
|||
|
|||
// No assertion on final state -- the test passes if no race detector fires.
|
|||
} |
|||
|
|||
func testDirtyRangeModify(t *testing.T) { |
|||
// Verify that Range callback can safely call Delete without deadlock.
|
|||
dm := NewDirtyMap() |
|||
for i := uint64(0); i < 10; i++ { |
|||
dm.Put(i, i*1000, i+1, 4096) |
|||
} |
|||
|
|||
// Delete every entry from within the Range callback.
|
|||
dm.Range(0, 10, func(lba, walOffset, lsn uint64, length uint32) { |
|||
dm.Delete(lba) |
|||
}) |
|||
|
|||
if dm.Len() != 0 { |
|||
t.Errorf("after Range+Delete: Len() = %d, want 0", dm.Len()) |
|||
} |
|||
} |
|||
@ -0,0 +1,239 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"fmt" |
|||
"os" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
// Flusher copies WAL entries to the extent region and frees WAL space.
|
|||
// It runs as a background goroutine and can also be triggered manually.
|
|||
type Flusher struct { |
|||
fd *os.File |
|||
super *Superblock |
|||
wal *WALWriter |
|||
dirtyMap *DirtyMap |
|||
walOffset uint64 // absolute file offset of WAL region
|
|||
walSize uint64 |
|||
blockSize uint32 |
|||
extentStart uint64 // absolute file offset of extent region
|
|||
|
|||
mu sync.Mutex |
|||
checkpointLSN uint64 // last flushed LSN
|
|||
checkpointTail uint64 // WAL physical tail after last flush
|
|||
|
|||
interval time.Duration |
|||
notifyCh chan struct{} |
|||
stopCh chan struct{} |
|||
done chan struct{} |
|||
stopOnce sync.Once |
|||
} |
|||
|
|||
// FlusherConfig configures the flusher.
|
|||
type FlusherConfig struct { |
|||
FD *os.File |
|||
Super *Superblock |
|||
WAL *WALWriter |
|||
DirtyMap *DirtyMap |
|||
Interval time.Duration // default 100ms
|
|||
} |
|||
|
|||
// NewFlusher creates a flusher. Call Run() in a goroutine.
|
|||
func NewFlusher(cfg FlusherConfig) *Flusher { |
|||
if cfg.Interval == 0 { |
|||
cfg.Interval = 100 * time.Millisecond |
|||
} |
|||
return &Flusher{ |
|||
fd: cfg.FD, |
|||
super: cfg.Super, |
|||
wal: cfg.WAL, |
|||
dirtyMap: cfg.DirtyMap, |
|||
walOffset: cfg.Super.WALOffset, |
|||
walSize: cfg.Super.WALSize, |
|||
blockSize: cfg.Super.BlockSize, |
|||
extentStart: cfg.Super.WALOffset + cfg.Super.WALSize, |
|||
checkpointLSN: cfg.Super.WALCheckpointLSN, |
|||
checkpointTail: 0, |
|||
interval: cfg.Interval, |
|||
notifyCh: make(chan struct{}, 1), |
|||
stopCh: make(chan struct{}), |
|||
done: make(chan struct{}), |
|||
} |
|||
} |
|||
|
|||
// Run is the flusher main loop. Call in a goroutine.
|
|||
func (f *Flusher) Run() { |
|||
defer close(f.done) |
|||
ticker := time.NewTicker(f.interval) |
|||
defer ticker.Stop() |
|||
|
|||
for { |
|||
select { |
|||
case <-f.stopCh: |
|||
return |
|||
case <-ticker.C: |
|||
f.FlushOnce() |
|||
case <-f.notifyCh: |
|||
f.FlushOnce() |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Notify wakes up the flusher for an immediate flush cycle.
|
|||
func (f *Flusher) Notify() { |
|||
select { |
|||
case f.notifyCh <- struct{}{}: |
|||
default: |
|||
} |
|||
} |
|||
|
|||
// Stop shuts down the flusher. Safe to call multiple times.
|
|||
func (f *Flusher) Stop() { |
|||
f.stopOnce.Do(func() { |
|||
close(f.stopCh) |
|||
}) |
|||
<-f.done |
|||
} |
|||
|
|||
// FlushOnce performs a single flush cycle: scan dirty map, copy data to
|
|||
// extent region, fsync, update checkpoint, advance WAL tail.
|
|||
func (f *Flusher) FlushOnce() error { |
|||
// Snapshot dirty entries. We use a full scan (Range over all possible LBAs
|
|||
// is impractical), so we collect from the dirty map directly.
|
|||
type flushEntry struct { |
|||
lba uint64 |
|||
walOff uint64 |
|||
lsn uint64 |
|||
length uint32 |
|||
} |
|||
|
|||
entries := f.dirtyMap.Snapshot() |
|||
if len(entries) == 0 { |
|||
return nil |
|||
} |
|||
|
|||
// Find the max LSN and max WAL offset to know where to advance tail.
|
|||
var maxLSN uint64 |
|||
var maxWALEnd uint64 |
|||
|
|||
for _, e := range entries { |
|||
// Read the WAL entry and copy data to extent region.
|
|||
headerBuf := make([]byte, walEntryHeaderSize) |
|||
absWALOff := int64(f.walOffset + e.WalOffset) |
|||
if _, err := f.fd.ReadAt(headerBuf, absWALOff); err != nil { |
|||
return fmt.Errorf("flusher: read WAL header at %d: %w", absWALOff, err) |
|||
} |
|||
|
|||
// Parse entry type and length.
|
|||
entryType := headerBuf[16] // Type at LSN(8)+Epoch(8)=16
|
|||
dataLen := parseLength(headerBuf) |
|||
|
|||
if entryType == EntryTypeWrite && dataLen > 0 { |
|||
// Read full entry.
|
|||
entryLen := walEntryHeaderSize + int(dataLen) |
|||
fullBuf := make([]byte, entryLen) |
|||
if _, err := f.fd.ReadAt(fullBuf, absWALOff); err != nil { |
|||
return fmt.Errorf("flusher: read WAL entry at %d: %w", absWALOff, err) |
|||
} |
|||
|
|||
entry, err := DecodeWALEntry(fullBuf) |
|||
if err != nil { |
|||
return fmt.Errorf("flusher: decode WAL entry: %w", err) |
|||
} |
|||
|
|||
// Write data to extent region.
|
|||
blocks := entry.Length / f.blockSize |
|||
for i := uint32(0); i < blocks; i++ { |
|||
blockLBA := entry.LBA + uint64(i) |
|||
extentOff := int64(f.extentStart + blockLBA*uint64(f.blockSize)) |
|||
blockData := entry.Data[i*f.blockSize : (i+1)*f.blockSize] |
|||
if _, err := f.fd.WriteAt(blockData, extentOff); err != nil { |
|||
return fmt.Errorf("flusher: write extent at LBA %d: %w", blockLBA, err) |
|||
} |
|||
} |
|||
|
|||
// Track WAL end position for tail advance.
|
|||
walEnd := e.WalOffset + uint64(entryLen) |
|||
if walEnd > maxWALEnd { |
|||
maxWALEnd = walEnd |
|||
} |
|||
} else if entryType == EntryTypeTrim { |
|||
// TRIM entries: zero the extent region for this LBA.
|
|||
// Each dirty map entry represents one trimmed block.
|
|||
zeroBlock := make([]byte, f.blockSize) |
|||
extentOff := int64(f.extentStart + e.Lba*uint64(f.blockSize)) |
|||
if _, err := f.fd.WriteAt(zeroBlock, extentOff); err != nil { |
|||
return fmt.Errorf("flusher: zero extent at LBA %d: %w", e.Lba, err) |
|||
} |
|||
|
|||
// TRIM entry has no data payload, just a header.
|
|||
walEnd := e.WalOffset + uint64(walEntryHeaderSize) |
|||
if walEnd > maxWALEnd { |
|||
maxWALEnd = walEnd |
|||
} |
|||
} |
|||
|
|||
if e.Lsn > maxLSN { |
|||
maxLSN = e.Lsn |
|||
} |
|||
} |
|||
|
|||
// Fsync extent writes.
|
|||
if err := f.fd.Sync(); err != nil { |
|||
return fmt.Errorf("flusher: fsync extent: %w", err) |
|||
} |
|||
|
|||
// Remove flushed entries from dirty map.
|
|||
f.mu.Lock() |
|||
for _, e := range entries { |
|||
// Only remove if the dirty map entry still has the same LSN
|
|||
// (a newer write may have updated it).
|
|||
_, currentLSN, _, ok := f.dirtyMap.Get(e.Lba) |
|||
if ok && currentLSN == e.Lsn { |
|||
f.dirtyMap.Delete(e.Lba) |
|||
} |
|||
} |
|||
f.checkpointLSN = maxLSN |
|||
f.checkpointTail = maxWALEnd |
|||
f.mu.Unlock() |
|||
|
|||
// Advance WAL tail to free space.
|
|||
if maxWALEnd > 0 { |
|||
f.wal.AdvanceTail(maxWALEnd) |
|||
} |
|||
|
|||
// Update superblock checkpoint.
|
|||
f.updateSuperblockCheckpoint(maxLSN, f.wal.Tail()) |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// updateSuperblockCheckpoint writes the updated checkpoint to disk.
|
|||
func (f *Flusher) updateSuperblockCheckpoint(checkpointLSN uint64, walTail uint64) error { |
|||
f.super.WALCheckpointLSN = checkpointLSN |
|||
f.super.WALHead = f.wal.LogicalHead() |
|||
f.super.WALTail = f.wal.LogicalTail() |
|||
|
|||
if _, err := f.fd.Seek(0, 0); err != nil { |
|||
return fmt.Errorf("flusher: seek to superblock: %w", err) |
|||
} |
|||
if _, err := f.super.WriteTo(f.fd); err != nil { |
|||
return fmt.Errorf("flusher: write superblock: %w", err) |
|||
} |
|||
return f.fd.Sync() |
|||
} |
|||
|
|||
// CheckpointLSN returns the last flushed LSN.
|
|||
func (f *Flusher) CheckpointLSN() uint64 { |
|||
f.mu.Lock() |
|||
defer f.mu.Unlock() |
|||
return f.checkpointLSN |
|||
} |
|||
|
|||
// parseLength extracts the Length field from a WAL entry header buffer.
|
|||
func parseLength(headerBuf []byte) uint32 { |
|||
// Length at LSN(8)+Epoch(8)+Type(1)+Flags(1)+LBA(8) = 26
|
|||
return binary.LittleEndian.Uint32(headerBuf[26:]) |
|||
} |
|||
@ -0,0 +1,261 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"bytes" |
|||
"path/filepath" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestFlusher(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "flush_moves_data", run: testFlushMovesData}, |
|||
{name: "flush_idempotent", run: testFlushIdempotent}, |
|||
{name: "flush_concurrent_writes", run: testFlushConcurrentWrites}, |
|||
{name: "flush_frees_wal_space", run: testFlushFreesWALSpace}, |
|||
{name: "flush_partial", run: testFlushPartial}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func createTestVolWithFlusher(t *testing.T) (*BlockVol, *Flusher) { |
|||
t.Helper() |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "test.blockvol") |
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: 1 * 1024 * 1024, // 1MB
|
|||
BlockSize: 4096, |
|||
WALSize: 256 * 1024, // 256KB WAL
|
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
|
|||
f := NewFlusher(FlusherConfig{ |
|||
FD: v.fd, |
|||
Super: &v.super, |
|||
WAL: v.wal, |
|||
DirtyMap: v.dirtyMap, |
|||
Interval: 1 * time.Hour, // don't auto-flush in tests
|
|||
}) |
|||
|
|||
return v, f |
|||
} |
|||
|
|||
func testFlushMovesData(t *testing.T) { |
|||
v, f := createTestVolWithFlusher(t) |
|||
defer v.Close() |
|||
|
|||
// Write 10 blocks.
|
|||
for i := uint64(0); i < 10; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
if v.dirtyMap.Len() != 10 { |
|||
t.Fatalf("dirty map len = %d, want 10", v.dirtyMap.Len()) |
|||
} |
|||
|
|||
// Run flusher.
|
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce: %v", err) |
|||
} |
|||
|
|||
// Dirty map should be empty.
|
|||
if v.dirtyMap.Len() != 0 { |
|||
t.Errorf("after flush: dirty map len = %d, want 0", v.dirtyMap.Len()) |
|||
} |
|||
|
|||
// Checkpoint should have advanced.
|
|||
if f.CheckpointLSN() == 0 { |
|||
t.Error("checkpoint LSN should be > 0 after flush") |
|||
} |
|||
|
|||
// Read from extent (dirty map is empty, so reads go to extent).
|
|||
for i := uint64(0); i < 10; i++ { |
|||
got, err := v.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d) after flush: %v", i, err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock(byte('A'+i))) { |
|||
t.Errorf("block %d: data mismatch after flush", i) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testFlushIdempotent(t *testing.T) { |
|||
v, f := createTestVolWithFlusher(t) |
|||
defer v.Close() |
|||
|
|||
data := makeBlock('X') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Flush twice.
|
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce 1: %v", err) |
|||
} |
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce 2: %v", err) |
|||
} |
|||
|
|||
// Data should still be correct.
|
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA after double flush: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("data mismatch after double flush") |
|||
} |
|||
} |
|||
|
|||
func testFlushConcurrentWrites(t *testing.T) { |
|||
v, f := createTestVolWithFlusher(t) |
|||
defer v.Close() |
|||
|
|||
// Write blocks 0-4.
|
|||
for i := uint64(0); i < 5; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Flush (moves blocks 0-4 to extent).
|
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce: %v", err) |
|||
} |
|||
|
|||
// Write blocks 5-9 AFTER flush.
|
|||
for i := uint64(5); i < 10; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Blocks 0-4 should read from extent, blocks 5-9 from WAL.
|
|||
for i := uint64(0); i < 10; i++ { |
|||
got, err := v.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d): %v", i, err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock(byte('A'+i))) { |
|||
t.Errorf("block %d: data mismatch", i) |
|||
} |
|||
} |
|||
|
|||
// Dirty map should have 5 entries (blocks 5-9).
|
|||
if v.dirtyMap.Len() != 5 { |
|||
t.Errorf("dirty map len = %d, want 5", v.dirtyMap.Len()) |
|||
} |
|||
|
|||
// Also: overwrite block 0 after flush -- new write should go to WAL.
|
|||
newData := makeBlock('Z') |
|||
if err := v.WriteLBA(0, newData); err != nil { |
|||
t.Fatalf("WriteLBA(0) overwrite: %v", err) |
|||
} |
|||
got, err := v.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(0) after overwrite: %v", err) |
|||
} |
|||
if !bytes.Equal(got, newData) { |
|||
t.Error("block 0: should return overwritten data 'Z'") |
|||
} |
|||
} |
|||
|
|||
func testFlushFreesWALSpace(t *testing.T) { |
|||
v, f := createTestVolWithFlusher(t) |
|||
defer v.Close() |
|||
|
|||
// Write enough blocks to fill a significant portion of WAL.
|
|||
entrySize := uint64(walEntryHeaderSize + 4096) |
|||
walCapacity := v.super.WALSize / entrySize |
|||
// Write ~80% of capacity.
|
|||
writeCount := int(walCapacity * 80 / 100) |
|||
|
|||
for i := 0; i < writeCount; i++ { |
|||
if err := v.WriteLBA(uint64(i), makeBlock(byte(i%26+'A'))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Try to write more -- should eventually fail with WAL full.
|
|||
var walFullBefore bool |
|||
for i := writeCount; i < writeCount+int(walCapacity); i++ { |
|||
if err := v.WriteLBA(uint64(i%writeCount), makeBlock('X')); err != nil { |
|||
walFullBefore = true |
|||
break |
|||
} |
|||
} |
|||
|
|||
// Flush to free WAL space.
|
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce: %v", err) |
|||
} |
|||
|
|||
// WAL tail should have advanced (free space available).
|
|||
// New writes should succeed.
|
|||
if err := v.WriteLBA(0, makeBlock('Y')); err != nil { |
|||
t.Fatalf("WriteLBA after flush: %v", err) |
|||
} |
|||
|
|||
// Log whether WAL was full before flush.
|
|||
if walFullBefore { |
|||
t.Log("WAL was full before flush, writes succeeded after flush") |
|||
} |
|||
} |
|||
|
|||
func testFlushPartial(t *testing.T) { |
|||
v, f := createTestVolWithFlusher(t) |
|||
defer v.Close() |
|||
|
|||
// Write blocks 0-4.
|
|||
for i := uint64(0); i < 5; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Flush once (all 5 blocks).
|
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce: %v", err) |
|||
} |
|||
|
|||
checkpointAfterFirst := f.CheckpointLSN() |
|||
|
|||
// Write blocks 5-9.
|
|||
for i := uint64(5); i < 10; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Simulate partial flush: flusher runs again, should handle new entries.
|
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce 2: %v", err) |
|||
} |
|||
|
|||
checkpointAfterSecond := f.CheckpointLSN() |
|||
if checkpointAfterSecond <= checkpointAfterFirst { |
|||
t.Errorf("checkpoint should advance: first=%d, second=%d", checkpointAfterFirst, checkpointAfterSecond) |
|||
} |
|||
|
|||
// All blocks should be readable from extent.
|
|||
for i := uint64(0); i < 10; i++ { |
|||
got, err := v.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d) after two flushes: %v", i, err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock(byte('A'+i))) { |
|||
t.Errorf("block %d: data mismatch after two flushes", i) |
|||
} |
|||
} |
|||
} |
|||
@ -0,0 +1,167 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"errors" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
) |
|||
|
|||
var ErrGroupCommitShutdown = errors.New("blockvol: group committer shut down") |
|||
|
|||
// GroupCommitter batches SyncCache requests and performs a single fsync
|
|||
// for the entire batch. This amortizes the cost of fsync across many callers.
|
|||
type GroupCommitter struct { |
|||
syncFunc func() error // called to fsync (injectable for testing)
|
|||
maxDelay time.Duration // max wait before flushing a partial batch
|
|||
maxBatch int // flush immediately when this many waiters accumulate
|
|||
onDegraded func() // called when fsync fails
|
|||
|
|||
mu sync.Mutex |
|||
pending []chan error |
|||
stopped bool // set under mu by Run() before draining
|
|||
notifyCh chan struct{} |
|||
stopCh chan struct{} |
|||
done chan struct{} |
|||
stopOnce sync.Once |
|||
|
|||
syncCount atomic.Uint64 // number of fsyncs performed (for testing)
|
|||
} |
|||
|
|||
// GroupCommitterConfig configures the group committer.
|
|||
type GroupCommitterConfig struct { |
|||
SyncFunc func() error // required: the fsync function
|
|||
MaxDelay time.Duration // default 1ms
|
|||
MaxBatch int // default 64
|
|||
OnDegraded func() // optional: called on fsync error
|
|||
} |
|||
|
|||
// NewGroupCommitter creates a new group committer. Call Run() to start it.
|
|||
func NewGroupCommitter(cfg GroupCommitterConfig) *GroupCommitter { |
|||
if cfg.MaxDelay == 0 { |
|||
cfg.MaxDelay = 1 * time.Millisecond |
|||
} |
|||
if cfg.MaxBatch == 0 { |
|||
cfg.MaxBatch = 64 |
|||
} |
|||
if cfg.OnDegraded == nil { |
|||
cfg.OnDegraded = func() {} |
|||
} |
|||
return &GroupCommitter{ |
|||
syncFunc: cfg.SyncFunc, |
|||
maxDelay: cfg.MaxDelay, |
|||
maxBatch: cfg.MaxBatch, |
|||
onDegraded: cfg.OnDegraded, |
|||
notifyCh: make(chan struct{}, 1), |
|||
stopCh: make(chan struct{}), |
|||
done: make(chan struct{}), |
|||
} |
|||
} |
|||
|
|||
// Run is the main loop. Call this in a goroutine.
|
|||
func (gc *GroupCommitter) Run() { |
|||
defer close(gc.done) |
|||
for { |
|||
// Wait for first waiter or shutdown.
|
|||
select { |
|||
case <-gc.stopCh: |
|||
gc.markStoppedAndDrain() |
|||
return |
|||
case <-gc.notifyCh: |
|||
} |
|||
|
|||
// Collect batch: wait up to maxDelay for more waiters, or until maxBatch reached.
|
|||
deadline := time.NewTimer(gc.maxDelay) |
|||
for { |
|||
gc.mu.Lock() |
|||
n := len(gc.pending) |
|||
gc.mu.Unlock() |
|||
if n >= gc.maxBatch { |
|||
deadline.Stop() |
|||
break |
|||
} |
|||
select { |
|||
case <-gc.stopCh: |
|||
deadline.Stop() |
|||
gc.markStoppedAndDrain() |
|||
return |
|||
case <-deadline.C: |
|||
goto flush |
|||
case <-gc.notifyCh: |
|||
continue |
|||
} |
|||
} |
|||
|
|||
flush: |
|||
// Take all pending waiters.
|
|||
gc.mu.Lock() |
|||
batch := gc.pending |
|||
gc.pending = nil |
|||
gc.mu.Unlock() |
|||
|
|||
if len(batch) == 0 { |
|||
continue |
|||
} |
|||
|
|||
// Perform fsync.
|
|||
err := gc.syncFunc() |
|||
gc.syncCount.Add(1) |
|||
if err != nil { |
|||
gc.onDegraded() |
|||
} |
|||
|
|||
// Wake all waiters.
|
|||
for _, ch := range batch { |
|||
ch <- err |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Submit submits a sync request and blocks until the batch fsync completes.
|
|||
// Returns nil on success, the fsync error, or ErrGroupCommitShutdown if stopped.
|
|||
func (gc *GroupCommitter) Submit() error { |
|||
ch := make(chan error, 1) |
|||
gc.mu.Lock() |
|||
if gc.stopped { |
|||
gc.mu.Unlock() |
|||
return ErrGroupCommitShutdown |
|||
} |
|||
gc.pending = append(gc.pending, ch) |
|||
gc.mu.Unlock() |
|||
|
|||
// Non-blocking notify to wake Run().
|
|||
select { |
|||
case gc.notifyCh <- struct{}{}: |
|||
default: |
|||
} |
|||
|
|||
return <-ch |
|||
} |
|||
|
|||
// Stop shuts down the group committer. Pending waiters receive ErrGroupCommitShutdown.
|
|||
// Safe to call multiple times.
|
|||
func (gc *GroupCommitter) Stop() { |
|||
gc.stopOnce.Do(func() { |
|||
close(gc.stopCh) |
|||
}) |
|||
<-gc.done |
|||
} |
|||
|
|||
// SyncCount returns the number of fsyncs performed (for testing).
|
|||
func (gc *GroupCommitter) SyncCount() uint64 { |
|||
return gc.syncCount.Load() |
|||
} |
|||
|
|||
// markStoppedAndDrain sets the stopped flag under mu and drains all pending
|
|||
// waiters with ErrGroupCommitShutdown. This ensures no new Submit() can
|
|||
// enqueue after we drain, closing the race window from QA-002.
|
|||
func (gc *GroupCommitter) markStoppedAndDrain() { |
|||
gc.mu.Lock() |
|||
gc.stopped = true |
|||
batch := gc.pending |
|||
gc.pending = nil |
|||
gc.mu.Unlock() |
|||
for _, ch := range batch { |
|||
ch <- ErrGroupCommitShutdown |
|||
} |
|||
} |
|||
@ -0,0 +1,287 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"errors" |
|||
"sync" |
|||
"sync/atomic" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestGroupCommitter(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "group_single_barrier", run: testGroupSingleBarrier}, |
|||
{name: "group_batch_10", run: testGroupBatch10}, |
|||
{name: "group_max_delay", run: testGroupMaxDelay}, |
|||
{name: "group_max_batch", run: testGroupMaxBatch}, |
|||
{name: "group_fsync_error", run: testGroupFsyncError}, |
|||
{name: "group_sequential", run: testGroupSequential}, |
|||
{name: "group_shutdown", run: testGroupShutdown}, |
|||
{name: "group_submit_after_stop", run: testGroupSubmitAfterStop}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testGroupSingleBarrier(t *testing.T) { |
|||
var syncCalls atomic.Uint64 |
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { |
|||
syncCalls.Add(1) |
|||
return nil |
|||
}, |
|||
MaxDelay: 10 * time.Millisecond, |
|||
}) |
|||
go gc.Run() |
|||
defer gc.Stop() |
|||
|
|||
if err := gc.Submit(); err != nil { |
|||
t.Fatalf("Submit: %v", err) |
|||
} |
|||
|
|||
if c := syncCalls.Load(); c != 1 { |
|||
t.Errorf("syncCalls = %d, want 1", c) |
|||
} |
|||
} |
|||
|
|||
func testGroupBatch10(t *testing.T) { |
|||
var syncCalls atomic.Uint64 |
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { |
|||
syncCalls.Add(1) |
|||
return nil |
|||
}, |
|||
MaxDelay: 50 * time.Millisecond, |
|||
MaxBatch: 64, |
|||
}) |
|||
go gc.Run() |
|||
defer gc.Stop() |
|||
|
|||
const n = 10 |
|||
var wg sync.WaitGroup |
|||
errs := make([]error, n) |
|||
|
|||
// Launch all 10 concurrently.
|
|||
wg.Add(n) |
|||
for i := 0; i < n; i++ { |
|||
go func(idx int) { |
|||
defer wg.Done() |
|||
errs[idx] = gc.Submit() |
|||
}(i) |
|||
} |
|||
wg.Wait() |
|||
|
|||
for i, err := range errs { |
|||
if err != nil { |
|||
t.Errorf("Submit[%d]: %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// All 10 should be batched into 1 fsync (maybe 2 if timing is unlucky).
|
|||
if c := syncCalls.Load(); c > 2 { |
|||
t.Errorf("syncCalls = %d, want 1-2 (batched)", c) |
|||
} |
|||
} |
|||
|
|||
func testGroupMaxDelay(t *testing.T) { |
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { return nil }, |
|||
MaxDelay: 5 * time.Millisecond, |
|||
}) |
|||
go gc.Run() |
|||
defer gc.Stop() |
|||
|
|||
start := time.Now() |
|||
if err := gc.Submit(); err != nil { |
|||
t.Fatalf("Submit: %v", err) |
|||
} |
|||
elapsed := time.Since(start) |
|||
|
|||
// Should complete within ~maxDelay + some margin, not take much longer.
|
|||
if elapsed > 50*time.Millisecond { |
|||
t.Errorf("Submit took %v, expected ~5ms", elapsed) |
|||
} |
|||
} |
|||
|
|||
func testGroupMaxBatch(t *testing.T) { |
|||
var syncCalls atomic.Uint64 |
|||
const batch = 64 |
|||
|
|||
// Use a slow sync to ensure we can detect immediate trigger.
|
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { |
|||
syncCalls.Add(1) |
|||
return nil |
|||
}, |
|||
MaxDelay: 5 * time.Second, // very long delay -- should NOT wait this long
|
|||
MaxBatch: batch, |
|||
}) |
|||
go gc.Run() |
|||
defer gc.Stop() |
|||
|
|||
var wg sync.WaitGroup |
|||
wg.Add(batch) |
|||
for i := 0; i < batch; i++ { |
|||
go func() { |
|||
defer wg.Done() |
|||
gc.Submit() |
|||
}() |
|||
} |
|||
|
|||
// Wait with timeout -- if maxBatch triggers, should complete fast.
|
|||
done := make(chan struct{}) |
|||
go func() { |
|||
wg.Wait() |
|||
close(done) |
|||
}() |
|||
|
|||
select { |
|||
case <-done: |
|||
// Good -- completed without waiting 5 seconds.
|
|||
case <-time.After(2 * time.Second): |
|||
t.Fatal("maxBatch=64 did not trigger immediate flush") |
|||
} |
|||
} |
|||
|
|||
func testGroupFsyncError(t *testing.T) { |
|||
errIO := errors.New("simulated EIO") |
|||
var degraded atomic.Bool |
|||
|
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { |
|||
return errIO |
|||
}, |
|||
MaxDelay: 10 * time.Millisecond, |
|||
OnDegraded: func() { degraded.Store(true) }, |
|||
}) |
|||
go gc.Run() |
|||
defer gc.Stop() |
|||
|
|||
const n = 5 |
|||
var wg sync.WaitGroup |
|||
errs := make([]error, n) |
|||
|
|||
wg.Add(n) |
|||
for i := 0; i < n; i++ { |
|||
go func(idx int) { |
|||
defer wg.Done() |
|||
errs[idx] = gc.Submit() |
|||
}(i) |
|||
} |
|||
wg.Wait() |
|||
|
|||
for i, err := range errs { |
|||
if !errors.Is(err, errIO) { |
|||
t.Errorf("Submit[%d]: got %v, want %v", i, err, errIO) |
|||
} |
|||
} |
|||
|
|||
if !degraded.Load() { |
|||
t.Error("onDegraded was not called") |
|||
} |
|||
} |
|||
|
|||
func testGroupSequential(t *testing.T) { |
|||
var syncCalls atomic.Uint64 |
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { |
|||
syncCalls.Add(1) |
|||
return nil |
|||
}, |
|||
MaxDelay: 10 * time.Millisecond, |
|||
}) |
|||
go gc.Run() |
|||
defer gc.Stop() |
|||
|
|||
// First sync.
|
|||
if err := gc.Submit(); err != nil { |
|||
t.Fatalf("Submit 1: %v", err) |
|||
} |
|||
|
|||
// Second sync (after first completes).
|
|||
if err := gc.Submit(); err != nil { |
|||
t.Fatalf("Submit 2: %v", err) |
|||
} |
|||
|
|||
if c := syncCalls.Load(); c != 2 { |
|||
t.Errorf("syncCalls = %d, want 2 (two separate fsyncs)", c) |
|||
} |
|||
} |
|||
|
|||
func testGroupShutdown(t *testing.T) { |
|||
// Block fsync so we can shut down while waiters are pending.
|
|||
syncStarted := make(chan struct{}) |
|||
syncBlock := make(chan struct{}) |
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { |
|||
close(syncStarted) |
|||
<-syncBlock // block until test releases
|
|||
return nil |
|||
}, |
|||
MaxDelay: 1 * time.Millisecond, |
|||
}) |
|||
go gc.Run() |
|||
|
|||
// Submit one request that will block on fsync.
|
|||
errCh1 := make(chan error, 1) |
|||
go func() { |
|||
errCh1 <- gc.Submit() |
|||
}() |
|||
|
|||
// Wait for fsync to start.
|
|||
<-syncStarted |
|||
|
|||
// Submit another request while fsync is in progress.
|
|||
errCh2 := make(chan error, 1) |
|||
go func() { |
|||
errCh2 <- gc.Submit() |
|||
}() |
|||
time.Sleep(5 * time.Millisecond) // let it enqueue
|
|||
|
|||
// Stop the group committer (will drain pending with ErrGroupCommitShutdown).
|
|||
go func() { |
|||
time.Sleep(10 * time.Millisecond) |
|||
close(syncBlock) // unblock the fsync first
|
|||
}() |
|||
gc.Stop() |
|||
|
|||
// First waiter should get nil (fsync succeeded).
|
|||
if err := <-errCh1; err != nil { |
|||
t.Errorf("first waiter: %v, want nil", err) |
|||
} |
|||
|
|||
// Second waiter should get shutdown error.
|
|||
if err := <-errCh2; !errors.Is(err, ErrGroupCommitShutdown) { |
|||
t.Errorf("second waiter: %v, want ErrGroupCommitShutdown", err) |
|||
} |
|||
} |
|||
|
|||
func testGroupSubmitAfterStop(t *testing.T) { |
|||
gc := NewGroupCommitter(GroupCommitterConfig{ |
|||
SyncFunc: func() error { return nil }, |
|||
MaxDelay: 10 * time.Millisecond, |
|||
}) |
|||
go gc.Run() |
|||
gc.Stop() |
|||
|
|||
// Submit after Stop must return ErrGroupCommitShutdown, not deadlock.
|
|||
done := make(chan error, 1) |
|||
go func() { |
|||
done <- gc.Submit() |
|||
}() |
|||
|
|||
select { |
|||
case err := <-done: |
|||
if !errors.Is(err, ErrGroupCommitShutdown) { |
|||
t.Errorf("Submit after Stop: %v, want ErrGroupCommitShutdown", err) |
|||
} |
|||
case <-time.After(2 * time.Second): |
|||
t.Fatal("Submit after Stop deadlocked") |
|||
} |
|||
} |
|||
@ -0,0 +1,141 @@ |
|||
// iscsi-target is a standalone iSCSI target backed by a BlockVol file.
|
|||
// Usage:
|
|||
//
|
|||
// iscsi-target -vol /path/to/volume.blk -addr :3260 -iqn iqn.2024.com.seaweedfs:vol1
|
|||
// iscsi-target -create -size 1G -vol /path/to/volume.blk -addr :3260
|
|||
package main |
|||
|
|||
import ( |
|||
"flag" |
|||
"fmt" |
|||
"log" |
|||
"os" |
|||
"os/signal" |
|||
"strconv" |
|||
"strings" |
|||
"syscall" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol" |
|||
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi" |
|||
) |
|||
|
|||
func main() { |
|||
volPath := flag.String("vol", "", "path to BlockVol file") |
|||
addr := flag.String("addr", ":3260", "listen address") |
|||
iqn := flag.String("iqn", "iqn.2024.com.seaweedfs:vol1", "target IQN") |
|||
create := flag.Bool("create", false, "create a new volume file") |
|||
size := flag.String("size", "1G", "volume size (e.g., 1G, 100M) — used with -create") |
|||
flag.Parse() |
|||
|
|||
if *volPath == "" { |
|||
fmt.Fprintln(os.Stderr, "error: -vol is required") |
|||
flag.Usage() |
|||
os.Exit(1) |
|||
} |
|||
|
|||
logger := log.New(os.Stdout, "[iscsi] ", log.LstdFlags) |
|||
|
|||
var vol *blockvol.BlockVol |
|||
var err error |
|||
|
|||
if *create { |
|||
volSize, parseErr := parseSize(*size) |
|||
if parseErr != nil { |
|||
log.Fatalf("invalid size %q: %v", *size, parseErr) |
|||
} |
|||
vol, err = blockvol.CreateBlockVol(*volPath, blockvol.CreateOptions{ |
|||
VolumeSize: volSize, |
|||
BlockSize: 4096, |
|||
WALSize: 64 * 1024 * 1024, |
|||
}) |
|||
if err != nil { |
|||
log.Fatalf("create volume: %v", err) |
|||
} |
|||
logger.Printf("created volume: %s (%s)", *volPath, *size) |
|||
} else { |
|||
vol, err = blockvol.OpenBlockVol(*volPath) |
|||
if err != nil { |
|||
log.Fatalf("open volume: %v", err) |
|||
} |
|||
logger.Printf("opened volume: %s", *volPath) |
|||
} |
|||
defer vol.Close() |
|||
|
|||
info := vol.Info() |
|||
logger.Printf("volume: %d bytes, block=%d, healthy=%v", |
|||
info.VolumeSize, info.BlockSize, info.Healthy) |
|||
|
|||
// Create adapter
|
|||
adapter := &blockVolAdapter{vol: vol} |
|||
|
|||
// Create target server
|
|||
config := iscsi.DefaultTargetConfig() |
|||
config.TargetName = *iqn |
|||
config.TargetAlias = "SeaweedFS BlockVol" |
|||
ts := iscsi.NewTargetServer(*addr, config, logger) |
|||
ts.AddVolume(*iqn, adapter) |
|||
|
|||
// Graceful shutdown on signal
|
|||
sigCh := make(chan os.Signal, 1) |
|||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) |
|||
go func() { |
|||
sig := <-sigCh |
|||
logger.Printf("received %v, shutting down...", sig) |
|||
ts.Close() |
|||
}() |
|||
|
|||
logger.Printf("starting iSCSI target: %s on %s", *iqn, *addr) |
|||
if err := ts.ListenAndServe(); err != nil { |
|||
log.Fatalf("target server: %v", err) |
|||
} |
|||
logger.Println("target stopped") |
|||
} |
|||
|
|||
// blockVolAdapter wraps BlockVol to implement iscsi.BlockDevice.
|
|||
type blockVolAdapter struct { |
|||
vol *blockvol.BlockVol |
|||
} |
|||
|
|||
func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) { |
|||
return a.vol.ReadLBA(lba, length) |
|||
} |
|||
func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error { |
|||
return a.vol.WriteLBA(lba, data) |
|||
} |
|||
func (a *blockVolAdapter) Trim(lba uint64, length uint32) error { |
|||
return a.vol.Trim(lba, length) |
|||
} |
|||
func (a *blockVolAdapter) SyncCache() error { return a.vol.SyncCache() } |
|||
func (a *blockVolAdapter) BlockSize() uint32 { return a.vol.Info().BlockSize } |
|||
func (a *blockVolAdapter) VolumeSize() uint64 { return a.vol.Info().VolumeSize } |
|||
func (a *blockVolAdapter) IsHealthy() bool { return a.vol.Info().Healthy } |
|||
|
|||
func parseSize(s string) (uint64, error) { |
|||
s = strings.TrimSpace(s) |
|||
if len(s) == 0 { |
|||
return 0, fmt.Errorf("empty size") |
|||
} |
|||
|
|||
multiplier := uint64(1) |
|||
suffix := s[len(s)-1] |
|||
switch suffix { |
|||
case 'K', 'k': |
|||
multiplier = 1024 |
|||
s = s[:len(s)-1] |
|||
case 'M', 'm': |
|||
multiplier = 1024 * 1024 |
|||
s = s[:len(s)-1] |
|||
case 'G', 'g': |
|||
multiplier = 1024 * 1024 * 1024 |
|||
s = s[:len(s)-1] |
|||
case 'T', 't': |
|||
multiplier = 1024 * 1024 * 1024 * 1024 |
|||
s = s[:len(s)-1] |
|||
} |
|||
|
|||
n, err := strconv.ParseUint(s, 10, 64) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
return n * multiplier, nil |
|||
} |
|||
@ -0,0 +1,216 @@ |
|||
#!/usr/bin/env bash |
|||
# smoke-test.sh — iscsiadm smoke test for SeaweedFS iSCSI target |
|||
# |
|||
# Prerequisites: |
|||
# - Linux host with iscsiadm (open-iscsi) installed |
|||
# - Root access (for iscsiadm, mkfs, mount) |
|||
# - iscsi-target binary built: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/ |
|||
# |
|||
# Usage: |
|||
# sudo ./smoke-test.sh [target-addr:port] |
|||
# |
|||
# Default: starts a local target on :3260, runs discovery + login + I/O, cleans up. |
|||
|
|||
set -euo pipefail |
|||
|
|||
TARGET_ADDR="${1:-127.0.0.1}" |
|||
TARGET_PORT="${2:-3260}" |
|||
TARGET_IQN="iqn.2024.com.seaweedfs:smoke" |
|||
VOL_FILE="/tmp/iscsi-smoke-test.blk" |
|||
MOUNT_POINT="/tmp/iscsi-smoke-mnt" |
|||
TARGET_PID="" |
|||
|
|||
RED='\033[0;31m' |
|||
GREEN='\033[0;32m' |
|||
YELLOW='\033[1;33m' |
|||
NC='\033[0m' |
|||
|
|||
log() { echo -e "${GREEN}[SMOKE]${NC} $*"; } |
|||
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } |
|||
fail() { echo -e "${RED}[FAIL]${NC} $*"; exit 1; } |
|||
|
|||
cleanup() { |
|||
log "Cleaning up..." |
|||
|
|||
# Unmount if mounted |
|||
if mountpoint -q "$MOUNT_POINT" 2>/dev/null; then |
|||
umount "$MOUNT_POINT" || true |
|||
fi |
|||
rmdir "$MOUNT_POINT" 2>/dev/null || true |
|||
|
|||
# Logout from iSCSI |
|||
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout 2>/dev/null || true |
|||
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" -o delete 2>/dev/null || true |
|||
|
|||
# Stop target |
|||
if [[ -n "$TARGET_PID" ]] && kill -0 "$TARGET_PID" 2>/dev/null; then |
|||
kill "$TARGET_PID" |
|||
wait "$TARGET_PID" 2>/dev/null || true |
|||
fi |
|||
|
|||
# Remove volume file |
|||
rm -f "$VOL_FILE" |
|||
|
|||
log "Cleanup complete" |
|||
} |
|||
|
|||
trap cleanup EXIT |
|||
|
|||
# ------------------------------------------------------- |
|||
# Preflight |
|||
# ------------------------------------------------------- |
|||
if [[ $EUID -ne 0 ]]; then |
|||
fail "This script must be run as root (for iscsiadm/mount)" |
|||
fi |
|||
|
|||
if ! command -v iscsiadm &>/dev/null; then |
|||
fail "iscsiadm not found. Install open-iscsi: apt install open-iscsi" |
|||
fi |
|||
|
|||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" |
|||
TARGET_BIN="${SCRIPT_DIR}/iscsi-target" |
|||
if [[ ! -x "$TARGET_BIN" ]]; then |
|||
# Try the repo root |
|||
TARGET_BIN="$(cd "$SCRIPT_DIR/../../../../../.." && pwd)/iscsi-target" |
|||
fi |
|||
if [[ ! -x "$TARGET_BIN" ]]; then |
|||
fail "iscsi-target binary not found. Build with: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/" |
|||
fi |
|||
|
|||
# ------------------------------------------------------- |
|||
# Step 1: Start iSCSI target |
|||
# ------------------------------------------------------- |
|||
log "Starting iSCSI target..." |
|||
"$TARGET_BIN" \ |
|||
-create \ |
|||
-vol "$VOL_FILE" \ |
|||
-size 100M \ |
|||
-addr "${TARGET_ADDR}:${TARGET_PORT}" \ |
|||
-iqn "$TARGET_IQN" & |
|||
TARGET_PID=$! |
|||
sleep 1 |
|||
|
|||
if ! kill -0 "$TARGET_PID" 2>/dev/null; then |
|||
fail "Target failed to start" |
|||
fi |
|||
log "Target started (PID $TARGET_PID)" |
|||
|
|||
# ------------------------------------------------------- |
|||
# Step 2: Discovery |
|||
# ------------------------------------------------------- |
|||
log "Running iscsiadm discovery..." |
|||
DISCOVERY=$(iscsiadm -m discovery -t sendtargets -p "${TARGET_ADDR}:${TARGET_PORT}" 2>&1) || { |
|||
fail "Discovery failed: $DISCOVERY" |
|||
} |
|||
echo "$DISCOVERY" |
|||
|
|||
if ! echo "$DISCOVERY" | grep -q "$TARGET_IQN"; then |
|||
fail "Target IQN not found in discovery output" |
|||
fi |
|||
log "Discovery OK" |
|||
|
|||
# ------------------------------------------------------- |
|||
# Step 3: Login |
|||
# ------------------------------------------------------- |
|||
log "Logging in to target..." |
|||
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --login || { |
|||
fail "Login failed" |
|||
} |
|||
log "Login OK" |
|||
|
|||
# Wait for device to appear |
|||
sleep 2 |
|||
ISCSI_DEV=$(iscsiadm -m session -P 3 2>/dev/null | grep "Attached scsi disk" | awk '{print $4}' | head -1) |
|||
if [[ -z "$ISCSI_DEV" ]]; then |
|||
warn "Could not determine attached device — trying /dev/sdb" |
|||
ISCSI_DEV="sdb" |
|||
fi |
|||
DEV_PATH="/dev/$ISCSI_DEV" |
|||
log "Attached device: $DEV_PATH" |
|||
|
|||
if [[ ! -b "$DEV_PATH" ]]; then |
|||
fail "Block device $DEV_PATH not found" |
|||
fi |
|||
|
|||
# ------------------------------------------------------- |
|||
# Step 4: dd write/read test |
|||
# ------------------------------------------------------- |
|||
log "Writing test pattern with dd..." |
|||
dd if=/dev/urandom of=/tmp/iscsi-smoke-pattern bs=1M count=1 2>/dev/null |
|||
dd if=/tmp/iscsi-smoke-pattern of="$DEV_PATH" bs=1M count=1 oflag=direct 2>/dev/null || { |
|||
fail "dd write failed" |
|||
} |
|||
|
|||
log "Reading back and verifying..." |
|||
dd if="$DEV_PATH" of=/tmp/iscsi-smoke-readback bs=1M count=1 iflag=direct 2>/dev/null || { |
|||
fail "dd read failed" |
|||
} |
|||
|
|||
if cmp -s /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback; then |
|||
log "Data integrity verified (dd write/read OK)" |
|||
else |
|||
fail "Data integrity check failed!" |
|||
fi |
|||
rm -f /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback |
|||
|
|||
# ------------------------------------------------------- |
|||
# Step 5: mkfs + mount + file I/O |
|||
# ------------------------------------------------------- |
|||
log "Creating ext4 filesystem..." |
|||
mkfs.ext4 -F -q "$DEV_PATH" || { |
|||
warn "mkfs.ext4 failed — skipping filesystem tests" |
|||
# Still consider the test a pass if dd worked |
|||
log "SMOKE TEST PASSED (dd only, mkfs skipped)" |
|||
exit 0 |
|||
} |
|||
|
|||
mkdir -p "$MOUNT_POINT" |
|||
mount "$DEV_PATH" "$MOUNT_POINT" || { |
|||
fail "mount failed" |
|||
} |
|||
log "Mounted at $MOUNT_POINT" |
|||
|
|||
# Write a file |
|||
echo "SeaweedFS iSCSI smoke test" > "$MOUNT_POINT/test.txt" |
|||
sync |
|||
|
|||
# Read it back |
|||
CONTENT=$(cat "$MOUNT_POINT/test.txt") |
|||
if [[ "$CONTENT" != "SeaweedFS iSCSI smoke test" ]]; then |
|||
fail "File content mismatch" |
|||
fi |
|||
log "File I/O verified" |
|||
|
|||
# ------------------------------------------------------- |
|||
# Step 6: Logout |
|||
# ------------------------------------------------------- |
|||
log "Unmounting and logging out..." |
|||
umount "$MOUNT_POINT" |
|||
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout || { |
|||
fail "Logout failed" |
|||
} |
|||
log "Logout OK" |
|||
|
|||
# Verify no stale sessions |
|||
SESSION_COUNT=$(iscsiadm -m session 2>/dev/null | grep -c "$TARGET_IQN" || true) |
|||
if [[ "$SESSION_COUNT" -gt 0 ]]; then |
|||
fail "Stale session detected after logout" |
|||
fi |
|||
log "No stale sessions" |
|||
|
|||
# ------------------------------------------------------- |
|||
# Result |
|||
# ------------------------------------------------------- |
|||
echo "" |
|||
echo -e "${GREEN}========================================${NC}" |
|||
echo -e "${GREEN} SMOKE TEST PASSED${NC}" |
|||
echo -e "${GREEN}========================================${NC}" |
|||
echo "" |
|||
echo " Discovery: OK" |
|||
echo " Login: OK" |
|||
echo " dd I/O: OK" |
|||
echo " mkfs+mount: OK" |
|||
echo " File I/O: OK" |
|||
echo " Logout: OK" |
|||
echo " Cleanup: OK" |
|||
echo "" |
|||
@ -0,0 +1,207 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"errors" |
|||
"io" |
|||
) |
|||
|
|||
var ( |
|||
ErrDataSNOrder = errors.New("iscsi: Data-Out DataSN out of order") |
|||
ErrDataOverflow = errors.New("iscsi: data exceeds expected transfer length") |
|||
ErrDataIncomplete = errors.New("iscsi: data transfer incomplete") |
|||
) |
|||
|
|||
// DataInWriter splits a read response into multiple Data-In PDUs, respecting
|
|||
// MaxRecvDataSegmentLength. The final PDU carries the S-bit (status) and F-bit.
|
|||
type DataInWriter struct { |
|||
maxSegLen uint32 // negotiated MaxRecvDataSegmentLength
|
|||
} |
|||
|
|||
// NewDataInWriter creates a writer with the given max segment length.
|
|||
func NewDataInWriter(maxSegLen uint32) *DataInWriter { |
|||
if maxSegLen == 0 { |
|||
maxSegLen = 8192 // sensible default
|
|||
} |
|||
return &DataInWriter{maxSegLen: maxSegLen} |
|||
} |
|||
|
|||
// WriteDataIn splits data into Data-In PDUs and writes them to w.
|
|||
// itt is the initiator task tag, ttt is the target transfer tag.
|
|||
// statSN is the current StatSN (incremented when S-bit is set).
|
|||
// Returns the number of PDUs written.
|
|||
func (d *DataInWriter) WriteDataIn(w io.Writer, data []byte, itt uint32, expCmdSN, maxCmdSN uint32, statSN *uint32) (int, error) { |
|||
totalLen := uint32(len(data)) |
|||
if totalLen == 0 { |
|||
// Zero-length read — send single Data-In with S-bit, no data
|
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIDataIn) |
|||
pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status
|
|||
pdu.SetInitiatorTaskTag(itt) |
|||
pdu.SetTargetTransferTag(0xFFFFFFFF) |
|||
pdu.SetStatSN(*statSN) |
|||
*statSN++ |
|||
pdu.SetExpCmdSN(expCmdSN) |
|||
pdu.SetMaxCmdSN(maxCmdSN) |
|||
pdu.SetDataSN(0) |
|||
pdu.SetSCSIStatus(SCSIStatusGood) |
|||
if err := WritePDU(w, pdu); err != nil { |
|||
return 0, err |
|||
} |
|||
return 1, nil |
|||
} |
|||
|
|||
var offset uint32 |
|||
var dataSN uint32 |
|||
count := 0 |
|||
|
|||
for offset < totalLen { |
|||
segLen := d.maxSegLen |
|||
if offset+segLen > totalLen { |
|||
segLen = totalLen - offset |
|||
} |
|||
isFinal := (offset + segLen) >= totalLen |
|||
|
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIDataIn) |
|||
pdu.SetInitiatorTaskTag(itt) |
|||
pdu.SetTargetTransferTag(0xFFFFFFFF) |
|||
pdu.SetExpCmdSN(expCmdSN) |
|||
pdu.SetMaxCmdSN(maxCmdSN) |
|||
pdu.SetDataSN(dataSN) |
|||
pdu.SetBufferOffset(offset) |
|||
|
|||
pdu.DataSegment = data[offset : offset+segLen] |
|||
|
|||
if isFinal { |
|||
pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status
|
|||
pdu.SetStatSN(*statSN) |
|||
*statSN++ |
|||
pdu.SetSCSIStatus(SCSIStatusGood) |
|||
} else { |
|||
pdu.SetOpSpecific1(0) // no flags
|
|||
} |
|||
|
|||
if err := WritePDU(w, pdu); err != nil { |
|||
return count, err |
|||
} |
|||
count++ |
|||
dataSN++ |
|||
offset += segLen |
|||
} |
|||
|
|||
return count, nil |
|||
} |
|||
|
|||
// DataOutCollector collects Data-Out PDUs for a single write command,
|
|||
// assembling the full data buffer from potentially multiple PDUs.
|
|||
// It handles both immediate data and R2T-solicited data.
|
|||
type DataOutCollector struct { |
|||
expectedLen uint32 |
|||
buf []byte |
|||
received uint32 |
|||
nextDataSN uint32 |
|||
done bool |
|||
} |
|||
|
|||
// NewDataOutCollector creates a collector expecting the given total transfer length.
|
|||
func NewDataOutCollector(expectedLen uint32) *DataOutCollector { |
|||
return &DataOutCollector{ |
|||
expectedLen: expectedLen, |
|||
buf: make([]byte, expectedLen), |
|||
} |
|||
} |
|||
|
|||
// AddImmediateData adds the data from the SCSI Command PDU (immediate data).
|
|||
func (c *DataOutCollector) AddImmediateData(data []byte) error { |
|||
if uint32(len(data)) > c.expectedLen { |
|||
return ErrDataOverflow |
|||
} |
|||
copy(c.buf, data) |
|||
c.received += uint32(len(data)) |
|||
if c.received >= c.expectedLen { |
|||
c.done = true |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// AddDataOut processes a Data-Out PDU and adds its data to the buffer.
|
|||
func (c *DataOutCollector) AddDataOut(pdu *PDU) error { |
|||
dataSN := pdu.DataSN() |
|||
if dataSN != c.nextDataSN { |
|||
return ErrDataSNOrder |
|||
} |
|||
c.nextDataSN++ |
|||
|
|||
offset := pdu.BufferOffset() |
|||
data := pdu.DataSegment |
|||
end := offset + uint32(len(data)) |
|||
|
|||
if end > c.expectedLen { |
|||
return ErrDataOverflow |
|||
} |
|||
|
|||
copy(c.buf[offset:], data) |
|||
c.received += uint32(len(data)) |
|||
|
|||
// Check F-bit
|
|||
if pdu.OpSpecific1()&FlagF != 0 { |
|||
c.done = true |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Done returns true if all expected data has been received.
|
|||
func (c *DataOutCollector) Done() bool { return c.done } |
|||
|
|||
// Data returns the assembled data buffer.
|
|||
func (c *DataOutCollector) Data() []byte { return c.buf } |
|||
|
|||
// Remaining returns how many bytes are still needed.
|
|||
func (c *DataOutCollector) Remaining() uint32 { |
|||
if c.received >= c.expectedLen { |
|||
return 0 |
|||
} |
|||
return c.expectedLen - c.received |
|||
} |
|||
|
|||
// BuildR2T creates an R2T PDU requesting more data from the initiator.
|
|||
func BuildR2T(itt, ttt uint32, r2tSN uint32, bufferOffset, desiredLen uint32, statSN, expCmdSN, maxCmdSN uint32) *PDU { |
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpR2T) |
|||
pdu.SetOpSpecific1(FlagF) // always Final for R2T
|
|||
pdu.SetInitiatorTaskTag(itt) |
|||
pdu.SetTargetTransferTag(ttt) |
|||
pdu.SetStatSN(statSN) |
|||
pdu.SetExpCmdSN(expCmdSN) |
|||
pdu.SetMaxCmdSN(maxCmdSN) |
|||
pdu.SetR2TSN(r2tSN) |
|||
pdu.SetBufferOffset(bufferOffset) |
|||
pdu.SetDesiredDataLength(desiredLen) |
|||
return pdu |
|||
} |
|||
|
|||
// SendSCSIResponse sends a SCSI Response PDU with optional sense data.
|
|||
func SendSCSIResponse(w io.Writer, result SCSIResult, itt uint32, statSN *uint32, expCmdSN, maxCmdSN uint32) error { |
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIResp) |
|||
pdu.SetOpSpecific1(FlagF) // Final
|
|||
pdu.SetSCSIResponse(ISCSIRespCompleted) |
|||
pdu.SetSCSIStatus(result.Status) |
|||
pdu.SetInitiatorTaskTag(itt) |
|||
pdu.SetStatSN(*statSN) |
|||
*statSN++ |
|||
pdu.SetExpCmdSN(expCmdSN) |
|||
pdu.SetMaxCmdSN(maxCmdSN) |
|||
|
|||
if result.Status == SCSIStatusCheckCond { |
|||
senseData := BuildSenseData(result.SenseKey, result.SenseASC, result.SenseASCQ) |
|||
// Sense data is wrapped in a 2-byte length prefix
|
|||
pdu.DataSegment = make([]byte, 2+len(senseData)) |
|||
pdu.DataSegment[0] = byte(len(senseData) >> 8) |
|||
pdu.DataSegment[1] = byte(len(senseData)) |
|||
copy(pdu.DataSegment[2:], senseData) |
|||
} |
|||
|
|||
return WritePDU(w, pdu) |
|||
} |
|||
@ -0,0 +1,410 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"bytes" |
|||
"testing" |
|||
) |
|||
|
|||
func TestDataIO(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"datain_single_pdu", testDataInSinglePDU}, |
|||
{"datain_multi_pdu", testDataInMultiPDU}, |
|||
{"datain_exact_boundary", testDataInExactBoundary}, |
|||
{"datain_zero_length", testDataInZeroLength}, |
|||
{"datain_datasn_ordering", testDataInDataSNOrdering}, |
|||
{"datain_fbit_sbit", testDataInFbitSbit}, |
|||
{"dataout_single_pdu", testDataOutSinglePDU}, |
|||
{"dataout_multi_pdu", testDataOutMultiPDU}, |
|||
{"dataout_immediate_data", testDataOutImmediateData}, |
|||
{"dataout_immediate_plus_r2t", testDataOutImmediatePlusR2T}, |
|||
{"dataout_wrong_datasn", testDataOutWrongDataSN}, |
|||
{"dataout_overflow", testDataOutOverflow}, |
|||
{"r2t_build", testR2TBuild}, |
|||
{"scsi_response_good", testSCSIResponseGood}, |
|||
{"scsi_response_check_condition", testSCSIResponseCheckCondition}, |
|||
{"datain_statsn_increment", testDataInStatSNIncrement}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testDataInSinglePDU(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
dw := NewDataInWriter(8192) |
|||
data := bytes.Repeat([]byte{0xAA}, 4096) |
|||
statSN := uint32(1) |
|||
|
|||
n, err := dw.WriteDataIn(w, data, 0x100, 1, 10, &statSN) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if n != 1 { |
|||
t.Fatalf("expected 1 PDU, got %d", n) |
|||
} |
|||
if statSN != 2 { |
|||
t.Fatalf("StatSN should be 2, got %d", statSN) |
|||
} |
|||
|
|||
pdu, err := ReadPDU(w) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if pdu.Opcode() != OpSCSIDataIn { |
|||
t.Fatal("wrong opcode") |
|||
} |
|||
if pdu.OpSpecific1()&FlagF == 0 { |
|||
t.Fatal("F-bit not set") |
|||
} |
|||
if pdu.OpSpecific1()&FlagS == 0 { |
|||
t.Fatal("S-bit not set on final PDU") |
|||
} |
|||
if len(pdu.DataSegment) != 4096 { |
|||
t.Fatalf("data length: %d", len(pdu.DataSegment)) |
|||
} |
|||
} |
|||
|
|||
func testDataInMultiPDU(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
dw := NewDataInWriter(1024) // small segment
|
|||
data := bytes.Repeat([]byte{0xBB}, 3000) |
|||
statSN := uint32(1) |
|||
|
|||
n, err := dw.WriteDataIn(w, data, 0x200, 1, 10, &statSN) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
// 3000 / 1024 = 3 PDUs (1024 + 1024 + 952)
|
|||
if n != 3 { |
|||
t.Fatalf("expected 3 PDUs, got %d", n) |
|||
} |
|||
|
|||
// Read them back
|
|||
var reassembled []byte |
|||
for i := 0; i < 3; i++ { |
|||
pdu, err := ReadPDU(w) |
|||
if err != nil { |
|||
t.Fatalf("PDU %d: %v", i, err) |
|||
} |
|||
if pdu.DataSN() != uint32(i) { |
|||
t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN()) |
|||
} |
|||
reassembled = append(reassembled, pdu.DataSegment...) |
|||
|
|||
if i < 2 { |
|||
if pdu.OpSpecific1()&FlagF != 0 { |
|||
t.Fatalf("PDU %d should not have F-bit", i) |
|||
} |
|||
} else { |
|||
if pdu.OpSpecific1()&FlagF == 0 { |
|||
t.Fatal("last PDU should have F-bit") |
|||
} |
|||
if pdu.OpSpecific1()&FlagS == 0 { |
|||
t.Fatal("last PDU should have S-bit") |
|||
} |
|||
} |
|||
} |
|||
|
|||
if !bytes.Equal(reassembled, data) { |
|||
t.Fatal("reassembled data mismatch") |
|||
} |
|||
} |
|||
|
|||
func testDataInExactBoundary(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
dw := NewDataInWriter(1024) |
|||
data := bytes.Repeat([]byte{0xCC}, 2048) // exact 2 PDUs
|
|||
statSN := uint32(1) |
|||
|
|||
n, err := dw.WriteDataIn(w, data, 0x300, 1, 10, &statSN) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if n != 2 { |
|||
t.Fatalf("expected 2 PDUs, got %d", n) |
|||
} |
|||
} |
|||
|
|||
func testDataInZeroLength(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
dw := NewDataInWriter(8192) |
|||
statSN := uint32(5) |
|||
|
|||
n, err := dw.WriteDataIn(w, nil, 0x400, 1, 10, &statSN) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if n != 1 { |
|||
t.Fatalf("expected 1 PDU for zero-length, got %d", n) |
|||
} |
|||
if statSN != 6 { |
|||
t.Fatal("StatSN should still increment") |
|||
} |
|||
} |
|||
|
|||
func testDataInDataSNOrdering(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
dw := NewDataInWriter(512) |
|||
data := bytes.Repeat([]byte{0xDD}, 2048) // 4 PDUs
|
|||
statSN := uint32(1) |
|||
|
|||
dw.WriteDataIn(w, data, 0x500, 1, 10, &statSN) |
|||
|
|||
for i := 0; i < 4; i++ { |
|||
pdu, _ := ReadPDU(w) |
|||
if pdu.DataSN() != uint32(i) { |
|||
t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN()) |
|||
} |
|||
expectedOffset := uint32(i) * 512 |
|||
if pdu.BufferOffset() != expectedOffset { |
|||
t.Fatalf("PDU %d: offset=%d, expected %d", i, pdu.BufferOffset(), expectedOffset) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testDataInFbitSbit(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
dw := NewDataInWriter(1000) |
|||
data := bytes.Repeat([]byte{0xEE}, 2500) |
|||
statSN := uint32(1) |
|||
|
|||
dw.WriteDataIn(w, data, 0x600, 1, 10, &statSN) |
|||
|
|||
for i := 0; i < 3; i++ { |
|||
pdu, _ := ReadPDU(w) |
|||
flags := pdu.OpSpecific1() |
|||
if i < 2 { |
|||
if flags&FlagF != 0 || flags&FlagS != 0 { |
|||
t.Fatalf("PDU %d: should have no F/S bits", i) |
|||
} |
|||
} else { |
|||
if flags&FlagF == 0 || flags&FlagS == 0 { |
|||
t.Fatal("last PDU must have F+S bits") |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testDataOutSinglePDU(t *testing.T) { |
|||
c := NewDataOutCollector(4096) |
|||
|
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIDataOut) |
|||
pdu.SetOpSpecific1(FlagF) |
|||
pdu.SetDataSN(0) |
|||
pdu.SetBufferOffset(0) |
|||
pdu.DataSegment = bytes.Repeat([]byte{0x11}, 4096) |
|||
|
|||
if err := c.AddDataOut(pdu); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if !c.Done() { |
|||
t.Fatal("should be done") |
|||
} |
|||
if c.Remaining() != 0 { |
|||
t.Fatal("remaining should be 0") |
|||
} |
|||
} |
|||
|
|||
func testDataOutMultiPDU(t *testing.T) { |
|||
c := NewDataOutCollector(8192) |
|||
|
|||
for i := 0; i < 2; i++ { |
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIDataOut) |
|||
pdu.SetDataSN(uint32(i)) |
|||
pdu.SetBufferOffset(uint32(i) * 4096) |
|||
pdu.DataSegment = bytes.Repeat([]byte{byte(i + 1)}, 4096) |
|||
if i == 1 { |
|||
pdu.SetOpSpecific1(FlagF) |
|||
} |
|||
if err := c.AddDataOut(pdu); err != nil { |
|||
t.Fatalf("PDU %d: %v", i, err) |
|||
} |
|||
} |
|||
|
|||
if !c.Done() { |
|||
t.Fatal("should be done") |
|||
} |
|||
data := c.Data() |
|||
if data[0] != 0x01 || data[4096] != 0x02 { |
|||
t.Fatal("data assembly wrong") |
|||
} |
|||
} |
|||
|
|||
func testDataOutImmediateData(t *testing.T) { |
|||
c := NewDataOutCollector(4096) |
|||
err := c.AddImmediateData(bytes.Repeat([]byte{0xFF}, 4096)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if !c.Done() { |
|||
t.Fatal("should be done with immediate data") |
|||
} |
|||
} |
|||
|
|||
func testDataOutImmediatePlusR2T(t *testing.T) { |
|||
c := NewDataOutCollector(8192) |
|||
|
|||
// Immediate: first 4096
|
|||
err := c.AddImmediateData(bytes.Repeat([]byte{0xAA}, 4096)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if c.Done() { |
|||
t.Fatal("should not be done yet") |
|||
} |
|||
if c.Remaining() != 4096 { |
|||
t.Fatalf("remaining: %d", c.Remaining()) |
|||
} |
|||
|
|||
// R2T-solicited Data-Out: next 4096
|
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIDataOut) |
|||
pdu.SetOpSpecific1(FlagF) |
|||
pdu.SetDataSN(0) |
|||
pdu.SetBufferOffset(4096) |
|||
pdu.DataSegment = bytes.Repeat([]byte{0xBB}, 4096) |
|||
if err := c.AddDataOut(pdu); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if !c.Done() { |
|||
t.Fatal("should be done") |
|||
} |
|||
|
|||
data := c.Data() |
|||
if data[0] != 0xAA || data[4096] != 0xBB { |
|||
t.Fatal("assembly wrong") |
|||
} |
|||
} |
|||
|
|||
func testDataOutWrongDataSN(t *testing.T) { |
|||
c := NewDataOutCollector(8192) |
|||
|
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIDataOut) |
|||
pdu.SetDataSN(1) // should be 0
|
|||
pdu.SetBufferOffset(0) |
|||
pdu.DataSegment = make([]byte, 4096) |
|||
|
|||
err := c.AddDataOut(pdu) |
|||
if err != ErrDataSNOrder { |
|||
t.Fatalf("expected ErrDataSNOrder, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testDataOutOverflow(t *testing.T) { |
|||
c := NewDataOutCollector(4096) |
|||
|
|||
pdu := &PDU{} |
|||
pdu.SetOpcode(OpSCSIDataOut) |
|||
pdu.SetDataSN(0) |
|||
pdu.SetBufferOffset(0) |
|||
pdu.DataSegment = make([]byte, 8192) // more than expected
|
|||
|
|||
err := c.AddDataOut(pdu) |
|||
if err != ErrDataOverflow { |
|||
t.Fatalf("expected ErrDataOverflow, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testR2TBuild(t *testing.T) { |
|||
pdu := BuildR2T(0x100, 0x200, 0, 4096, 4096, 5, 10, 20) |
|||
if pdu.Opcode() != OpR2T { |
|||
t.Fatal("wrong opcode") |
|||
} |
|||
if pdu.InitiatorTaskTag() != 0x100 { |
|||
t.Fatal("ITT wrong") |
|||
} |
|||
if pdu.TargetTransferTag() != 0x200 { |
|||
t.Fatal("TTT wrong") |
|||
} |
|||
if pdu.R2TSN() != 0 { |
|||
t.Fatal("R2TSN wrong") |
|||
} |
|||
if pdu.BufferOffset() != 4096 { |
|||
t.Fatal("offset wrong") |
|||
} |
|||
if pdu.DesiredDataLength() != 4096 { |
|||
t.Fatal("desired length wrong") |
|||
} |
|||
if pdu.StatSN() != 5 { |
|||
t.Fatal("StatSN wrong") |
|||
} |
|||
} |
|||
|
|||
func testSCSIResponseGood(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
statSN := uint32(10) |
|||
result := SCSIResult{Status: SCSIStatusGood} |
|||
err := SendSCSIResponse(w, result, 0x300, &statSN, 5, 15) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if statSN != 11 { |
|||
t.Fatal("StatSN not incremented") |
|||
} |
|||
|
|||
pdu, err := ReadPDU(w) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if pdu.Opcode() != OpSCSIResp { |
|||
t.Fatal("wrong opcode") |
|||
} |
|||
if pdu.SCSIStatus() != SCSIStatusGood { |
|||
t.Fatal("status wrong") |
|||
} |
|||
if len(pdu.DataSegment) != 0 { |
|||
t.Fatal("no data expected for good status") |
|||
} |
|||
} |
|||
|
|||
func testSCSIResponseCheckCondition(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
statSN := uint32(20) |
|||
result := SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: SenseIllegalRequest, |
|||
SenseASC: ASCInvalidOpcode, |
|||
SenseASCQ: ASCQLuk, |
|||
} |
|||
err := SendSCSIResponse(w, result, 0x400, &statSN, 5, 15) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
pdu, err := ReadPDU(w) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if pdu.SCSIStatus() != SCSIStatusCheckCond { |
|||
t.Fatal("status wrong") |
|||
} |
|||
// Data segment should contain sense data with 2-byte length prefix
|
|||
if len(pdu.DataSegment) < 20 { // 2 + 18
|
|||
t.Fatalf("sense data too short: %d", len(pdu.DataSegment)) |
|||
} |
|||
senseLen := int(pdu.DataSegment[0])<<8 | int(pdu.DataSegment[1]) |
|||
if senseLen != 18 { |
|||
t.Fatalf("sense length: %d", senseLen) |
|||
} |
|||
} |
|||
|
|||
func testDataInStatSNIncrement(t *testing.T) { |
|||
w := &bytes.Buffer{} |
|||
dw := NewDataInWriter(1024) |
|||
data := bytes.Repeat([]byte{0x00}, 3072) // 3 PDUs
|
|||
statSN := uint32(100) |
|||
|
|||
dw.WriteDataIn(w, data, 0x700, 1, 10, &statSN) |
|||
// Only the final PDU has S-bit, so StatSN increments once
|
|||
if statSN != 101 { |
|||
t.Fatalf("StatSN should be 101, got %d", statSN) |
|||
} |
|||
} |
|||
@ -0,0 +1,68 @@ |
|||
package iscsi |
|||
|
|||
// HandleTextRequest processes a Text Request PDU.
|
|||
// Currently supports SendTargets discovery (RFC 7143, Section 12.3).
|
|||
func HandleTextRequest(req *PDU, targets []DiscoveryTarget) *PDU { |
|||
resp := &PDU{} |
|||
resp.SetOpcode(OpTextResp) |
|||
resp.SetOpSpecific1(FlagF) // Final
|
|||
resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) |
|||
resp.SetTargetTransferTag(0xFFFFFFFF) // no continuation
|
|||
|
|||
params, err := ParseParams(req.DataSegment) |
|||
if err != nil { |
|||
// Malformed — return empty response
|
|||
return resp |
|||
} |
|||
|
|||
val, ok := params.Get("SendTargets") |
|||
if !ok { |
|||
return resp |
|||
} |
|||
|
|||
// Discovery responses can have duplicate keys (TargetName appears
|
|||
// for each target), so we use EncodeDiscoveryTargets directly.
|
|||
var matched []DiscoveryTarget |
|||
|
|||
switch val { |
|||
case "All": |
|||
matched = targets |
|||
default: |
|||
for _, tgt := range targets { |
|||
if tgt.Name == val { |
|||
matched = append(matched, tgt) |
|||
break |
|||
} |
|||
} |
|||
} |
|||
|
|||
if data := EncodeDiscoveryTargets(matched); len(data) > 0 { |
|||
resp.DataSegment = data |
|||
} |
|||
return resp |
|||
} |
|||
|
|||
// DiscoveryTarget represents a target available for discovery.
|
|||
type DiscoveryTarget struct { |
|||
Name string // IQN, e.g., "iqn.2024.com.seaweedfs:vol1"
|
|||
Address string // IP:port,portal-group, e.g., "10.0.0.1:3260,1"
|
|||
} |
|||
|
|||
// EncodeDiscoveryTargets encodes multiple targets into text parameter format.
|
|||
// Each target produces TargetName=<iqn>\0TargetAddress=<addr>\0
|
|||
// This handles the special multi-value encoding where the same key can appear
|
|||
// multiple times (unlike normal negotiation params).
|
|||
func EncodeDiscoveryTargets(targets []DiscoveryTarget) []byte { |
|||
if len(targets) == 0 { |
|||
return nil |
|||
} |
|||
|
|||
var buf []byte |
|||
for _, tgt := range targets { |
|||
buf = append(buf, "TargetName="+tgt.Name+"\x00"...) |
|||
if tgt.Address != "" { |
|||
buf = append(buf, "TargetAddress="+tgt.Address+"\x00"...) |
|||
} |
|||
} |
|||
return buf |
|||
} |
|||
@ -0,0 +1,213 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"strings" |
|||
"testing" |
|||
) |
|||
|
|||
func TestDiscovery(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"send_targets_all", testSendTargetsAll}, |
|||
{"send_targets_specific", testSendTargetsSpecific}, |
|||
{"send_targets_not_found", testSendTargetsNotFound}, |
|||
{"send_targets_empty_list", testSendTargetsEmptyList}, |
|||
{"no_send_targets_key", testNoSendTargetsKey}, |
|||
{"malformed_text_request", testMalformedTextRequest}, |
|||
{"special_chars_in_iqn", testSpecialCharsInIQN}, |
|||
{"multiple_targets", testMultipleTargets}, |
|||
{"encode_discovery_targets", testEncodeDiscoveryTargets}, |
|||
{"encode_empty_targets", testEncodeEmptyTargets}, |
|||
{"target_without_address", testTargetWithoutAddress}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
var testTargets = []DiscoveryTarget{ |
|||
{Name: "iqn.2024.com.seaweedfs:vol1", Address: "10.0.0.1:3260,1"}, |
|||
{Name: "iqn.2024.com.seaweedfs:vol2", Address: "10.0.0.2:3260,1"}, |
|||
} |
|||
|
|||
func makeTextReq(params *Params) *PDU { |
|||
p := &PDU{} |
|||
p.SetOpcode(OpTextReq) |
|||
p.SetOpSpecific1(FlagF) |
|||
p.SetInitiatorTaskTag(0x1234) |
|||
if params != nil { |
|||
p.DataSegment = params.Encode() |
|||
} |
|||
return p |
|||
} |
|||
|
|||
func testSendTargetsAll(t *testing.T) { |
|||
params := NewParams() |
|||
params.Set("SendTargets", "All") |
|||
req := makeTextReq(params) |
|||
|
|||
resp := HandleTextRequest(req, testTargets) |
|||
|
|||
if resp.Opcode() != OpTextResp { |
|||
t.Fatalf("opcode: 0x%02x", resp.Opcode()) |
|||
} |
|||
if resp.InitiatorTaskTag() != 0x1234 { |
|||
t.Fatal("ITT mismatch") |
|||
} |
|||
|
|||
body := string(resp.DataSegment) |
|||
if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol1") { |
|||
t.Fatal("missing vol1") |
|||
} |
|||
if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol2") { |
|||
t.Fatal("missing vol2") |
|||
} |
|||
if !strings.Contains(body, "10.0.0.1:3260,1") { |
|||
t.Fatal("missing vol1 address") |
|||
} |
|||
} |
|||
|
|||
func testSendTargetsSpecific(t *testing.T) { |
|||
params := NewParams() |
|||
params.Set("SendTargets", "iqn.2024.com.seaweedfs:vol2") |
|||
req := makeTextReq(params) |
|||
|
|||
resp := HandleTextRequest(req, testTargets) |
|||
body := string(resp.DataSegment) |
|||
if !strings.Contains(body, "vol2") { |
|||
t.Fatal("missing vol2") |
|||
} |
|||
// Should not contain vol1
|
|||
if strings.Contains(body, "vol1") { |
|||
t.Fatal("should not contain vol1") |
|||
} |
|||
} |
|||
|
|||
func testSendTargetsNotFound(t *testing.T) { |
|||
params := NewParams() |
|||
params.Set("SendTargets", "iqn.2024.com.seaweedfs:nonexistent") |
|||
req := makeTextReq(params) |
|||
|
|||
resp := HandleTextRequest(req, testTargets) |
|||
if len(resp.DataSegment) != 0 { |
|||
t.Fatalf("expected empty response, got %q", resp.DataSegment) |
|||
} |
|||
} |
|||
|
|||
func testSendTargetsEmptyList(t *testing.T) { |
|||
params := NewParams() |
|||
params.Set("SendTargets", "All") |
|||
req := makeTextReq(params) |
|||
|
|||
resp := HandleTextRequest(req, nil) |
|||
if len(resp.DataSegment) != 0 { |
|||
t.Fatalf("expected empty response, got %q", resp.DataSegment) |
|||
} |
|||
} |
|||
|
|||
func testNoSendTargetsKey(t *testing.T) { |
|||
params := NewParams() |
|||
params.Set("SomethingElse", "value") |
|||
req := makeTextReq(params) |
|||
|
|||
resp := HandleTextRequest(req, testTargets) |
|||
if len(resp.DataSegment) != 0 { |
|||
t.Fatal("expected empty response for non-SendTargets request") |
|||
} |
|||
} |
|||
|
|||
func testMalformedTextRequest(t *testing.T) { |
|||
req := &PDU{} |
|||
req.SetOpcode(OpTextReq) |
|||
req.SetInitiatorTaskTag(0x5678) |
|||
req.DataSegment = []byte("not a valid param format") // no '='
|
|||
|
|||
resp := HandleTextRequest(req, testTargets) |
|||
// Should return empty response without error
|
|||
if resp.Opcode() != OpTextResp { |
|||
t.Fatalf("opcode: 0x%02x", resp.Opcode()) |
|||
} |
|||
if resp.InitiatorTaskTag() != 0x5678 { |
|||
t.Fatal("ITT mismatch") |
|||
} |
|||
} |
|||
|
|||
func testSpecialCharsInIQN(t *testing.T) { |
|||
targets := []DiscoveryTarget{ |
|||
{Name: "iqn.2024-01.com.example:storage.tape1.sys1.xyz", Address: "192.168.1.100:3260,1"}, |
|||
} |
|||
params := NewParams() |
|||
params.Set("SendTargets", "All") |
|||
req := makeTextReq(params) |
|||
|
|||
resp := HandleTextRequest(req, targets) |
|||
body := string(resp.DataSegment) |
|||
if !strings.Contains(body, "iqn.2024-01.com.example:storage.tape1.sys1.xyz") { |
|||
t.Fatal("IQN with special chars not found") |
|||
} |
|||
} |
|||
|
|||
func testMultipleTargets(t *testing.T) { |
|||
targets := make([]DiscoveryTarget, 10) |
|||
for i := range targets { |
|||
targets[i] = DiscoveryTarget{ |
|||
Name: "iqn.2024.com.test:vol" + string(rune('0'+i)), |
|||
Address: "10.0.0.1:3260,1", |
|||
} |
|||
} |
|||
|
|||
params := NewParams() |
|||
params.Set("SendTargets", "All") |
|||
req := makeTextReq(params) |
|||
|
|||
resp := HandleTextRequest(req, targets) |
|||
body := string(resp.DataSegment) |
|||
// The response uses EncodeDiscoveryTargets internally via the params,
|
|||
// but since Params doesn't allow duplicate keys, the last one wins.
|
|||
// This is a known limitation — for multi-target discovery, we use
|
|||
// EncodeDiscoveryTargets directly. Let's verify at least the last target.
|
|||
if !strings.Contains(body, "TargetName=") { |
|||
t.Fatal("no TargetName in response") |
|||
} |
|||
} |
|||
|
|||
func testEncodeDiscoveryTargets(t *testing.T) { |
|||
encoded := EncodeDiscoveryTargets(testTargets) |
|||
body := string(encoded) |
|||
|
|||
// Should contain both targets with proper format
|
|||
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") { |
|||
t.Fatal("missing vol1") |
|||
} |
|||
if !strings.Contains(body, "TargetAddress=10.0.0.1:3260,1\x00") { |
|||
t.Fatal("missing vol1 address") |
|||
} |
|||
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol2\x00") { |
|||
t.Fatal("missing vol2") |
|||
} |
|||
} |
|||
|
|||
func testEncodeEmptyTargets(t *testing.T) { |
|||
encoded := EncodeDiscoveryTargets(nil) |
|||
if encoded != nil { |
|||
t.Fatalf("expected nil, got %q", encoded) |
|||
} |
|||
} |
|||
|
|||
func testTargetWithoutAddress(t *testing.T) { |
|||
targets := []DiscoveryTarget{ |
|||
{Name: "iqn.2024.com.seaweedfs:vol1"}, // no address
|
|||
} |
|||
encoded := EncodeDiscoveryTargets(targets) |
|||
body := string(encoded) |
|||
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") { |
|||
t.Fatal("missing target name") |
|||
} |
|||
if strings.Contains(body, "TargetAddress") { |
|||
t.Fatal("should not have TargetAddress") |
|||
} |
|||
} |
|||
@ -0,0 +1,412 @@ |
|||
package iscsi_test |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"io" |
|||
"log" |
|||
"net" |
|||
"path/filepath" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol" |
|||
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi" |
|||
) |
|||
|
|||
// blockVolAdapter wraps a BlockVol to implement iscsi.BlockDevice.
|
|||
type blockVolAdapter struct { |
|||
vol *blockvol.BlockVol |
|||
} |
|||
|
|||
func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) { |
|||
return a.vol.ReadLBA(lba, length) |
|||
} |
|||
func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error { |
|||
return a.vol.WriteLBA(lba, data) |
|||
} |
|||
func (a *blockVolAdapter) Trim(lba uint64, length uint32) error { |
|||
return a.vol.Trim(lba, length) |
|||
} |
|||
func (a *blockVolAdapter) SyncCache() error { |
|||
return a.vol.SyncCache() |
|||
} |
|||
func (a *blockVolAdapter) BlockSize() uint32 { |
|||
return a.vol.Info().BlockSize |
|||
} |
|||
func (a *blockVolAdapter) VolumeSize() uint64 { |
|||
return a.vol.Info().VolumeSize |
|||
} |
|||
func (a *blockVolAdapter) IsHealthy() bool { |
|||
return a.vol.Info().Healthy |
|||
} |
|||
|
|||
const ( |
|||
intTargetName = "iqn.2024.com.seaweedfs:integration" |
|||
intInitiatorName = "iqn.2024.com.test:client" |
|||
) |
|||
|
|||
func createTestVol(t *testing.T) *blockvol.BlockVol { |
|||
t.Helper() |
|||
path := filepath.Join(t.TempDir(), "test.blk") |
|||
vol, err := blockvol.CreateBlockVol(path, blockvol.CreateOptions{ |
|||
VolumeSize: 1024 * 4096, // 1024 blocks = 4MB
|
|||
BlockSize: 4096, |
|||
WALSize: 1024 * 1024, // 1MB WAL
|
|||
}) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
t.Cleanup(func() { vol.Close() }) |
|||
return vol |
|||
} |
|||
|
|||
func setupIntegrationTarget(t *testing.T) (net.Conn, *iscsi.TargetServer) { |
|||
t.Helper() |
|||
vol := createTestVol(t) |
|||
adapter := &blockVolAdapter{vol: vol} |
|||
|
|||
config := iscsi.DefaultTargetConfig() |
|||
config.TargetName = intTargetName |
|||
logger := log.New(io.Discard, "", 0) |
|||
ts := iscsi.NewTargetServer("127.0.0.1:0", config, logger) |
|||
ts.AddVolume(intTargetName, adapter) |
|||
|
|||
ln, err := net.Listen("tcp", "127.0.0.1:0") |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
go ts.Serve(ln) |
|||
t.Cleanup(func() { ts.Close() }) |
|||
|
|||
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
t.Cleanup(func() { conn.Close() }) |
|||
|
|||
// Login
|
|||
params := iscsi.NewParams() |
|||
params.Set("InitiatorName", intInitiatorName) |
|||
params.Set("TargetName", intTargetName) |
|||
params.Set("SessionType", "Normal") |
|||
|
|||
loginReq := &iscsi.PDU{} |
|||
loginReq.SetOpcode(iscsi.OpLoginReq) |
|||
loginReq.SetLoginStages(iscsi.StageSecurityNeg, iscsi.StageFullFeature) |
|||
loginReq.SetLoginTransit(true) |
|||
loginReq.SetISID([6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01}) |
|||
loginReq.SetCmdSN(1) |
|||
loginReq.DataSegment = params.Encode() |
|||
|
|||
if err := iscsi.WritePDU(conn, loginReq); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
resp, err := iscsi.ReadPDU(conn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.LoginStatusClass() != iscsi.LoginStatusSuccess { |
|||
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
|||
} |
|||
|
|||
return conn, ts |
|||
} |
|||
|
|||
func TestIntegration(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"write_read_verify", testWriteReadVerify}, |
|||
{"multi_block_write_read", testMultiBlockWriteRead}, |
|||
{"inquiry_capacity", testInquiryCapacity}, |
|||
{"sync_cache", testIntSyncCache}, |
|||
{"test_unit_ready", testIntTestUnitReady}, |
|||
{"concurrent_readers_writers", testConcurrentReadersWriters}, |
|||
{"write_at_boundary", testWriteAtBoundary}, |
|||
{"unmap_integration", testUnmapIntegration}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func sendSCSICmd(t *testing.T, conn net.Conn, cdb [16]byte, cmdSN uint32, read bool, write bool, dataOut []byte, expLen uint32) *iscsi.PDU { |
|||
t.Helper() |
|||
cmd := &iscsi.PDU{} |
|||
cmd.SetOpcode(iscsi.OpSCSICmd) |
|||
flags := uint8(iscsi.FlagF) |
|||
if read { |
|||
flags |= iscsi.FlagR |
|||
} |
|||
if write { |
|||
flags |= iscsi.FlagW |
|||
} |
|||
cmd.SetOpSpecific1(flags) |
|||
cmd.SetInitiatorTaskTag(cmdSN) |
|||
cmd.SetExpectedDataTransferLength(expLen) |
|||
cmd.SetCmdSN(cmdSN) |
|||
cmd.SetCDB(cdb) |
|||
if dataOut != nil { |
|||
cmd.DataSegment = dataOut |
|||
} |
|||
|
|||
if err := iscsi.WritePDU(conn, cmd); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := iscsi.ReadPDU(conn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
return resp |
|||
} |
|||
|
|||
func testWriteReadVerify(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
// Write pattern to LBA 0
|
|||
writeData := make([]byte, 4096) |
|||
for i := range writeData { |
|||
writeData[i] = byte(i % 251) // prime modulus for pattern
|
|||
} |
|||
|
|||
var wCDB [16]byte |
|||
wCDB[0] = iscsi.ScsiWrite10 |
|||
binary.BigEndian.PutUint32(wCDB[2:6], 0) // LBA 0
|
|||
binary.BigEndian.PutUint16(wCDB[7:9], 1) // 1 block
|
|||
|
|||
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatalf("write failed: status %d", resp.SCSIStatus()) |
|||
} |
|||
|
|||
// Read it back
|
|||
var rCDB [16]byte |
|||
rCDB[0] = iscsi.ScsiRead10 |
|||
binary.BigEndian.PutUint32(rCDB[2:6], 0) |
|||
binary.BigEndian.PutUint16(rCDB[7:9], 1) |
|||
|
|||
resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, 4096) |
|||
if resp2.Opcode() != iscsi.OpSCSIDataIn { |
|||
t.Fatalf("expected Data-In, got %s", iscsi.OpcodeName(resp2.Opcode())) |
|||
} |
|||
|
|||
if !bytes.Equal(resp2.DataSegment, writeData) { |
|||
t.Fatal("data integrity verification failed") |
|||
} |
|||
} |
|||
|
|||
func testMultiBlockWriteRead(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
// Write 4 blocks at LBA 10
|
|||
blockCount := uint16(4) |
|||
dataLen := int(blockCount) * 4096 |
|||
writeData := make([]byte, dataLen) |
|||
for i := range writeData { |
|||
writeData[i] = byte((i / 4096) + 1) // each block has different fill
|
|||
} |
|||
|
|||
var wCDB [16]byte |
|||
wCDB[0] = iscsi.ScsiWrite10 |
|||
binary.BigEndian.PutUint32(wCDB[2:6], 10) |
|||
binary.BigEndian.PutUint16(wCDB[7:9], blockCount) |
|||
|
|||
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, uint32(dataLen)) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatalf("write failed: status %d", resp.SCSIStatus()) |
|||
} |
|||
|
|||
// Read back all 4 blocks
|
|||
var rCDB [16]byte |
|||
rCDB[0] = iscsi.ScsiRead10 |
|||
binary.BigEndian.PutUint32(rCDB[2:6], 10) |
|||
binary.BigEndian.PutUint16(rCDB[7:9], blockCount) |
|||
|
|||
resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, uint32(dataLen)) |
|||
|
|||
// May be split across multiple Data-In PDUs — reassemble
|
|||
var readData []byte |
|||
readData = append(readData, resp2.DataSegment...) |
|||
|
|||
// If first PDU doesn't have S-bit, read more
|
|||
for resp2.OpSpecific1()&iscsi.FlagS == 0 { |
|||
var err error |
|||
resp2, err = iscsi.ReadPDU(conn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
readData = append(readData, resp2.DataSegment...) |
|||
} |
|||
|
|||
if !bytes.Equal(readData, writeData) { |
|||
t.Fatal("multi-block data integrity failed") |
|||
} |
|||
} |
|||
|
|||
func testInquiryCapacity(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
// TEST UNIT READY
|
|||
var tuCDB [16]byte |
|||
tuCDB[0] = iscsi.ScsiTestUnitReady |
|||
resp := sendSCSICmd(t, conn, tuCDB, 2, false, false, nil, 0) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatal("test unit ready failed") |
|||
} |
|||
|
|||
// INQUIRY
|
|||
var iqCDB [16]byte |
|||
iqCDB[0] = iscsi.ScsiInquiry |
|||
binary.BigEndian.PutUint16(iqCDB[3:5], 96) |
|||
resp = sendSCSICmd(t, conn, iqCDB, 3, true, false, nil, 96) |
|||
if resp.Opcode() != iscsi.OpSCSIDataIn { |
|||
t.Fatal("expected Data-In for inquiry") |
|||
} |
|||
if resp.DataSegment[0] != 0x00 { // SBC device
|
|||
t.Fatal("not SBC device type") |
|||
} |
|||
|
|||
// READ CAPACITY 10
|
|||
var rcCDB [16]byte |
|||
rcCDB[0] = iscsi.ScsiReadCapacity10 |
|||
resp = sendSCSICmd(t, conn, rcCDB, 4, true, false, nil, 8) |
|||
if resp.Opcode() != iscsi.OpSCSIDataIn { |
|||
t.Fatal("expected Data-In for read capacity") |
|||
} |
|||
lastLBA := binary.BigEndian.Uint32(resp.DataSegment[0:4]) |
|||
blockSize := binary.BigEndian.Uint32(resp.DataSegment[4:8]) |
|||
if lastLBA != 1023 { // 1024 blocks, last = 1023
|
|||
t.Fatalf("last LBA: %d, expected 1023", lastLBA) |
|||
} |
|||
if blockSize != 4096 { |
|||
t.Fatalf("block size: %d", blockSize) |
|||
} |
|||
} |
|||
|
|||
func testIntSyncCache(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = iscsi.ScsiSyncCache10 |
|||
resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatalf("sync cache failed: %d", resp.SCSIStatus()) |
|||
} |
|||
} |
|||
|
|||
func testIntTestUnitReady(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = iscsi.ScsiTestUnitReady |
|||
resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatal("TUR failed") |
|||
} |
|||
} |
|||
|
|||
func testConcurrentReadersWriters(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
// Sequential writes to different LBAs, then sequential reads to verify
|
|||
cmdSN := uint32(2) |
|||
for lba := uint32(0); lba < 10; lba++ { |
|||
data := bytes.Repeat([]byte{byte(lba)}, 4096) |
|||
var wCDB [16]byte |
|||
wCDB[0] = iscsi.ScsiWrite10 |
|||
binary.BigEndian.PutUint32(wCDB[2:6], lba) |
|||
binary.BigEndian.PutUint16(wCDB[7:9], 1) |
|||
resp := sendSCSICmd(t, conn, wCDB, cmdSN, false, true, data, 4096) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatalf("write LBA %d failed", lba) |
|||
} |
|||
cmdSN++ |
|||
} |
|||
|
|||
// Read back and verify
|
|||
for lba := uint32(0); lba < 10; lba++ { |
|||
var rCDB [16]byte |
|||
rCDB[0] = iscsi.ScsiRead10 |
|||
binary.BigEndian.PutUint32(rCDB[2:6], lba) |
|||
binary.BigEndian.PutUint16(rCDB[7:9], 1) |
|||
resp := sendSCSICmd(t, conn, rCDB, cmdSN, true, false, nil, 4096) |
|||
if resp.Opcode() != iscsi.OpSCSIDataIn { |
|||
t.Fatalf("LBA %d: expected Data-In", lba) |
|||
} |
|||
expected := bytes.Repeat([]byte{byte(lba)}, 4096) |
|||
if !bytes.Equal(resp.DataSegment, expected) { |
|||
t.Fatalf("LBA %d: data mismatch", lba) |
|||
} |
|||
cmdSN++ |
|||
} |
|||
} |
|||
|
|||
func testWriteAtBoundary(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
// Write at last valid LBA (1023)
|
|||
data := bytes.Repeat([]byte{0xFF}, 4096) |
|||
var wCDB [16]byte |
|||
wCDB[0] = iscsi.ScsiWrite10 |
|||
binary.BigEndian.PutUint32(wCDB[2:6], 1023) |
|||
binary.BigEndian.PutUint16(wCDB[7:9], 1) |
|||
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, data, 4096) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatal("boundary write failed") |
|||
} |
|||
|
|||
// Write past end — should fail
|
|||
var wCDB2 [16]byte |
|||
wCDB2[0] = iscsi.ScsiWrite10 |
|||
binary.BigEndian.PutUint32(wCDB2[2:6], 1024) // out of bounds
|
|||
binary.BigEndian.PutUint16(wCDB2[7:9], 1) |
|||
resp2 := sendSCSICmd(t, conn, wCDB2, 3, false, true, data, 4096) |
|||
if resp2.SCSIStatus() == iscsi.SCSIStatusGood { |
|||
t.Fatal("OOB write should fail") |
|||
} |
|||
} |
|||
|
|||
func testUnmapIntegration(t *testing.T) { |
|||
conn, _ := setupIntegrationTarget(t) |
|||
|
|||
// Write data at LBA 5
|
|||
writeData := bytes.Repeat([]byte{0xCC}, 4096) |
|||
var wCDB [16]byte |
|||
wCDB[0] = iscsi.ScsiWrite10 |
|||
binary.BigEndian.PutUint32(wCDB[2:6], 5) |
|||
binary.BigEndian.PutUint16(wCDB[7:9], 1) |
|||
sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096) |
|||
|
|||
// UNMAP LBA 5
|
|||
unmapData := make([]byte, 24) |
|||
binary.BigEndian.PutUint16(unmapData[0:2], 22) |
|||
binary.BigEndian.PutUint16(unmapData[2:4], 16) |
|||
binary.BigEndian.PutUint64(unmapData[8:16], 5) |
|||
binary.BigEndian.PutUint32(unmapData[16:20], 1) |
|||
|
|||
var uCDB [16]byte |
|||
uCDB[0] = iscsi.ScsiUnmap |
|||
resp := sendSCSICmd(t, conn, uCDB, 3, false, true, unmapData, uint32(len(unmapData))) |
|||
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
|||
t.Fatalf("unmap failed: %d", resp.SCSIStatus()) |
|||
} |
|||
|
|||
// Read back — should be zeros
|
|||
var rCDB [16]byte |
|||
rCDB[0] = iscsi.ScsiRead10 |
|||
binary.BigEndian.PutUint32(rCDB[2:6], 5) |
|||
binary.BigEndian.PutUint16(rCDB[7:9], 1) |
|||
resp2 := sendSCSICmd(t, conn, rCDB, 4, true, false, nil, 4096) |
|||
if resp2.Opcode() != iscsi.OpSCSIDataIn { |
|||
t.Fatal("expected Data-In") |
|||
} |
|||
zeros := make([]byte, 4096) |
|||
if !bytes.Equal(resp2.DataSegment, zeros) { |
|||
t.Fatal("unmapped block should return zeros") |
|||
} |
|||
} |
|||
@ -0,0 +1,396 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"errors" |
|||
"strconv" |
|||
) |
|||
|
|||
// Login status classes (RFC 7143, Section 11.13.5)
|
|||
const ( |
|||
LoginStatusSuccess uint8 = 0x00 |
|||
LoginStatusRedirect uint8 = 0x01 |
|||
LoginStatusInitiatorErr uint8 = 0x02 |
|||
LoginStatusTargetErr uint8 = 0x03 |
|||
) |
|||
|
|||
// Login status details
|
|||
const ( |
|||
LoginDetailSuccess uint8 = 0x00 |
|||
LoginDetailTargetMoved uint8 = 0x01 // redirect: permanently moved
|
|||
LoginDetailTargetMovedTemp uint8 = 0x02 // redirect: temporarily moved
|
|||
LoginDetailInitiatorError uint8 = 0x00 // initiator error (miscellaneous)
|
|||
LoginDetailAuthFailure uint8 = 0x01 // authentication failure
|
|||
LoginDetailAuthorizationFail uint8 = 0x02 // authorization failure
|
|||
LoginDetailNotFound uint8 = 0x03 // target not found
|
|||
LoginDetailTargetRemoved uint8 = 0x04 // target removed
|
|||
LoginDetailUnsupported uint8 = 0x05 // unsupported version
|
|||
LoginDetailTooManyConns uint8 = 0x06 // too many connections
|
|||
LoginDetailMissingParam uint8 = 0x07 // missing parameter
|
|||
LoginDetailNoSessionSlot uint8 = 0x08 // no session slot
|
|||
LoginDetailNoTCPConn uint8 = 0x09 // no TCP connection available
|
|||
LoginDetailNoSession uint8 = 0x0a // no existing session
|
|||
LoginDetailTargetError uint8 = 0x00 // target error (miscellaneous)
|
|||
LoginDetailServiceUnavail uint8 = 0x01 // service unavailable
|
|||
LoginDetailOutOfResources uint8 = 0x02 // out of resources
|
|||
) |
|||
|
|||
var ( |
|||
ErrLoginInvalidStage = errors.New("iscsi: invalid login stage transition") |
|||
ErrLoginMissingParam = errors.New("iscsi: missing required login parameter") |
|||
ErrLoginInvalidISID = errors.New("iscsi: invalid ISID") |
|||
ErrLoginTargetNotFound = errors.New("iscsi: target not found") |
|||
ErrLoginSessionExists = errors.New("iscsi: session already exists for this ISID") |
|||
ErrLoginInvalidRequest = errors.New("iscsi: invalid login request") |
|||
) |
|||
|
|||
// LoginPhase tracks the current phase of login negotiation.
|
|||
type LoginPhase int |
|||
|
|||
const ( |
|||
LoginPhaseStart LoginPhase = iota // before first PDU
|
|||
LoginPhaseSecurity // CSG=SecurityNeg
|
|||
LoginPhaseOperational // CSG=LoginOp
|
|||
LoginPhaseDone // transition to FFP complete
|
|||
) |
|||
|
|||
// TargetConfig holds the target-side negotiation defaults.
|
|||
type TargetConfig struct { |
|||
TargetName string |
|||
TargetAlias string |
|||
MaxRecvDataSegmentLength int |
|||
MaxBurstLength int |
|||
FirstBurstLength int |
|||
MaxConnections int |
|||
MaxOutstandingR2T int |
|||
DefaultTime2Wait int |
|||
DefaultTime2Retain int |
|||
DataPDUInOrder bool |
|||
DataSequenceInOrder bool |
|||
InitialR2T bool |
|||
ImmediateData bool |
|||
ErrorRecoveryLevel int |
|||
} |
|||
|
|||
// DefaultTargetConfig returns sensible defaults for a target.
|
|||
func DefaultTargetConfig() TargetConfig { |
|||
return TargetConfig{ |
|||
MaxRecvDataSegmentLength: 262144, // 256KB
|
|||
MaxBurstLength: 262144, |
|||
FirstBurstLength: 65536, |
|||
MaxConnections: 1, |
|||
MaxOutstandingR2T: 1, |
|||
DefaultTime2Wait: 2, |
|||
DefaultTime2Retain: 0, |
|||
DataPDUInOrder: true, |
|||
DataSequenceInOrder: true, |
|||
InitialR2T: true, |
|||
ImmediateData: true, |
|||
ErrorRecoveryLevel: 0, |
|||
} |
|||
} |
|||
|
|||
// LoginNegotiator handles the target side of login negotiation.
|
|||
type LoginNegotiator struct { |
|||
config TargetConfig |
|||
phase LoginPhase |
|||
isid [6]byte |
|||
tsih uint16 |
|||
targetOK bool // target name validated
|
|||
|
|||
// Negotiated values (updated during negotiation)
|
|||
NegMaxRecvDataSegLen int |
|||
NegMaxBurstLength int |
|||
NegFirstBurstLength int |
|||
NegInitialR2T bool |
|||
NegImmediateData bool |
|||
|
|||
// Initiator/target info captured during login
|
|||
InitiatorName string |
|||
TargetName string |
|||
SessionType string // "Normal" or "Discovery"
|
|||
} |
|||
|
|||
// NewLoginNegotiator creates a negotiator for a new login sequence.
|
|||
func NewLoginNegotiator(config TargetConfig) *LoginNegotiator { |
|||
return &LoginNegotiator{ |
|||
config: config, |
|||
phase: LoginPhaseStart, |
|||
NegMaxRecvDataSegLen: config.MaxRecvDataSegmentLength, |
|||
NegMaxBurstLength: config.MaxBurstLength, |
|||
NegFirstBurstLength: config.FirstBurstLength, |
|||
NegInitialR2T: config.InitialR2T, |
|||
NegImmediateData: config.ImmediateData, |
|||
} |
|||
} |
|||
|
|||
// HandleLoginPDU processes one login request PDU and returns the response PDU.
|
|||
// It manages stage transitions and parameter negotiation.
|
|||
func (ln *LoginNegotiator) HandleLoginPDU(req *PDU, resolver TargetResolver) *PDU { |
|||
resp := &PDU{} |
|||
resp.SetOpcode(OpLoginResp) |
|||
resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) |
|||
resp.SetISID(req.ISID()) |
|||
resp.SetTSIH(req.TSIH()) |
|||
|
|||
// Validate opcode
|
|||
if req.Opcode() != OpLoginReq { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
|||
return resp |
|||
} |
|||
|
|||
csg := req.LoginCSG() |
|||
nsg := req.LoginNSG() |
|||
transit := req.LoginTransit() |
|||
|
|||
// Parse text parameters from data segment
|
|||
params, err := ParseParams(req.DataSegment) |
|||
if err != nil { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
|||
return resp |
|||
} |
|||
|
|||
// Process based on current stage
|
|||
respParams := NewParams() |
|||
|
|||
switch csg { |
|||
case StageSecurityNeg: |
|||
if ln.phase != LoginPhaseStart && ln.phase != LoginPhaseSecurity { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
|||
return resp |
|||
} |
|||
ln.phase = LoginPhaseSecurity |
|||
|
|||
// Capture initiator name
|
|||
if name, ok := params.Get("InitiatorName"); ok { |
|||
ln.InitiatorName = name |
|||
} else if ln.InitiatorName == "" { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) |
|||
return resp |
|||
} |
|||
|
|||
// Session type
|
|||
if st, ok := params.Get("SessionType"); ok { |
|||
ln.SessionType = st |
|||
} |
|||
if ln.SessionType == "" { |
|||
ln.SessionType = "Normal" |
|||
} |
|||
|
|||
// Target name (required for Normal sessions)
|
|||
if tn, ok := params.Get("TargetName"); ok { |
|||
if ln.SessionType == "Normal" { |
|||
if resolver == nil || !resolver.HasTarget(tn) { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailNotFound) |
|||
return resp |
|||
} |
|||
} |
|||
ln.TargetName = tn |
|||
ln.targetOK = true |
|||
} else if ln.SessionType == "Normal" && !ln.targetOK { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) |
|||
return resp |
|||
} |
|||
|
|||
// ISID
|
|||
ln.isid = req.ISID() |
|||
|
|||
// We don't implement CHAP — declare AuthMethod=None
|
|||
respParams.Set("AuthMethod", "None") |
|||
|
|||
if transit { |
|||
if nsg == StageLoginOp { |
|||
ln.phase = LoginPhaseOperational |
|||
} else if nsg == StageFullFeature { |
|||
ln.phase = LoginPhaseDone |
|||
} |
|||
} |
|||
|
|||
case StageLoginOp: |
|||
if ln.phase != LoginPhaseOperational && ln.phase != LoginPhaseSecurity { |
|||
// Allow direct jump to LoginOp if security was skipped
|
|||
if ln.phase == LoginPhaseStart { |
|||
// Need InitiatorName at minimum
|
|||
if name, ok := params.Get("InitiatorName"); ok { |
|||
ln.InitiatorName = name |
|||
} else if ln.InitiatorName == "" { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) |
|||
return resp |
|||
} |
|||
} else { |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
|||
return resp |
|||
} |
|||
} |
|||
ln.phase = LoginPhaseOperational |
|||
|
|||
// Negotiate operational parameters
|
|||
ln.negotiateParams(params, respParams) |
|||
|
|||
if transit && nsg == StageFullFeature { |
|||
ln.phase = LoginPhaseDone |
|||
} |
|||
|
|||
default: |
|||
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
|||
return resp |
|||
} |
|||
|
|||
// Build response
|
|||
resp.SetLoginStages(csg, nsg) |
|||
if transit { |
|||
resp.SetLoginTransit(true) |
|||
} |
|||
resp.SetLoginStatus(LoginStatusSuccess, LoginDetailSuccess) |
|||
|
|||
// Assign TSIH on first successful login
|
|||
if ln.tsih == 0 { |
|||
ln.tsih = 1 // simplified: single session
|
|||
} |
|||
resp.SetTSIH(ln.tsih) |
|||
|
|||
// Encode response params
|
|||
if respParams.Len() > 0 { |
|||
resp.DataSegment = respParams.Encode() |
|||
} |
|||
|
|||
return resp |
|||
} |
|||
|
|||
// Done returns true if login negotiation is complete.
|
|||
func (ln *LoginNegotiator) Done() bool { |
|||
return ln.phase == LoginPhaseDone |
|||
} |
|||
|
|||
// Phase returns the current login phase.
|
|||
func (ln *LoginNegotiator) Phase() LoginPhase { |
|||
return ln.phase |
|||
} |
|||
|
|||
// negotiateParams processes operational parameter negotiation.
|
|||
func (ln *LoginNegotiator) negotiateParams(req *Params, resp *Params) { |
|||
req.Each(func(key, value string) { |
|||
switch key { |
|||
case "MaxRecvDataSegmentLength": |
|||
if v, err := NegotiateNumber(value, ln.config.MaxRecvDataSegmentLength, 512, 16777215); err == nil { |
|||
ln.NegMaxRecvDataSegLen = v |
|||
resp.Set(key, strconv.Itoa(ln.config.MaxRecvDataSegmentLength)) |
|||
} |
|||
case "MaxBurstLength": |
|||
if v, err := NegotiateNumber(value, ln.config.MaxBurstLength, 512, 16777215); err == nil { |
|||
ln.NegMaxBurstLength = v |
|||
resp.Set(key, strconv.Itoa(v)) |
|||
} |
|||
case "FirstBurstLength": |
|||
if v, err := NegotiateNumber(value, ln.config.FirstBurstLength, 512, 16777215); err == nil { |
|||
ln.NegFirstBurstLength = v |
|||
resp.Set(key, strconv.Itoa(v)) |
|||
} |
|||
case "InitialR2T": |
|||
if v, err := NegotiateBool(value, ln.config.InitialR2T); err == nil { |
|||
// InitialR2T uses OR semantics: result is Yes if either side says Yes
|
|||
if value == "Yes" || ln.config.InitialR2T { |
|||
ln.NegInitialR2T = true |
|||
} else { |
|||
ln.NegInitialR2T = v |
|||
} |
|||
resp.Set(key, BoolStr(ln.NegInitialR2T)) |
|||
} |
|||
case "ImmediateData": |
|||
if v, err := NegotiateBool(value, ln.config.ImmediateData); err == nil { |
|||
ln.NegImmediateData = v |
|||
resp.Set(key, BoolStr(v)) |
|||
} |
|||
case "MaxConnections": |
|||
resp.Set(key, strconv.Itoa(ln.config.MaxConnections)) |
|||
case "DataPDUInOrder": |
|||
resp.Set(key, BoolStr(ln.config.DataPDUInOrder)) |
|||
case "DataSequenceInOrder": |
|||
resp.Set(key, BoolStr(ln.config.DataSequenceInOrder)) |
|||
case "DefaultTime2Wait": |
|||
resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Wait)) |
|||
case "DefaultTime2Retain": |
|||
resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Retain)) |
|||
case "MaxOutstandingR2T": |
|||
resp.Set(key, strconv.Itoa(ln.config.MaxOutstandingR2T)) |
|||
case "ErrorRecoveryLevel": |
|||
resp.Set(key, strconv.Itoa(ln.config.ErrorRecoveryLevel)) |
|||
case "HeaderDigest": |
|||
resp.Set(key, "None") |
|||
case "DataDigest": |
|||
resp.Set(key, "None") |
|||
case "TargetName", "InitiatorName", "SessionType", "AuthMethod": |
|||
// Already handled or declarative — skip
|
|||
case "TargetAlias": |
|||
// Informational from initiator — skip
|
|||
default: |
|||
// Unknown keys: respond with NotUnderstood
|
|||
resp.Set(key, "NotUnderstood") |
|||
} |
|||
}) |
|||
|
|||
// Always declare our TargetAlias if configured
|
|||
if ln.config.TargetAlias != "" { |
|||
resp.Set("TargetAlias", ln.config.TargetAlias) |
|||
} |
|||
} |
|||
|
|||
// TargetResolver allows the login state machine to check if a target exists.
|
|||
type TargetResolver interface { |
|||
HasTarget(name string) bool |
|||
} |
|||
|
|||
func setLoginReject(resp *PDU, class, detail uint8) { |
|||
resp.SetLoginStatus(class, detail) |
|||
resp.SetLoginTransit(false) |
|||
} |
|||
|
|||
// LoginResult contains the outcome of a completed login negotiation.
|
|||
type LoginResult struct { |
|||
InitiatorName string |
|||
TargetName string |
|||
SessionType string |
|||
ISID [6]byte |
|||
TSIH uint16 |
|||
MaxRecvDataSegLen int |
|||
MaxBurstLength int |
|||
FirstBurstLength int |
|||
InitialR2T bool |
|||
ImmediateData bool |
|||
} |
|||
|
|||
// Result returns the negotiation outcome. Only valid after Done() returns true.
|
|||
func (ln *LoginNegotiator) Result() LoginResult { |
|||
return LoginResult{ |
|||
InitiatorName: ln.InitiatorName, |
|||
TargetName: ln.TargetName, |
|||
SessionType: ln.SessionType, |
|||
ISID: ln.isid, |
|||
TSIH: ln.tsih, |
|||
MaxRecvDataSegLen: ln.NegMaxRecvDataSegLen, |
|||
MaxBurstLength: ln.NegMaxBurstLength, |
|||
FirstBurstLength: ln.NegFirstBurstLength, |
|||
InitialR2T: ln.NegInitialR2T, |
|||
ImmediateData: ln.NegImmediateData, |
|||
} |
|||
} |
|||
|
|||
// BuildRedirectResponse creates a login response that redirects the initiator
|
|||
// to a different target address.
|
|||
func BuildRedirectResponse(req *PDU, addr string, permanent bool) *PDU { |
|||
resp := &PDU{} |
|||
resp.SetOpcode(OpLoginResp) |
|||
resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) |
|||
resp.SetISID(req.ISID()) |
|||
|
|||
detail := LoginDetailTargetMovedTemp |
|||
if permanent { |
|||
detail = LoginDetailTargetMoved |
|||
} |
|||
resp.SetLoginStatus(LoginStatusRedirect, detail) |
|||
|
|||
params := NewParams() |
|||
params.Set("TargetAddress", addr) |
|||
resp.DataSegment = params.Encode() |
|||
|
|||
return resp |
|||
} |
|||
@ -0,0 +1,444 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"testing" |
|||
) |
|||
|
|||
// mockResolver implements TargetResolver for tests.
|
|||
type mockResolver struct { |
|||
targets map[string]bool |
|||
} |
|||
|
|||
func (m *mockResolver) HasTarget(name string) bool { |
|||
return m.targets[name] |
|||
} |
|||
|
|||
func newResolver(names ...string) *mockResolver { |
|||
r := &mockResolver{targets: make(map[string]bool)} |
|||
for _, n := range names { |
|||
r.targets[n] = true |
|||
} |
|||
return r |
|||
} |
|||
|
|||
const testTargetName = "iqn.2024.com.seaweedfs:vol1" |
|||
const testInitiatorName = "iqn.2024.com.test:client1" |
|||
|
|||
func TestLogin(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"single_pdu_security_to_ffp", testSinglePDUSecurityToFFP}, |
|||
{"two_phase_login", testTwoPhaseLogin}, |
|||
{"security_to_loginop_to_ffp", testSecurityToLoginOpToFFP}, |
|||
{"missing_initiator_name", testMissingInitiatorName}, |
|||
{"target_not_found", testTargetNotFound}, |
|||
{"discovery_session_no_target", testDiscoverySessionNoTarget}, |
|||
{"wrong_opcode", testWrongOpcode}, |
|||
{"operational_negotiation", testOperationalNegotiation}, |
|||
{"redirect_permanent", testRedirectPermanent}, |
|||
{"redirect_temporary", testRedirectTemporary}, |
|||
{"login_result", testLoginResult}, |
|||
{"unknown_key_not_understood", testUnknownKeyNotUnderstood}, |
|||
{"header_data_digest_none", testHeaderDataDigestNone}, |
|||
{"duplicate_login_pdu", testDuplicateLoginPDU}, |
|||
{"invalid_csg", testInvalidCSG}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func makeLoginReq(csg, nsg uint8, transit bool, params *Params) *PDU { |
|||
p := &PDU{} |
|||
p.SetOpcode(OpLoginReq) |
|||
p.SetLoginStages(csg, nsg) |
|||
if transit { |
|||
p.SetLoginTransit(true) |
|||
} |
|||
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} |
|||
p.SetISID(isid) |
|||
if params != nil { |
|||
p.DataSegment = params.Encode() |
|||
} |
|||
return p |
|||
} |
|||
|
|||
func testSinglePDUSecurityToFFP(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
params.Set("SessionType", "Normal") |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
resp := ln.HandleLoginPDU(req, resolver) |
|||
|
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("status class: %d", resp.LoginStatusClass()) |
|||
} |
|||
if !resp.LoginTransit() { |
|||
t.Fatal("transit should be set") |
|||
} |
|||
if !ln.Done() { |
|||
t.Fatal("should be done") |
|||
} |
|||
if resp.TSIH() == 0 { |
|||
t.Fatal("TSIH should be assigned") |
|||
} |
|||
|
|||
// Verify AuthMethod=None in response
|
|||
rp, err := ParseParams(resp.DataSegment) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if v, ok := rp.Get("AuthMethod"); !ok || v != "None" { |
|||
t.Fatalf("AuthMethod: %q, %v", v, ok) |
|||
} |
|||
} |
|||
|
|||
func testTwoPhaseLogin(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
// Phase 1: Security -> LoginOp
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
params.Set("SessionType", "Normal") |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
|||
resp := ln.HandleLoginPDU(req, resolver) |
|||
|
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("phase 1 status: %d", resp.LoginStatusClass()) |
|||
} |
|||
if ln.Done() { |
|||
t.Fatal("should not be done yet") |
|||
} |
|||
|
|||
// Phase 2: LoginOp -> FFP
|
|||
params2 := NewParams() |
|||
params2.Set("MaxRecvDataSegmentLength", "65536") |
|||
|
|||
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
|||
resp2 := ln.HandleLoginPDU(req2, resolver) |
|||
|
|||
if resp2.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("phase 2 status: %d", resp2.LoginStatusClass()) |
|||
} |
|||
if !ln.Done() { |
|||
t.Fatal("should be done") |
|||
} |
|||
} |
|||
|
|||
func testSecurityToLoginOpToFFP(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
// Security (no transit)
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params) |
|||
resp := ln.HandleLoginPDU(req, resolver) |
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("status: %d", resp.LoginStatusClass()) |
|||
} |
|||
|
|||
// Security -> LoginOp (transit)
|
|||
req2 := makeLoginReq(StageSecurityNeg, StageLoginOp, true, NewParams()) |
|||
resp2 := ln.HandleLoginPDU(req2, resolver) |
|||
if resp2.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("status: %d", resp2.LoginStatusClass()) |
|||
} |
|||
|
|||
// LoginOp -> FFP (transit)
|
|||
req3 := makeLoginReq(StageLoginOp, StageFullFeature, true, NewParams()) |
|||
resp3 := ln.HandleLoginPDU(req3, resolver) |
|||
if resp3.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("status: %d", resp3.LoginStatusClass()) |
|||
} |
|||
if !ln.Done() { |
|||
t.Fatal("should be done") |
|||
} |
|||
} |
|||
|
|||
func testMissingInitiatorName(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
// No InitiatorName
|
|||
params := NewParams() |
|||
params.Set("TargetName", testTargetName) |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
resp := ln.HandleLoginPDU(req, resolver) |
|||
|
|||
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
|||
t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass()) |
|||
} |
|||
if resp.LoginStatusDetail() != LoginDetailMissingParam { |
|||
t.Fatalf("expected missing param, got %d", resp.LoginStatusDetail()) |
|||
} |
|||
} |
|||
|
|||
func testTargetNotFound(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver() // empty
|
|||
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", "iqn.2024.com.seaweedfs:nonexistent") |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
resp := ln.HandleLoginPDU(req, resolver) |
|||
|
|||
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
|||
t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass()) |
|||
} |
|||
if resp.LoginStatusDetail() != LoginDetailNotFound { |
|||
t.Fatalf("expected not found, got %d", resp.LoginStatusDetail()) |
|||
} |
|||
} |
|||
|
|||
func testDiscoverySessionNoTarget(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("SessionType", "Discovery") |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
resp := ln.HandleLoginPDU(req, nil) // no resolver needed
|
|||
|
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("status: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
|||
} |
|||
if !ln.Done() { |
|||
t.Fatal("should be done") |
|||
} |
|||
if ln.SessionType != "Discovery" { |
|||
t.Fatalf("session type: %q", ln.SessionType) |
|||
} |
|||
} |
|||
|
|||
func testWrongOpcode(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
|
|||
req := &PDU{} |
|||
req.SetOpcode(OpSCSICmd) // wrong
|
|||
resp := ln.HandleLoginPDU(req, nil) |
|||
|
|||
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
|||
t.Fatalf("expected error, got %d", resp.LoginStatusClass()) |
|||
} |
|||
} |
|||
|
|||
func testOperationalNegotiation(t *testing.T) { |
|||
config := DefaultTargetConfig() |
|||
config.TargetAlias = "test-target" |
|||
ln := NewLoginNegotiator(config) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
// First: security phase
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
|||
ln.HandleLoginPDU(req, resolver) |
|||
|
|||
// Second: operational negotiation
|
|||
params2 := NewParams() |
|||
params2.Set("MaxRecvDataSegmentLength", "65536") |
|||
params2.Set("MaxBurstLength", "131072") |
|||
params2.Set("FirstBurstLength", "32768") |
|||
params2.Set("InitialR2T", "Yes") |
|||
params2.Set("ImmediateData", "No") |
|||
params2.Set("HeaderDigest", "CRC32C,None") |
|||
params2.Set("DataDigest", "CRC32C,None") |
|||
|
|||
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
|||
resp := ln.HandleLoginPDU(req2, resolver) |
|||
|
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("status: %d", resp.LoginStatusClass()) |
|||
} |
|||
|
|||
rp, err := ParseParams(resp.DataSegment) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
// Verify negotiated values
|
|||
if v, ok := rp.Get("HeaderDigest"); !ok || v != "None" { |
|||
t.Fatalf("HeaderDigest: %q", v) |
|||
} |
|||
if v, ok := rp.Get("DataDigest"); !ok || v != "None" { |
|||
t.Fatalf("DataDigest: %q", v) |
|||
} |
|||
if v, ok := rp.Get("InitialR2T"); !ok || v != "Yes" { |
|||
t.Fatalf("InitialR2T: %q", v) |
|||
} |
|||
if v, ok := rp.Get("ImmediateData"); !ok || v != "No" { |
|||
t.Fatalf("ImmediateData: %q", v) |
|||
} |
|||
if v, ok := rp.Get("TargetAlias"); !ok || v != "test-target" { |
|||
t.Fatalf("TargetAlias: %q", v) |
|||
} |
|||
|
|||
// Check negotiated state
|
|||
if ln.NegInitialR2T != true { |
|||
t.Fatal("NegInitialR2T should be true") |
|||
} |
|||
if ln.NegImmediateData != false { |
|||
t.Fatal("NegImmediateData should be false") |
|||
} |
|||
} |
|||
|
|||
func testRedirectPermanent(t *testing.T) { |
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams()) |
|||
resp := BuildRedirectResponse(req, "10.0.0.1:3260,1", true) |
|||
|
|||
if resp.LoginStatusClass() != LoginStatusRedirect { |
|||
t.Fatalf("class: %d", resp.LoginStatusClass()) |
|||
} |
|||
if resp.LoginStatusDetail() != LoginDetailTargetMoved { |
|||
t.Fatalf("detail: %d", resp.LoginStatusDetail()) |
|||
} |
|||
|
|||
rp, err := ParseParams(resp.DataSegment) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if v, _ := rp.Get("TargetAddress"); v != "10.0.0.1:3260,1" { |
|||
t.Fatalf("TargetAddress: %q", v) |
|||
} |
|||
} |
|||
|
|||
func testRedirectTemporary(t *testing.T) { |
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams()) |
|||
resp := BuildRedirectResponse(req, "192.168.1.100:3260,2", false) |
|||
|
|||
if resp.LoginStatusDetail() != LoginDetailTargetMovedTemp { |
|||
t.Fatalf("detail: %d", resp.LoginStatusDetail()) |
|||
} |
|||
} |
|||
|
|||
func testLoginResult(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
ln.HandleLoginPDU(req, resolver) |
|||
|
|||
result := ln.Result() |
|||
if result.InitiatorName != testInitiatorName { |
|||
t.Fatalf("initiator: %q", result.InitiatorName) |
|||
} |
|||
if result.SessionType != "Normal" { |
|||
t.Fatalf("session type: %q", result.SessionType) |
|||
} |
|||
if result.TSIH == 0 { |
|||
t.Fatal("TSIH should be assigned") |
|||
} |
|||
} |
|||
|
|||
func testUnknownKeyNotUnderstood(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
// Security phase first
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
|||
ln.HandleLoginPDU(req, resolver) |
|||
|
|||
// Op phase with unknown key
|
|||
params2 := NewParams() |
|||
params2.Set("X-CustomKey", "whatever") |
|||
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
|||
resp := ln.HandleLoginPDU(req2, resolver) |
|||
|
|||
rp, err := ParseParams(resp.DataSegment) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if v, ok := rp.Get("X-CustomKey"); !ok || v != "NotUnderstood" { |
|||
t.Fatalf("expected NotUnderstood, got %q, %v", v, ok) |
|||
} |
|||
} |
|||
|
|||
func testHeaderDataDigestNone(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
|||
ln.HandleLoginPDU(req, resolver) |
|||
|
|||
params2 := NewParams() |
|||
params2.Set("HeaderDigest", "CRC32C") |
|||
params2.Set("DataDigest", "CRC32C") |
|||
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
|||
resp := ln.HandleLoginPDU(req2, resolver) |
|||
|
|||
rp, _ := ParseParams(resp.DataSegment) |
|||
if v, _ := rp.Get("HeaderDigest"); v != "None" { |
|||
t.Fatalf("HeaderDigest: %q (we always negotiate None)", v) |
|||
} |
|||
if v, _ := rp.Get("DataDigest"); v != "None" { |
|||
t.Fatalf("DataDigest: %q", v) |
|||
} |
|||
} |
|||
|
|||
func testDuplicateLoginPDU(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
resolver := newResolver(testTargetName) |
|||
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
|
|||
// Send same security PDU twice (no transit)
|
|||
req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params) |
|||
resp1 := ln.HandleLoginPDU(req, resolver) |
|||
if resp1.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("first: %d", resp1.LoginStatusClass()) |
|||
} |
|||
|
|||
// Same stage again should still work
|
|||
resp2 := ln.HandleLoginPDU(req, resolver) |
|||
if resp2.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("second: %d", resp2.LoginStatusClass()) |
|||
} |
|||
} |
|||
|
|||
func testInvalidCSG(t *testing.T) { |
|||
ln := NewLoginNegotiator(DefaultTargetConfig()) |
|||
|
|||
// CSG=FullFeature (3) is invalid as a current stage
|
|||
req := &PDU{} |
|||
req.SetOpcode(OpLoginReq) |
|||
req.SetLoginStages(StageFullFeature, StageFullFeature) |
|||
req.SetLoginTransit(true) |
|||
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} |
|||
req.SetISID(isid) |
|||
|
|||
resp := ln.HandleLoginPDU(req, nil) |
|||
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
|||
t.Fatalf("expected error for invalid CSG, got %d", resp.LoginStatusClass()) |
|||
} |
|||
} |
|||
@ -0,0 +1,183 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"strconv" |
|||
"strings" |
|||
) |
|||
|
|||
// iSCSI key-value text parameters (RFC 7143, Section 6).
|
|||
// Parameters are encoded as "Key=Value\0" in the data segment.
|
|||
|
|||
var ( |
|||
ErrMalformedParam = errors.New("iscsi: malformed parameter (missing '=')") |
|||
ErrEmptyKey = errors.New("iscsi: empty parameter key") |
|||
ErrDuplicateKey = errors.New("iscsi: duplicate parameter key") |
|||
) |
|||
|
|||
// Params is an ordered list of iSCSI key-value parameters.
|
|||
// Order matters for negotiation, so we use a slice rather than a map.
|
|||
type Params struct { |
|||
items []paramItem |
|||
} |
|||
|
|||
type paramItem struct { |
|||
key string |
|||
value string |
|||
} |
|||
|
|||
// ParseParams decodes "Key=Value\0Key=Value\0..." from raw bytes.
|
|||
// Empty trailing segments (from a trailing \0) are ignored.
|
|||
// Returns ErrDuplicateKey if the same key appears more than once.
|
|||
func ParseParams(data []byte) (*Params, error) { |
|||
p := &Params{} |
|||
if len(data) == 0 { |
|||
return p, nil |
|||
} |
|||
|
|||
seen := make(map[string]bool) |
|||
s := string(data) |
|||
|
|||
// Split by null separator
|
|||
parts := strings.Split(s, "\x00") |
|||
for _, part := range parts { |
|||
if part == "" { |
|||
continue // trailing null or empty
|
|||
} |
|||
idx := strings.IndexByte(part, '=') |
|||
if idx < 0 { |
|||
return nil, fmt.Errorf("%w: %q", ErrMalformedParam, part) |
|||
} |
|||
key := part[:idx] |
|||
if key == "" { |
|||
return nil, ErrEmptyKey |
|||
} |
|||
value := part[idx+1:] |
|||
|
|||
if seen[key] { |
|||
return nil, fmt.Errorf("%w: %q", ErrDuplicateKey, key) |
|||
} |
|||
seen[key] = true |
|||
p.items = append(p.items, paramItem{key: key, value: value}) |
|||
} |
|||
|
|||
return p, nil |
|||
} |
|||
|
|||
// Encode serializes parameters to "Key=Value\0" format.
|
|||
func (p *Params) Encode() []byte { |
|||
if len(p.items) == 0 { |
|||
return nil |
|||
} |
|||
var b strings.Builder |
|||
for _, item := range p.items { |
|||
b.WriteString(item.key) |
|||
b.WriteByte('=') |
|||
b.WriteString(item.value) |
|||
b.WriteByte(0) |
|||
} |
|||
return []byte(b.String()) |
|||
} |
|||
|
|||
// Get returns the value for a key, or ("", false) if not present.
|
|||
func (p *Params) Get(key string) (string, bool) { |
|||
for _, item := range p.items { |
|||
if item.key == key { |
|||
return item.value, true |
|||
} |
|||
} |
|||
return "", false |
|||
} |
|||
|
|||
// Set adds or replaces a parameter. If the key already exists, its value
|
|||
// is updated in place. Otherwise, the key is appended.
|
|||
func (p *Params) Set(key, value string) { |
|||
for i, item := range p.items { |
|||
if item.key == key { |
|||
p.items[i].value = value |
|||
return |
|||
} |
|||
} |
|||
p.items = append(p.items, paramItem{key: key, value: value}) |
|||
} |
|||
|
|||
// Del removes a key if present.
|
|||
func (p *Params) Del(key string) { |
|||
for i, item := range p.items { |
|||
if item.key == key { |
|||
p.items = append(p.items[:i], p.items[i+1:]...) |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Keys returns all keys in order.
|
|||
func (p *Params) Keys() []string { |
|||
keys := make([]string, len(p.items)) |
|||
for i, item := range p.items { |
|||
keys[i] = item.key |
|||
} |
|||
return keys |
|||
} |
|||
|
|||
// Len returns the number of parameters.
|
|||
func (p *Params) Len() int { return len(p.items) } |
|||
|
|||
// Each iterates over all key-value pairs in order.
|
|||
func (p *Params) Each(fn func(key, value string)) { |
|||
for _, item := range p.items { |
|||
fn(item.key, item.value) |
|||
} |
|||
} |
|||
|
|||
// --- Negotiation helpers ---
|
|||
|
|||
// NegotiateNumber applies the iSCSI numeric negotiation rule:
|
|||
// for "min" semantics (e.g., MaxRecvDataSegmentLength), return min(offer, ours).
|
|||
// For "max" semantics (e.g., MaxBurstLength), return min(offer, ours).
|
|||
// Both directions clamp to the smaller value, so min() is the general rule.
|
|||
func NegotiateNumber(offered string, ours int, min, max int) (int, error) { |
|||
v, err := strconv.Atoi(offered) |
|||
if err != nil { |
|||
return 0, fmt.Errorf("iscsi: invalid numeric value %q: %w", offered, err) |
|||
} |
|||
// Clamp to valid range
|
|||
if v < min { |
|||
v = min |
|||
} |
|||
if v > max { |
|||
v = max |
|||
} |
|||
// Standard negotiation: result is min(offer, ours)
|
|||
if ours < v { |
|||
return ours, nil |
|||
} |
|||
return v, nil |
|||
} |
|||
|
|||
// NegotiateBool applies boolean negotiation (AND semantics per RFC 7143).
|
|||
// Result is true only if both sides agree to true.
|
|||
func NegotiateBool(offered string, ours bool) (bool, error) { |
|||
switch offered { |
|||
case "Yes": |
|||
return ours, nil |
|||
case "No": |
|||
return false, nil |
|||
default: |
|||
return false, fmt.Errorf("iscsi: invalid boolean value %q", offered) |
|||
} |
|||
} |
|||
|
|||
// BoolStr returns "Yes" or "No" for a boolean value.
|
|||
func BoolStr(v bool) string { |
|||
if v { |
|||
return "Yes" |
|||
} |
|||
return "No" |
|||
} |
|||
|
|||
// NewParams creates a new empty Params.
|
|||
func NewParams() *Params { |
|||
return &Params{} |
|||
} |
|||
@ -0,0 +1,357 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"bytes" |
|||
"strings" |
|||
"testing" |
|||
) |
|||
|
|||
func TestParams(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"parse_single", testParseSingle}, |
|||
{"parse_multiple", testParseMultiple}, |
|||
{"parse_empty", testParseEmpty}, |
|||
{"parse_trailing_null", testParseTrailingNull}, |
|||
{"parse_malformed_no_equals", testParseMalformedNoEquals}, |
|||
{"parse_empty_key", testParseEmptyKey}, |
|||
{"parse_empty_value", testParseEmptyValue}, |
|||
{"parse_duplicate_key", testParseDuplicateKey}, |
|||
{"roundtrip", testParamsRoundtrip}, |
|||
{"encode_empty", testEncodeEmpty}, |
|||
{"set_new_key", testSetNewKey}, |
|||
{"set_existing_key", testSetExistingKey}, |
|||
{"del_key", testDelKey}, |
|||
{"del_nonexistent", testDelNonexistent}, |
|||
{"keys_order", testKeysOrder}, |
|||
{"each", testEach}, |
|||
{"negotiate_number_min", testNegotiateNumberMin}, |
|||
{"negotiate_number_clamp", testNegotiateNumberClamp}, |
|||
{"negotiate_number_invalid", testNegotiateNumberInvalid}, |
|||
{"negotiate_bool_and", testNegotiateBoolAnd}, |
|||
{"negotiate_bool_invalid", testNegotiateBoolInvalid}, |
|||
{"bool_str", testBoolStr}, |
|||
{"value_with_equals", testValueWithEquals}, |
|||
{"parse_binary_data_value", testParseBinaryDataValue}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testParseSingle(t *testing.T) { |
|||
data := []byte("TargetName=iqn.2024.com.seaweedfs:vol1\x00") |
|||
p, err := ParseParams(data) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if p.Len() != 1 { |
|||
t.Fatalf("expected 1 param, got %d", p.Len()) |
|||
} |
|||
v, ok := p.Get("TargetName") |
|||
if !ok || v != "iqn.2024.com.seaweedfs:vol1" { |
|||
t.Fatalf("got %q, %v", v, ok) |
|||
} |
|||
} |
|||
|
|||
func testParseMultiple(t *testing.T) { |
|||
data := []byte("InitiatorName=iqn.2024.com.test\x00TargetName=iqn.2024.com.seaweedfs:vol1\x00SessionType=Normal\x00") |
|||
p, err := ParseParams(data) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if p.Len() != 3 { |
|||
t.Fatalf("expected 3 params, got %d", p.Len()) |
|||
} |
|||
keys := p.Keys() |
|||
if keys[0] != "InitiatorName" || keys[1] != "TargetName" || keys[2] != "SessionType" { |
|||
t.Fatalf("wrong order: %v", keys) |
|||
} |
|||
} |
|||
|
|||
func testParseEmpty(t *testing.T) { |
|||
p, err := ParseParams(nil) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if p.Len() != 0 { |
|||
t.Fatalf("expected 0, got %d", p.Len()) |
|||
} |
|||
|
|||
p2, err := ParseParams([]byte{}) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if p2.Len() != 0 { |
|||
t.Fatalf("expected 0, got %d", p2.Len()) |
|||
} |
|||
} |
|||
|
|||
func testParseTrailingNull(t *testing.T) { |
|||
// Multiple trailing nulls should not cause errors
|
|||
data := []byte("Key=Value\x00\x00\x00") |
|||
p, err := ParseParams(data) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if p.Len() != 1 { |
|||
t.Fatalf("expected 1, got %d", p.Len()) |
|||
} |
|||
} |
|||
|
|||
func testParseMalformedNoEquals(t *testing.T) { |
|||
data := []byte("KeyWithoutValue\x00") |
|||
_, err := ParseParams(data) |
|||
if err == nil { |
|||
t.Fatal("expected error") |
|||
} |
|||
if !strings.Contains(err.Error(), "malformed") { |
|||
t.Fatalf("unexpected error: %v", err) |
|||
} |
|||
} |
|||
|
|||
func testParseEmptyKey(t *testing.T) { |
|||
data := []byte("=Value\x00") |
|||
_, err := ParseParams(data) |
|||
if err != ErrEmptyKey { |
|||
t.Fatalf("expected ErrEmptyKey, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testParseEmptyValue(t *testing.T) { |
|||
// Empty value is valid in iSCSI (e.g., reject with empty value)
|
|||
data := []byte("Key=\x00") |
|||
p, err := ParseParams(data) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
v, ok := p.Get("Key") |
|||
if !ok || v != "" { |
|||
t.Fatalf("got %q, %v", v, ok) |
|||
} |
|||
} |
|||
|
|||
func testParseDuplicateKey(t *testing.T) { |
|||
data := []byte("Key=Value1\x00Key=Value2\x00") |
|||
_, err := ParseParams(data) |
|||
if err == nil { |
|||
t.Fatal("expected error") |
|||
} |
|||
if !strings.Contains(err.Error(), "duplicate") { |
|||
t.Fatalf("unexpected error: %v", err) |
|||
} |
|||
} |
|||
|
|||
func testParamsRoundtrip(t *testing.T) { |
|||
original := []byte("InitiatorName=iqn.2024.com.test\x00MaxRecvDataSegmentLength=65536\x00") |
|||
p, err := ParseParams(original) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
encoded := p.Encode() |
|||
if !bytes.Equal(encoded, original) { |
|||
t.Fatalf("roundtrip mismatch:\n got: %q\n want: %q", encoded, original) |
|||
} |
|||
} |
|||
|
|||
func testEncodeEmpty(t *testing.T) { |
|||
p := NewParams() |
|||
if encoded := p.Encode(); encoded != nil { |
|||
t.Fatalf("expected nil, got %q", encoded) |
|||
} |
|||
} |
|||
|
|||
func testSetNewKey(t *testing.T) { |
|||
p := NewParams() |
|||
p.Set("Key1", "Val1") |
|||
p.Set("Key2", "Val2") |
|||
if p.Len() != 2 { |
|||
t.Fatalf("expected 2, got %d", p.Len()) |
|||
} |
|||
v, _ := p.Get("Key2") |
|||
if v != "Val2" { |
|||
t.Fatalf("got %q", v) |
|||
} |
|||
} |
|||
|
|||
func testSetExistingKey(t *testing.T) { |
|||
p := NewParams() |
|||
p.Set("Key", "Old") |
|||
p.Set("Key", "New") |
|||
if p.Len() != 1 { |
|||
t.Fatalf("expected 1, got %d", p.Len()) |
|||
} |
|||
v, _ := p.Get("Key") |
|||
if v != "New" { |
|||
t.Fatalf("got %q", v) |
|||
} |
|||
} |
|||
|
|||
func testDelKey(t *testing.T) { |
|||
p := NewParams() |
|||
p.Set("A", "1") |
|||
p.Set("B", "2") |
|||
p.Set("C", "3") |
|||
p.Del("B") |
|||
if p.Len() != 2 { |
|||
t.Fatalf("expected 2, got %d", p.Len()) |
|||
} |
|||
if _, ok := p.Get("B"); ok { |
|||
t.Fatal("B should be deleted") |
|||
} |
|||
keys := p.Keys() |
|||
if keys[0] != "A" || keys[1] != "C" { |
|||
t.Fatalf("wrong keys: %v", keys) |
|||
} |
|||
} |
|||
|
|||
func testDelNonexistent(t *testing.T) { |
|||
p := NewParams() |
|||
p.Set("A", "1") |
|||
p.Del("Z") // should not panic
|
|||
if p.Len() != 1 { |
|||
t.Fatalf("expected 1, got %d", p.Len()) |
|||
} |
|||
} |
|||
|
|||
func testKeysOrder(t *testing.T) { |
|||
data := []byte("Z=1\x00A=2\x00M=3\x00") |
|||
p, err := ParseParams(data) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
keys := p.Keys() |
|||
if keys[0] != "Z" || keys[1] != "A" || keys[2] != "M" { |
|||
t.Fatalf("order not preserved: %v", keys) |
|||
} |
|||
} |
|||
|
|||
func testEach(t *testing.T) { |
|||
p := NewParams() |
|||
p.Set("A", "1") |
|||
p.Set("B", "2") |
|||
var collected []string |
|||
p.Each(func(k, v string) { |
|||
collected = append(collected, k+"="+v) |
|||
}) |
|||
if len(collected) != 2 || collected[0] != "A=1" || collected[1] != "B=2" { |
|||
t.Fatalf("Each: %v", collected) |
|||
} |
|||
} |
|||
|
|||
func testNegotiateNumberMin(t *testing.T) { |
|||
// Both sides offer a value, result is min
|
|||
result, err := NegotiateNumber("65536", 262144, 512, 16777215) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if result != 65536 { |
|||
t.Fatalf("expected 65536, got %d", result) |
|||
} |
|||
|
|||
// Our value is smaller
|
|||
result, err = NegotiateNumber("262144", 65536, 512, 16777215) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if result != 65536 { |
|||
t.Fatalf("expected 65536, got %d", result) |
|||
} |
|||
} |
|||
|
|||
func testNegotiateNumberClamp(t *testing.T) { |
|||
// Offered value below minimum
|
|||
result, err := NegotiateNumber("100", 65536, 512, 16777215) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if result != 512 { |
|||
t.Fatalf("expected 512, got %d", result) |
|||
} |
|||
|
|||
// Offered value above maximum
|
|||
result, err = NegotiateNumber("99999999", 65536, 512, 16777215) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if result != 65536 { |
|||
t.Fatalf("expected 65536, got %d", result) |
|||
} |
|||
} |
|||
|
|||
func testNegotiateNumberInvalid(t *testing.T) { |
|||
_, err := NegotiateNumber("notanumber", 65536, 512, 16777215) |
|||
if err == nil { |
|||
t.Fatal("expected error") |
|||
} |
|||
} |
|||
|
|||
func testNegotiateBoolAnd(t *testing.T) { |
|||
// Both yes
|
|||
r, err := NegotiateBool("Yes", true) |
|||
if err != nil || !r { |
|||
t.Fatalf("Yes+true = %v, %v", r, err) |
|||
} |
|||
// Initiator yes, target no
|
|||
r, err = NegotiateBool("Yes", false) |
|||
if err != nil || r { |
|||
t.Fatalf("Yes+false = %v, %v", r, err) |
|||
} |
|||
// Initiator no, target yes
|
|||
r, err = NegotiateBool("No", true) |
|||
if err != nil || r { |
|||
t.Fatalf("No+true = %v, %v", r, err) |
|||
} |
|||
// Both no
|
|||
r, err = NegotiateBool("No", false) |
|||
if err != nil || r { |
|||
t.Fatalf("No+false = %v, %v", r, err) |
|||
} |
|||
} |
|||
|
|||
func testNegotiateBoolInvalid(t *testing.T) { |
|||
_, err := NegotiateBool("Maybe", true) |
|||
if err == nil { |
|||
t.Fatal("expected error") |
|||
} |
|||
} |
|||
|
|||
func testBoolStr(t *testing.T) { |
|||
if BoolStr(true) != "Yes" { |
|||
t.Fatal("true should be Yes") |
|||
} |
|||
if BoolStr(false) != "No" { |
|||
t.Fatal("false should be No") |
|||
} |
|||
} |
|||
|
|||
func testValueWithEquals(t *testing.T) { |
|||
// Value containing '=' is valid (only first '=' splits key from value)
|
|||
data := []byte("TargetAddress=10.0.0.1:3260,1\x00Key=a=b=c\x00") |
|||
p, err := ParseParams(data) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
v, ok := p.Get("Key") |
|||
if !ok || v != "a=b=c" { |
|||
t.Fatalf("got %q, %v", v, ok) |
|||
} |
|||
} |
|||
|
|||
func testParseBinaryDataValue(t *testing.T) { |
|||
// Values can contain arbitrary bytes except null
|
|||
data := []byte("Key=\x01\x02\x03\x00") |
|||
p, err := ParseParams(data) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
v, ok := p.Get("Key") |
|||
if !ok || v != "\x01\x02\x03" { |
|||
t.Fatalf("got %q, %v", v, ok) |
|||
} |
|||
} |
|||
@ -0,0 +1,447 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
) |
|||
|
|||
// BHS (Basic Header Segment) is 48 bytes, big-endian.
|
|||
// RFC 7143, Section 12.1
|
|||
const BHSLength = 48 |
|||
|
|||
// Opcode constants — initiator opcodes (RFC 7143, Section 12.1.1)
|
|||
const ( |
|||
OpNOPOut uint8 = 0x00 |
|||
OpSCSICmd uint8 = 0x01 |
|||
OpSCSITaskMgmt uint8 = 0x02 |
|||
OpLoginReq uint8 = 0x03 |
|||
OpTextReq uint8 = 0x04 |
|||
OpSCSIDataOut uint8 = 0x05 |
|||
OpLogoutReq uint8 = 0x06 |
|||
OpSNACKReq uint8 = 0x0c |
|||
) |
|||
|
|||
// Target opcodes (RFC 7143, Section 12.1.1)
|
|||
const ( |
|||
OpNOPIn uint8 = 0x20 |
|||
OpSCSIResp uint8 = 0x21 |
|||
OpSCSITaskResp uint8 = 0x22 |
|||
OpLoginResp uint8 = 0x23 |
|||
OpTextResp uint8 = 0x24 |
|||
OpSCSIDataIn uint8 = 0x25 |
|||
OpLogoutResp uint8 = 0x26 |
|||
OpR2T uint8 = 0x31 |
|||
OpAsyncMsg uint8 = 0x32 |
|||
OpReject uint8 = 0x3f |
|||
) |
|||
|
|||
// BHS flag masks
|
|||
const ( |
|||
opcMask = 0x3f // lower 6 bits of byte 0
|
|||
FlagI = 0x40 // Immediate delivery bit (byte 0)
|
|||
FlagF = 0x80 // Final bit (byte 1)
|
|||
FlagR = 0x40 // Read bit (byte 1, SCSI command)
|
|||
FlagW = 0x20 // Write bit (byte 1, SCSI command)
|
|||
FlagC = 0x40 // Continue bit (byte 1, login)
|
|||
FlagT = 0x80 // Transit bit (byte 1, login)
|
|||
FlagS = 0x01 // Status bit (byte 1, Data-In)
|
|||
FlagU = 0x02 // Underflow bit (byte 1, SCSI Response)
|
|||
FlagO = 0x04 // Overflow bit (byte 1, SCSI Response)
|
|||
FlagBiU = 0x08 // Bidi underflow (byte 1, SCSI Response)
|
|||
FlagBiO = 0x10 // Bidi overflow (byte 1, SCSI Response)
|
|||
FlagA = 0x40 // Acknowledge bit (byte 1, Data-Out)
|
|||
) |
|||
|
|||
// Login stage constants (CSG/NSG, 2-bit fields in byte 1 of login PDU)
|
|||
const ( |
|||
StageSecurityNeg uint8 = 0 // Security Negotiation
|
|||
StageLoginOp uint8 = 1 // Login Operational Negotiation
|
|||
StageFullFeature uint8 = 3 // Full Feature Phase
|
|||
) |
|||
|
|||
// SCSI status codes
|
|||
const ( |
|||
SCSIStatusGood uint8 = 0x00 |
|||
SCSIStatusCheckCond uint8 = 0x02 |
|||
SCSIStatusBusy uint8 = 0x08 |
|||
SCSIStatusResvConflict uint8 = 0x18 |
|||
) |
|||
|
|||
// iSCSI response codes
|
|||
const ( |
|||
ISCSIRespCompleted uint8 = 0x00 |
|||
) |
|||
|
|||
// MaxDataSegmentLength limits the maximum data segment we'll accept.
|
|||
// The iSCSI data segment length is a 3-byte field (max 16MB-1).
|
|||
// We cap at 8MB which is well above typical MaxRecvDataSegmentLength values.
|
|||
const MaxDataSegmentLength = 8 * 1024 * 1024 // 8 MB
|
|||
|
|||
var ( |
|||
ErrPDUTruncated = errors.New("iscsi: PDU truncated") |
|||
ErrPDUTooLarge = errors.New("iscsi: data segment exceeds maximum length") |
|||
ErrInvalidAHSLength = errors.New("iscsi: invalid AHS length (not multiple of 4)") |
|||
ErrUnknownOpcode = errors.New("iscsi: unknown opcode") |
|||
) |
|||
|
|||
// PDU represents a full iSCSI Protocol Data Unit.
|
|||
type PDU struct { |
|||
BHS [BHSLength]byte |
|||
AHS []byte // Additional Header Segment (multiple of 4 bytes)
|
|||
DataSegment []byte // Data segment (padded to 4-byte boundary on wire)
|
|||
} |
|||
|
|||
// --- BHS field accessors ---
|
|||
|
|||
// Opcode returns the opcode (lower 6 bits of byte 0).
|
|||
func (p *PDU) Opcode() uint8 { return p.BHS[0] & opcMask } |
|||
|
|||
// SetOpcode sets the opcode (lower 6 bits of byte 0).
|
|||
func (p *PDU) SetOpcode(op uint8) { |
|||
p.BHS[0] = (p.BHS[0] & ^uint8(opcMask)) | (op & opcMask) |
|||
} |
|||
|
|||
// Immediate returns true if the immediate delivery bit is set.
|
|||
func (p *PDU) Immediate() bool { return p.BHS[0]&FlagI != 0 } |
|||
|
|||
// SetImmediate sets the immediate delivery bit.
|
|||
func (p *PDU) SetImmediate(v bool) { |
|||
if v { |
|||
p.BHS[0] |= FlagI |
|||
} else { |
|||
p.BHS[0] &^= FlagI |
|||
} |
|||
} |
|||
|
|||
// OpSpecific1 returns byte 1 (opcode-specific flags).
|
|||
func (p *PDU) OpSpecific1() uint8 { return p.BHS[1] } |
|||
|
|||
// SetOpSpecific1 sets byte 1.
|
|||
func (p *PDU) SetOpSpecific1(v uint8) { p.BHS[1] = v } |
|||
|
|||
// TotalAHSLength returns the total AHS length in 4-byte words (byte 4).
|
|||
func (p *PDU) TotalAHSLength() uint8 { return p.BHS[4] } |
|||
|
|||
// DataSegmentLength returns the 3-byte data segment length (bytes 5-7).
|
|||
func (p *PDU) DataSegmentLength() uint32 { |
|||
return uint32(p.BHS[5])<<16 | uint32(p.BHS[6])<<8 | uint32(p.BHS[7]) |
|||
} |
|||
|
|||
// SetDataSegmentLength sets the 3-byte data segment length field.
|
|||
func (p *PDU) SetDataSegmentLength(n uint32) { |
|||
p.BHS[5] = byte(n >> 16) |
|||
p.BHS[6] = byte(n >> 8) |
|||
p.BHS[7] = byte(n) |
|||
} |
|||
|
|||
// LUN returns the 8-byte LUN field (bytes 8-15).
|
|||
func (p *PDU) LUN() uint64 { return binary.BigEndian.Uint64(p.BHS[8:16]) } |
|||
|
|||
// SetLUN sets the LUN field.
|
|||
func (p *PDU) SetLUN(lun uint64) { binary.BigEndian.PutUint64(p.BHS[8:16], lun) } |
|||
|
|||
// InitiatorTaskTag returns bytes 16-19.
|
|||
func (p *PDU) InitiatorTaskTag() uint32 { return binary.BigEndian.Uint32(p.BHS[16:20]) } |
|||
|
|||
// SetInitiatorTaskTag sets bytes 16-19.
|
|||
func (p *PDU) SetInitiatorTaskTag(tag uint32) { binary.BigEndian.PutUint32(p.BHS[16:20], tag) } |
|||
|
|||
// Field32 reads a generic 4-byte field at the given BHS offset.
|
|||
func (p *PDU) Field32(offset int) uint32 { return binary.BigEndian.Uint32(p.BHS[offset : offset+4]) } |
|||
|
|||
// SetField32 writes a generic 4-byte field at the given BHS offset.
|
|||
func (p *PDU) SetField32(offset int, v uint32) { |
|||
binary.BigEndian.PutUint32(p.BHS[offset:offset+4], v) |
|||
} |
|||
|
|||
// --- Common BHS offsets for named fields ---
|
|||
|
|||
// TSIH returns the Target Session Identifying Handle (bytes 14-15, login PDU).
|
|||
func (p *PDU) TSIH() uint16 { return binary.BigEndian.Uint16(p.BHS[14:16]) } |
|||
|
|||
// SetTSIH sets the TSIH field.
|
|||
func (p *PDU) SetTSIH(v uint16) { binary.BigEndian.PutUint16(p.BHS[14:16], v) } |
|||
|
|||
// CmdSN returns bytes 24-27.
|
|||
func (p *PDU) CmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) } |
|||
|
|||
// SetCmdSN sets bytes 24-27.
|
|||
func (p *PDU) SetCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) } |
|||
|
|||
// ExpStatSN returns bytes 28-31.
|
|||
func (p *PDU) ExpStatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) } |
|||
|
|||
// SetExpStatSN sets bytes 28-31.
|
|||
func (p *PDU) SetExpStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) } |
|||
|
|||
// StatSN returns bytes 24-27 (target response PDUs).
|
|||
func (p *PDU) StatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) } |
|||
|
|||
// SetStatSN sets bytes 24-27.
|
|||
func (p *PDU) SetStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) } |
|||
|
|||
// ExpCmdSN returns bytes 28-31 (target response PDUs).
|
|||
func (p *PDU) ExpCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) } |
|||
|
|||
// SetExpCmdSN sets bytes 28-31.
|
|||
func (p *PDU) SetExpCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) } |
|||
|
|||
// MaxCmdSN returns bytes 32-35 (target response PDUs).
|
|||
func (p *PDU) MaxCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[32:36]) } |
|||
|
|||
// SetMaxCmdSN sets bytes 32-35.
|
|||
func (p *PDU) SetMaxCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[32:36], v) } |
|||
|
|||
// DataSN returns bytes 36-39 (Data-In / Data-Out PDUs).
|
|||
func (p *PDU) DataSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) } |
|||
|
|||
// SetDataSN sets bytes 36-39.
|
|||
func (p *PDU) SetDataSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) } |
|||
|
|||
// BufferOffset returns bytes 40-43 (Data-In / Data-Out PDUs).
|
|||
func (p *PDU) BufferOffset() uint32 { return binary.BigEndian.Uint32(p.BHS[40:44]) } |
|||
|
|||
// SetBufferOffset sets bytes 40-43.
|
|||
func (p *PDU) SetBufferOffset(v uint32) { binary.BigEndian.PutUint32(p.BHS[40:44], v) } |
|||
|
|||
// R2TSN returns bytes 36-39 (R2T PDUs).
|
|||
func (p *PDU) R2TSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) } |
|||
|
|||
// SetR2TSN sets bytes 36-39.
|
|||
func (p *PDU) SetR2TSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) } |
|||
|
|||
// DesiredDataLength returns bytes 44-47 (R2T PDUs).
|
|||
func (p *PDU) DesiredDataLength() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) } |
|||
|
|||
// SetDesiredDataLength sets bytes 44-47.
|
|||
func (p *PDU) SetDesiredDataLength(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) } |
|||
|
|||
// ExpectedDataTransferLength returns bytes 20-23 (SCSI Command PDUs).
|
|||
func (p *PDU) ExpectedDataTransferLength() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) } |
|||
|
|||
// SetExpectedDataTransferLength sets bytes 20-23.
|
|||
func (p *PDU) SetExpectedDataTransferLength(v uint32) { |
|||
binary.BigEndian.PutUint32(p.BHS[20:24], v) |
|||
} |
|||
|
|||
// TargetTransferTag returns bytes 20-23 (target PDUs: R2T, Data-In, etc.).
|
|||
func (p *PDU) TargetTransferTag() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) } |
|||
|
|||
// SetTargetTransferTag sets bytes 20-23.
|
|||
func (p *PDU) SetTargetTransferTag(v uint32) { binary.BigEndian.PutUint32(p.BHS[20:24], v) } |
|||
|
|||
// ISID returns the 6-byte Initiator Session ID (bytes 8-13, login PDU).
|
|||
func (p *PDU) ISID() [6]byte { |
|||
var id [6]byte |
|||
copy(id[:], p.BHS[8:14]) |
|||
return id |
|||
} |
|||
|
|||
// SetISID sets the 6-byte ISID field.
|
|||
func (p *PDU) SetISID(id [6]byte) { copy(p.BHS[8:14], id[:]) } |
|||
|
|||
// CDB returns the 16-byte SCSI CDB from the BHS (bytes 32-47, SCSI Command PDU).
|
|||
func (p *PDU) CDB() [16]byte { |
|||
var cdb [16]byte |
|||
copy(cdb[:], p.BHS[32:48]) |
|||
return cdb |
|||
} |
|||
|
|||
// SetCDB sets the 16-byte SCSI CDB in the BHS.
|
|||
func (p *PDU) SetCDB(cdb [16]byte) { copy(p.BHS[32:48], cdb[:]) } |
|||
|
|||
// ResidualCount returns bytes 44-47 (SCSI Response PDUs).
|
|||
func (p *PDU) ResidualCount() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) } |
|||
|
|||
// SetResidualCount sets bytes 44-47.
|
|||
func (p *PDU) SetResidualCount(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) } |
|||
|
|||
// --- Wire I/O ---
|
|||
|
|||
// pad4 rounds n up to the next multiple of 4.
|
|||
func pad4(n uint32) uint32 { return (n + 3) &^ 3 } |
|||
|
|||
// ReadPDU reads a complete PDU from r.
|
|||
func ReadPDU(r io.Reader) (*PDU, error) { |
|||
p := &PDU{} |
|||
|
|||
// Read BHS (48 bytes)
|
|||
if _, err := io.ReadFull(r, p.BHS[:]); err != nil { |
|||
if err == io.ErrUnexpectedEOF { |
|||
return nil, ErrPDUTruncated |
|||
} |
|||
return nil, err |
|||
} |
|||
|
|||
// AHS
|
|||
ahsLen := uint32(p.TotalAHSLength()) * 4 |
|||
if ahsLen > 0 { |
|||
p.AHS = make([]byte, ahsLen) |
|||
if _, err := io.ReadFull(r, p.AHS); err != nil { |
|||
if err == io.ErrUnexpectedEOF { |
|||
return nil, ErrPDUTruncated |
|||
} |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
// Data segment
|
|||
dsLen := p.DataSegmentLength() |
|||
if dsLen > MaxDataSegmentLength { |
|||
return nil, fmt.Errorf("%w: %d bytes", ErrPDUTooLarge, dsLen) |
|||
} |
|||
|
|||
if dsLen > 0 { |
|||
paddedLen := pad4(dsLen) |
|||
buf := make([]byte, paddedLen) |
|||
if _, err := io.ReadFull(r, buf); err != nil { |
|||
if err == io.ErrUnexpectedEOF { |
|||
return nil, ErrPDUTruncated |
|||
} |
|||
return nil, err |
|||
} |
|||
p.DataSegment = buf[:dsLen] // strip padding
|
|||
} |
|||
|
|||
return p, nil |
|||
} |
|||
|
|||
// WritePDU writes a complete PDU to w, with proper padding.
|
|||
func WritePDU(w io.Writer, p *PDU) error { |
|||
// Update header lengths from actual AHS/DataSegment
|
|||
if len(p.AHS) > 0 { |
|||
if len(p.AHS)%4 != 0 { |
|||
return ErrInvalidAHSLength |
|||
} |
|||
p.BHS[4] = uint8(len(p.AHS) / 4) |
|||
} else { |
|||
p.BHS[4] = 0 |
|||
} |
|||
p.SetDataSegmentLength(uint32(len(p.DataSegment))) |
|||
|
|||
// Write BHS
|
|||
if _, err := w.Write(p.BHS[:]); err != nil { |
|||
return err |
|||
} |
|||
|
|||
// Write AHS
|
|||
if len(p.AHS) > 0 { |
|||
if _, err := w.Write(p.AHS); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
// Write data segment with padding
|
|||
if len(p.DataSegment) > 0 { |
|||
if _, err := w.Write(p.DataSegment); err != nil { |
|||
return err |
|||
} |
|||
padLen := pad4(uint32(len(p.DataSegment))) - uint32(len(p.DataSegment)) |
|||
if padLen > 0 { |
|||
var pad [3]byte |
|||
if _, err := w.Write(pad[:padLen]); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// OpcodeName returns a human-readable name for the given opcode.
|
|||
func OpcodeName(op uint8) string { |
|||
switch op { |
|||
case OpNOPOut: |
|||
return "NOP-Out" |
|||
case OpSCSICmd: |
|||
return "SCSI-Command" |
|||
case OpSCSITaskMgmt: |
|||
return "SCSI-Task-Mgmt" |
|||
case OpLoginReq: |
|||
return "Login-Request" |
|||
case OpTextReq: |
|||
return "Text-Request" |
|||
case OpSCSIDataOut: |
|||
return "SCSI-Data-Out" |
|||
case OpLogoutReq: |
|||
return "Logout-Request" |
|||
case OpSNACKReq: |
|||
return "SNACK-Request" |
|||
case OpNOPIn: |
|||
return "NOP-In" |
|||
case OpSCSIResp: |
|||
return "SCSI-Response" |
|||
case OpSCSITaskResp: |
|||
return "SCSI-Task-Mgmt-Response" |
|||
case OpLoginResp: |
|||
return "Login-Response" |
|||
case OpTextResp: |
|||
return "Text-Response" |
|||
case OpSCSIDataIn: |
|||
return "SCSI-Data-In" |
|||
case OpLogoutResp: |
|||
return "Logout-Response" |
|||
case OpR2T: |
|||
return "R2T" |
|||
case OpAsyncMsg: |
|||
return "Async-Message" |
|||
case OpReject: |
|||
return "Reject" |
|||
default: |
|||
return fmt.Sprintf("Unknown(0x%02x)", op) |
|||
} |
|||
} |
|||
|
|||
// Login PDU helpers
|
|||
|
|||
// LoginCSG returns the Current Stage from byte 1 of a login PDU (bits 3-2).
|
|||
func (p *PDU) LoginCSG() uint8 { return (p.BHS[1] >> 2) & 0x03 } |
|||
|
|||
// LoginNSG returns the Next Stage from byte 1 of a login PDU (bits 1-0).
|
|||
func (p *PDU) LoginNSG() uint8 { return p.BHS[1] & 0x03 } |
|||
|
|||
// SetLoginStages sets the CSG and NSG fields in byte 1, preserving T and C flags.
|
|||
func (p *PDU) SetLoginStages(csg, nsg uint8) { |
|||
p.BHS[1] = (p.BHS[1] & 0xF0) | ((csg & 0x03) << 2) | (nsg & 0x03) |
|||
} |
|||
|
|||
// LoginTransit returns true if the Transit bit is set (byte 1, bit 7).
|
|||
func (p *PDU) LoginTransit() bool { return p.BHS[1]&FlagT != 0 } |
|||
|
|||
// SetLoginTransit sets or clears the Transit bit.
|
|||
func (p *PDU) SetLoginTransit(v bool) { |
|||
if v { |
|||
p.BHS[1] |= FlagT |
|||
} else { |
|||
p.BHS[1] &^= FlagT |
|||
} |
|||
} |
|||
|
|||
// LoginContinue returns true if the Continue bit is set.
|
|||
func (p *PDU) LoginContinue() bool { return p.BHS[1]&FlagC != 0 } |
|||
|
|||
// LoginStatusClass returns the status class (byte 36) of a login response.
|
|||
func (p *PDU) LoginStatusClass() uint8 { return p.BHS[36] } |
|||
|
|||
// LoginStatusDetail returns the status detail (byte 37) of a login response.
|
|||
func (p *PDU) LoginStatusDetail() uint8 { return p.BHS[37] } |
|||
|
|||
// SetLoginStatus sets the status class and detail in a login response.
|
|||
func (p *PDU) SetLoginStatus(class, detail uint8) { |
|||
p.BHS[36] = class |
|||
p.BHS[37] = detail |
|||
} |
|||
|
|||
// SCSIResponse returns the iSCSI response byte (byte 2) of a SCSI Response PDU.
|
|||
func (p *PDU) SCSIResponse() uint8 { return p.BHS[2] } |
|||
|
|||
// SetSCSIResponse sets byte 2.
|
|||
func (p *PDU) SetSCSIResponse(v uint8) { p.BHS[2] = v } |
|||
|
|||
// SCSIStatus returns the SCSI status byte (byte 3) of a SCSI Response PDU.
|
|||
func (p *PDU) SCSIStatus() uint8 { return p.BHS[3] } |
|||
|
|||
// SetSCSIStatus sets byte 3.
|
|||
func (p *PDU) SetSCSIStatus(v uint8) { p.BHS[3] = v } |
|||
@ -0,0 +1,559 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"io" |
|||
"strings" |
|||
"testing" |
|||
) |
|||
|
|||
func TestPDU(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"roundtrip_bhs_only", testRoundtripBHSOnly}, |
|||
{"roundtrip_with_data", testRoundtripWithData}, |
|||
{"roundtrip_with_ahs_and_data", testRoundtripWithAHSAndData}, |
|||
{"data_padding", testDataPadding}, |
|||
{"opcode_accessors", testOpcodeAccessors}, |
|||
{"immediate_flag", testImmediateFlag}, |
|||
{"field32_accessors", testField32Accessors}, |
|||
{"lun_accessor", testLUNAccessor}, |
|||
{"isid_accessor", testISIDAccessor}, |
|||
{"cdb_accessor", testCDBAccessor}, |
|||
{"login_stage_accessors", testLoginStageAccessors}, |
|||
{"login_transit_continue", testLoginTransitContinue}, |
|||
{"login_status", testLoginStatus}, |
|||
{"scsi_response_status", testSCSIResponseStatus}, |
|||
{"data_segment_length_3byte", testDataSegmentLength3Byte}, |
|||
{"tsih_accessor", testTSIHAccessor}, |
|||
{"cmdsn_expstatsn", testCmdSNExpStatSN}, |
|||
{"statsn_expcmdsn_maxcmdsn", testStatSNExpCmdSNMaxCmdSN}, |
|||
{"datasn_bufferoffset", testDataSNBufferOffset}, |
|||
{"r2t_fields", testR2TFields}, |
|||
{"residual_count", testResidualCount}, |
|||
{"expected_data_transfer_length", testExpectedDataTransferLength}, |
|||
{"target_transfer_tag", testTargetTransferTag}, |
|||
{"opcode_name", testOpcodeName}, |
|||
{"read_truncated_bhs", testReadTruncatedBHS}, |
|||
{"read_truncated_data", testReadTruncatedData}, |
|||
{"read_truncated_ahs", testReadTruncatedAHS}, |
|||
{"read_oversized_data", testReadOversizedData}, |
|||
{"read_eof", testReadEOF}, |
|||
{"write_invalid_ahs_length", testWriteInvalidAHSLength}, |
|||
{"roundtrip_all_opcodes", testRoundtripAllOpcodes}, |
|||
{"data_segment_exact_4byte_boundary", testDataSegmentExact4ByteBoundary}, |
|||
{"zero_length_data_segment", testZeroLengthDataSegment}, |
|||
{"max_3byte_data_length", testMax3ByteDataLength}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testRoundtripBHSOnly(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetOpcode(OpSCSICmd) |
|||
p.SetInitiatorTaskTag(0xDEADBEEF) |
|||
p.SetCmdSN(42) |
|||
|
|||
var buf bytes.Buffer |
|||
if err := WritePDU(&buf, p); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if buf.Len() != BHSLength { |
|||
t.Fatalf("expected %d bytes, got %d", BHSLength, buf.Len()) |
|||
} |
|||
|
|||
p2, err := ReadPDU(&buf) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if p2.Opcode() != OpSCSICmd { |
|||
t.Fatalf("opcode mismatch: got 0x%02x", p2.Opcode()) |
|||
} |
|||
if p2.InitiatorTaskTag() != 0xDEADBEEF { |
|||
t.Fatalf("ITT mismatch: got 0x%x", p2.InitiatorTaskTag()) |
|||
} |
|||
if p2.CmdSN() != 42 { |
|||
t.Fatalf("CmdSN mismatch: got %d", p2.CmdSN()) |
|||
} |
|||
} |
|||
|
|||
func testRoundtripWithData(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetOpcode(OpSCSIDataIn) |
|||
p.DataSegment = []byte("hello, iSCSI world!") |
|||
|
|||
var buf bytes.Buffer |
|||
if err := WritePDU(&buf, p); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
// BHS(48) + data(19) + padding(1) = 68
|
|||
expectedLen := 48 + pad4(uint32(len(p.DataSegment))) |
|||
if uint32(buf.Len()) != expectedLen { |
|||
t.Fatalf("wire length: expected %d, got %d", expectedLen, buf.Len()) |
|||
} |
|||
|
|||
p2, err := ReadPDU(&buf) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if !bytes.Equal(p2.DataSegment, []byte("hello, iSCSI world!")) { |
|||
t.Fatalf("data mismatch: %q", p2.DataSegment) |
|||
} |
|||
} |
|||
|
|||
func testRoundtripWithAHSAndData(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetOpcode(OpSCSICmd) |
|||
p.AHS = make([]byte, 8) // 2 words
|
|||
p.AHS[0] = 0xAA |
|||
p.AHS[7] = 0xBB |
|||
p.DataSegment = []byte{1, 2, 3, 4, 5} |
|||
|
|||
var buf bytes.Buffer |
|||
if err := WritePDU(&buf, p); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
p2, err := ReadPDU(&buf) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if len(p2.AHS) != 8 { |
|||
t.Fatalf("AHS length: expected 8, got %d", len(p2.AHS)) |
|||
} |
|||
if p2.AHS[0] != 0xAA || p2.AHS[7] != 0xBB { |
|||
t.Fatal("AHS content mismatch") |
|||
} |
|||
if !bytes.Equal(p2.DataSegment, []byte{1, 2, 3, 4, 5}) { |
|||
t.Fatal("data segment mismatch") |
|||
} |
|||
} |
|||
|
|||
func testDataPadding(t *testing.T) { |
|||
for _, dataLen := range []int{1, 2, 3, 4, 5, 7, 8, 100, 1023} { |
|||
p := &PDU{} |
|||
p.SetOpcode(OpSCSIDataIn) |
|||
p.DataSegment = bytes.Repeat([]byte{0xFF}, dataLen) |
|||
|
|||
var buf bytes.Buffer |
|||
if err := WritePDU(&buf, p); err != nil { |
|||
t.Fatalf("dataLen=%d: write: %v", dataLen, err) |
|||
} |
|||
|
|||
expectedWire := BHSLength + int(pad4(uint32(dataLen))) |
|||
if buf.Len() != expectedWire { |
|||
t.Fatalf("dataLen=%d: wire=%d, expected=%d", dataLen, buf.Len(), expectedWire) |
|||
} |
|||
|
|||
p2, err := ReadPDU(&buf) |
|||
if err != nil { |
|||
t.Fatalf("dataLen=%d: read: %v", dataLen, err) |
|||
} |
|||
if len(p2.DataSegment) != dataLen { |
|||
t.Fatalf("dataLen=%d: got %d", dataLen, len(p2.DataSegment)) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testOpcodeAccessors(t *testing.T) { |
|||
p := &PDU{} |
|||
// Set immediate first, then opcode — verify no interference
|
|||
p.SetImmediate(true) |
|||
p.SetOpcode(OpLoginReq) |
|||
if p.Opcode() != OpLoginReq { |
|||
t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpLoginReq) |
|||
} |
|||
if !p.Immediate() { |
|||
t.Fatal("immediate flag lost") |
|||
} |
|||
|
|||
// Change opcode, verify immediate preserved
|
|||
p.SetOpcode(OpTextReq) |
|||
if p.Opcode() != OpTextReq { |
|||
t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpTextReq) |
|||
} |
|||
if !p.Immediate() { |
|||
t.Fatal("immediate flag lost after opcode change") |
|||
} |
|||
} |
|||
|
|||
func testImmediateFlag(t *testing.T) { |
|||
p := &PDU{} |
|||
if p.Immediate() { |
|||
t.Fatal("should start false") |
|||
} |
|||
p.SetImmediate(true) |
|||
if !p.Immediate() { |
|||
t.Fatal("should be true") |
|||
} |
|||
p.SetImmediate(false) |
|||
if p.Immediate() { |
|||
t.Fatal("should be false again") |
|||
} |
|||
} |
|||
|
|||
func testField32Accessors(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetField32(20, 0x12345678) |
|||
if got := p.Field32(20); got != 0x12345678 { |
|||
t.Fatalf("got 0x%08x", got) |
|||
} |
|||
} |
|||
|
|||
func testLUNAccessor(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetLUN(0x0001000000000000) // LUN 1 in SAM encoding
|
|||
if p.LUN() != 0x0001000000000000 { |
|||
t.Fatalf("LUN mismatch: 0x%016x", p.LUN()) |
|||
} |
|||
} |
|||
|
|||
func testISIDAccessor(t *testing.T) { |
|||
p := &PDU{} |
|||
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} |
|||
p.SetISID(isid) |
|||
got := p.ISID() |
|||
if got != isid { |
|||
t.Fatalf("ISID mismatch: %v", got) |
|||
} |
|||
} |
|||
|
|||
func testCDBAccessor(t *testing.T) { |
|||
p := &PDU{} |
|||
var cdb [16]byte |
|||
cdb[0] = 0x28 // READ_10
|
|||
cdb[1] = 0x00 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 100) // LBA
|
|||
binary.BigEndian.PutUint16(cdb[7:9], 8) // transfer length
|
|||
p.SetCDB(cdb) |
|||
got := p.CDB() |
|||
if got != cdb { |
|||
t.Fatal("CDB mismatch") |
|||
} |
|||
} |
|||
|
|||
func testLoginStageAccessors(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetLoginStages(StageSecurityNeg, StageLoginOp) |
|||
if p.LoginCSG() != StageSecurityNeg { |
|||
t.Fatalf("CSG: got %d", p.LoginCSG()) |
|||
} |
|||
if p.LoginNSG() != StageLoginOp { |
|||
t.Fatalf("NSG: got %d", p.LoginNSG()) |
|||
} |
|||
|
|||
// Change to LoginOp -> FullFeature
|
|||
p.SetLoginStages(StageLoginOp, StageFullFeature) |
|||
if p.LoginCSG() != StageLoginOp { |
|||
t.Fatalf("CSG: got %d", p.LoginCSG()) |
|||
} |
|||
if p.LoginNSG() != StageFullFeature { |
|||
t.Fatalf("NSG: got %d", p.LoginNSG()) |
|||
} |
|||
} |
|||
|
|||
func testLoginTransitContinue(t *testing.T) { |
|||
p := &PDU{} |
|||
if p.LoginTransit() { |
|||
t.Fatal("Transit should start false") |
|||
} |
|||
if p.LoginContinue() { |
|||
t.Fatal("Continue should start false") |
|||
} |
|||
|
|||
p.SetLoginTransit(true) |
|||
p.SetLoginStages(StageLoginOp, StageFullFeature) |
|||
if !p.LoginTransit() { |
|||
t.Fatal("Transit should be true") |
|||
} |
|||
// Verify stages preserved
|
|||
if p.LoginCSG() != StageLoginOp { |
|||
t.Fatal("CSG lost after setting transit") |
|||
} |
|||
} |
|||
|
|||
func testLoginStatus(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetLoginStatus(0x02, 0x01) // Initiator error, authentication failure
|
|||
if p.LoginStatusClass() != 0x02 { |
|||
t.Fatalf("class: got %d", p.LoginStatusClass()) |
|||
} |
|||
if p.LoginStatusDetail() != 0x01 { |
|||
t.Fatalf("detail: got %d", p.LoginStatusDetail()) |
|||
} |
|||
} |
|||
|
|||
func testSCSIResponseStatus(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetSCSIResponse(ISCSIRespCompleted) |
|||
p.SetSCSIStatus(SCSIStatusGood) |
|||
if p.SCSIResponse() != ISCSIRespCompleted { |
|||
t.Fatal("response mismatch") |
|||
} |
|||
if p.SCSIStatus() != SCSIStatusGood { |
|||
t.Fatal("status mismatch") |
|||
} |
|||
} |
|||
|
|||
func testDataSegmentLength3Byte(t *testing.T) { |
|||
p := &PDU{} |
|||
// Test various sizes including >64KB (needs all 3 bytes)
|
|||
for _, size := range []uint32{0, 1, 255, 256, 65535, 65536, 1<<24 - 1} { |
|||
p.SetDataSegmentLength(size) |
|||
if got := p.DataSegmentLength(); got != size { |
|||
t.Fatalf("size %d: got %d", size, got) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testTSIHAccessor(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetTSIH(0x1234) |
|||
if p.TSIH() != 0x1234 { |
|||
t.Fatalf("TSIH: got 0x%04x", p.TSIH()) |
|||
} |
|||
} |
|||
|
|||
func testCmdSNExpStatSN(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetCmdSN(100) |
|||
p.SetExpStatSN(200) |
|||
if p.CmdSN() != 100 { |
|||
t.Fatal("CmdSN mismatch") |
|||
} |
|||
if p.ExpStatSN() != 200 { |
|||
t.Fatal("ExpStatSN mismatch") |
|||
} |
|||
} |
|||
|
|||
func testStatSNExpCmdSNMaxCmdSN(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetStatSN(10) |
|||
p.SetExpCmdSN(20) |
|||
p.SetMaxCmdSN(30) |
|||
if p.StatSN() != 10 { |
|||
t.Fatal("StatSN mismatch") |
|||
} |
|||
if p.ExpCmdSN() != 20 { |
|||
t.Fatal("ExpCmdSN mismatch") |
|||
} |
|||
if p.MaxCmdSN() != 30 { |
|||
t.Fatal("MaxCmdSN mismatch") |
|||
} |
|||
} |
|||
|
|||
func testDataSNBufferOffset(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetDataSN(5) |
|||
p.SetBufferOffset(8192) |
|||
if p.DataSN() != 5 { |
|||
t.Fatal("DataSN mismatch") |
|||
} |
|||
if p.BufferOffset() != 8192 { |
|||
t.Fatal("BufferOffset mismatch") |
|||
} |
|||
} |
|||
|
|||
func testR2TFields(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetR2TSN(3) |
|||
p.SetDesiredDataLength(65536) |
|||
if p.R2TSN() != 3 { |
|||
t.Fatal("R2TSN mismatch") |
|||
} |
|||
if p.DesiredDataLength() != 65536 { |
|||
t.Fatal("DesiredDataLength mismatch") |
|||
} |
|||
} |
|||
|
|||
func testResidualCount(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetResidualCount(512) |
|||
if p.ResidualCount() != 512 { |
|||
t.Fatal("ResidualCount mismatch") |
|||
} |
|||
} |
|||
|
|||
func testExpectedDataTransferLength(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetExpectedDataTransferLength(1048576) |
|||
if p.ExpectedDataTransferLength() != 1048576 { |
|||
t.Fatal("mismatch") |
|||
} |
|||
} |
|||
|
|||
func testTargetTransferTag(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetTargetTransferTag(0xFFFFFFFF) |
|||
if p.TargetTransferTag() != 0xFFFFFFFF { |
|||
t.Fatal("mismatch") |
|||
} |
|||
} |
|||
|
|||
func testOpcodeName(t *testing.T) { |
|||
if OpcodeName(OpSCSICmd) != "SCSI-Command" { |
|||
t.Fatal("wrong name for SCSI-Command") |
|||
} |
|||
if OpcodeName(OpLoginReq) != "Login-Request" { |
|||
t.Fatal("wrong name for Login-Request") |
|||
} |
|||
if !strings.HasPrefix(OpcodeName(0xFF), "Unknown") { |
|||
t.Fatal("unknown opcode should have Unknown prefix") |
|||
} |
|||
} |
|||
|
|||
func testReadTruncatedBHS(t *testing.T) { |
|||
// Only 20 bytes — not enough for BHS
|
|||
buf := bytes.NewReader(make([]byte, 20)) |
|||
_, err := ReadPDU(buf) |
|||
if err == nil { |
|||
t.Fatal("expected error") |
|||
} |
|||
} |
|||
|
|||
func testReadTruncatedData(t *testing.T) { |
|||
// Valid BHS claiming 100 bytes of data, but only 10 bytes follow
|
|||
var bhs [BHSLength]byte |
|||
bhs[5] = 0 |
|||
bhs[6] = 0 |
|||
bhs[7] = 100 // DataSegmentLength = 100
|
|||
buf := make([]byte, BHSLength+10) |
|||
copy(buf, bhs[:]) |
|||
|
|||
_, err := ReadPDU(bytes.NewReader(buf)) |
|||
if err == nil { |
|||
t.Fatal("expected truncation error") |
|||
} |
|||
} |
|||
|
|||
func testReadTruncatedAHS(t *testing.T) { |
|||
// BHS claiming 2 words of AHS but only 4 bytes follow
|
|||
var bhs [BHSLength]byte |
|||
bhs[4] = 2 // TotalAHSLength = 2 words = 8 bytes
|
|||
buf := make([]byte, BHSLength+4) |
|||
copy(buf, bhs[:]) |
|||
|
|||
_, err := ReadPDU(bytes.NewReader(buf)) |
|||
if err == nil { |
|||
t.Fatal("expected truncation error") |
|||
} |
|||
} |
|||
|
|||
func testReadOversizedData(t *testing.T) { |
|||
// BHS claiming MaxDataSegmentLength+1 bytes
|
|||
var bhs [BHSLength]byte |
|||
oversized := uint32(MaxDataSegmentLength + 1) |
|||
bhs[5] = byte(oversized >> 16) |
|||
bhs[6] = byte(oversized >> 8) |
|||
bhs[7] = byte(oversized) |
|||
|
|||
_, err := ReadPDU(bytes.NewReader(bhs[:])) |
|||
if err == nil { |
|||
t.Fatal("expected oversized error") |
|||
} |
|||
if !strings.Contains(err.Error(), "exceeds maximum") { |
|||
t.Fatalf("unexpected error: %v", err) |
|||
} |
|||
} |
|||
|
|||
func testReadEOF(t *testing.T) { |
|||
_, err := ReadPDU(bytes.NewReader(nil)) |
|||
if err != io.EOF { |
|||
t.Fatalf("expected io.EOF, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testWriteInvalidAHSLength(t *testing.T) { |
|||
p := &PDU{} |
|||
p.AHS = make([]byte, 5) // not multiple of 4
|
|||
var buf bytes.Buffer |
|||
err := WritePDU(&buf, p) |
|||
if err != ErrInvalidAHSLength { |
|||
t.Fatalf("expected ErrInvalidAHSLength, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testRoundtripAllOpcodes(t *testing.T) { |
|||
opcodes := []uint8{ |
|||
OpNOPOut, OpSCSICmd, OpSCSITaskMgmt, OpLoginReq, OpTextReq, |
|||
OpSCSIDataOut, OpLogoutReq, OpSNACKReq, |
|||
OpNOPIn, OpSCSIResp, OpSCSITaskResp, OpLoginResp, OpTextResp, |
|||
OpSCSIDataIn, OpLogoutResp, OpR2T, OpAsyncMsg, OpReject, |
|||
} |
|||
for _, op := range opcodes { |
|||
p := &PDU{} |
|||
p.SetOpcode(op) |
|||
var buf bytes.Buffer |
|||
if err := WritePDU(&buf, p); err != nil { |
|||
t.Fatalf("op=0x%02x: write: %v", op, err) |
|||
} |
|||
p2, err := ReadPDU(&buf) |
|||
if err != nil { |
|||
t.Fatalf("op=0x%02x: read: %v", op, err) |
|||
} |
|||
if p2.Opcode() != op { |
|||
t.Fatalf("op=0x%02x: got 0x%02x", op, p2.Opcode()) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testDataSegmentExact4ByteBoundary(t *testing.T) { |
|||
for _, size := range []int{4, 8, 12, 16, 256, 4096} { |
|||
p := &PDU{} |
|||
p.DataSegment = bytes.Repeat([]byte{0xAB}, size) |
|||
var buf bytes.Buffer |
|||
if err := WritePDU(&buf, p); err != nil { |
|||
t.Fatalf("size=%d: %v", size, err) |
|||
} |
|||
// No padding needed
|
|||
if buf.Len() != BHSLength+size { |
|||
t.Fatalf("size=%d: wire=%d, expected=%d", size, buf.Len(), BHSLength+size) |
|||
} |
|||
p2, err := ReadPDU(&buf) |
|||
if err != nil { |
|||
t.Fatalf("size=%d: read: %v", size, err) |
|||
} |
|||
if len(p2.DataSegment) != size { |
|||
t.Fatalf("size=%d: got %d", size, len(p2.DataSegment)) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testZeroLengthDataSegment(t *testing.T) { |
|||
p := &PDU{} |
|||
p.SetOpcode(OpNOPOut) |
|||
// DataSegment is nil
|
|||
|
|||
var buf bytes.Buffer |
|||
if err := WritePDU(&buf, p); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if buf.Len() != BHSLength { |
|||
t.Fatalf("expected %d, got %d", BHSLength, buf.Len()) |
|||
} |
|||
|
|||
p2, err := ReadPDU(&buf) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if len(p2.DataSegment) != 0 { |
|||
t.Fatalf("expected empty data segment, got %d bytes", len(p2.DataSegment)) |
|||
} |
|||
} |
|||
|
|||
func testMax3ByteDataLength(t *testing.T) { |
|||
p := &PDU{} |
|||
maxLen := uint32(1<<24 - 1) // 16MB - 1
|
|||
p.SetDataSegmentLength(maxLen) |
|||
if got := p.DataSegmentLength(); got != maxLen { |
|||
t.Fatalf("expected %d, got %d", maxLen, got) |
|||
} |
|||
} |
|||
@ -0,0 +1,463 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
) |
|||
|
|||
// SCSI opcode constants (SPC-5 / SBC-4)
|
|||
const ( |
|||
ScsiTestUnitReady uint8 = 0x00 |
|||
ScsiInquiry uint8 = 0x12 |
|||
ScsiModeSense6 uint8 = 0x1a |
|||
ScsiReadCapacity10 uint8 = 0x25 |
|||
ScsiRead10 uint8 = 0x28 |
|||
ScsiWrite10 uint8 = 0x2a |
|||
ScsiSyncCache10 uint8 = 0x35 |
|||
ScsiUnmap uint8 = 0x42 |
|||
ScsiReportLuns uint8 = 0xa0 |
|||
ScsiRead16 uint8 = 0x88 |
|||
ScsiWrite16 uint8 = 0x8a |
|||
ScsiReadCapacity16 uint8 = 0x9e // SERVICE ACTION IN (16), SA=0x10
|
|||
ScsiSyncCache16 uint8 = 0x91 |
|||
) |
|||
|
|||
// Service action for READ CAPACITY (16)
|
|||
const ScsiSAReadCapacity16 uint8 = 0x10 |
|||
|
|||
// SCSI sense keys
|
|||
const ( |
|||
SenseNoSense uint8 = 0x00 |
|||
SenseNotReady uint8 = 0x02 |
|||
SenseMediumError uint8 = 0x03 |
|||
SenseHardwareError uint8 = 0x04 |
|||
SenseIllegalRequest uint8 = 0x05 |
|||
SenseAbortedCommand uint8 = 0x0b |
|||
) |
|||
|
|||
// ASC/ASCQ pairs
|
|||
const ( |
|||
ASCInvalidOpcode uint8 = 0x20 |
|||
ASCQLuk uint8 = 0x00 |
|||
ASCInvalidFieldInCDB uint8 = 0x24 |
|||
ASCLBAOutOfRange uint8 = 0x21 |
|||
ASCNotReady uint8 = 0x04 |
|||
ASCQNotReady uint8 = 0x03 // manual intervention required
|
|||
) |
|||
|
|||
// BlockDevice is the interface that the SCSI command handler uses to
|
|||
// interact with the underlying storage. This maps onto BlockVol.
|
|||
type BlockDevice interface { |
|||
ReadAt(lba uint64, length uint32) ([]byte, error) |
|||
WriteAt(lba uint64, data []byte) error |
|||
Trim(lba uint64, length uint32) error |
|||
SyncCache() error |
|||
BlockSize() uint32 |
|||
VolumeSize() uint64 // total size in bytes
|
|||
IsHealthy() bool |
|||
} |
|||
|
|||
// SCSIHandler processes SCSI commands from iSCSI PDUs.
|
|||
type SCSIHandler struct { |
|||
dev BlockDevice |
|||
vendorID string // 8 bytes for INQUIRY
|
|||
prodID string // 16 bytes for INQUIRY
|
|||
serial string // for VPD page 0x80
|
|||
} |
|||
|
|||
// NewSCSIHandler creates a SCSI command handler for the given block device.
|
|||
func NewSCSIHandler(dev BlockDevice) *SCSIHandler { |
|||
return &SCSIHandler{ |
|||
dev: dev, |
|||
vendorID: "SeaweedF", |
|||
prodID: "BlockVol ", |
|||
serial: "SWF00001", |
|||
} |
|||
} |
|||
|
|||
// SCSIResult holds the result of a SCSI command execution.
|
|||
type SCSIResult struct { |
|||
Status uint8 // SCSI status
|
|||
Data []byte // Response data (for Data-In)
|
|||
SenseKey uint8 // Sense key (if CHECK_CONDITION)
|
|||
SenseASC uint8 // Additional sense code
|
|||
SenseASCQ uint8 // Additional sense code qualifier
|
|||
} |
|||
|
|||
// HandleCommand dispatches a SCSI CDB to the appropriate handler.
|
|||
// dataOut contains any data sent by the initiator (for WRITE commands).
|
|||
func (h *SCSIHandler) HandleCommand(cdb [16]byte, dataOut []byte) SCSIResult { |
|||
opcode := cdb[0] |
|||
|
|||
switch opcode { |
|||
case ScsiTestUnitReady: |
|||
return h.testUnitReady() |
|||
case ScsiInquiry: |
|||
return h.inquiry(cdb) |
|||
case ScsiModeSense6: |
|||
return h.modeSense6(cdb) |
|||
case ScsiReadCapacity10: |
|||
return h.readCapacity10() |
|||
case ScsiReadCapacity16: |
|||
sa := cdb[1] & 0x1f |
|||
if sa == ScsiSAReadCapacity16 { |
|||
return h.readCapacity16(cdb) |
|||
} |
|||
return illegalRequest(ASCInvalidOpcode, ASCQLuk) |
|||
case ScsiReportLuns: |
|||
return h.reportLuns(cdb) |
|||
case ScsiRead10: |
|||
return h.read10(cdb) |
|||
case ScsiRead16: |
|||
return h.read16(cdb) |
|||
case ScsiWrite10: |
|||
return h.write10(cdb, dataOut) |
|||
case ScsiWrite16: |
|||
return h.write16(cdb, dataOut) |
|||
case ScsiSyncCache10: |
|||
return h.syncCache() |
|||
case ScsiSyncCache16: |
|||
return h.syncCache() |
|||
case ScsiUnmap: |
|||
return h.unmap(cdb, dataOut) |
|||
default: |
|||
return illegalRequest(ASCInvalidOpcode, ASCQLuk) |
|||
} |
|||
} |
|||
|
|||
// --- Metadata commands ---
|
|||
|
|||
func (h *SCSIHandler) testUnitReady() SCSIResult { |
|||
if !h.dev.IsHealthy() { |
|||
return SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: SenseNotReady, |
|||
SenseASC: ASCNotReady, |
|||
SenseASCQ: ASCQNotReady, |
|||
} |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood} |
|||
} |
|||
|
|||
func (h *SCSIHandler) inquiry(cdb [16]byte) SCSIResult { |
|||
evpd := cdb[1] & 0x01 |
|||
pageCode := cdb[2] |
|||
allocLen := binary.BigEndian.Uint16(cdb[3:5]) |
|||
if allocLen == 0 { |
|||
allocLen = 36 |
|||
} |
|||
|
|||
if evpd != 0 { |
|||
return h.inquiryVPD(pageCode, allocLen) |
|||
} |
|||
|
|||
// Standard INQUIRY response (SPC-5, Section 6.6.1)
|
|||
data := make([]byte, 96) |
|||
data[0] = 0x00 // Peripheral device type: SBC (direct access block device)
|
|||
data[1] = 0x00 // RMB=0 (not removable)
|
|||
data[2] = 0x06 // SPC-4 version
|
|||
data[3] = 0x02 // Response data format = 2 (SPC-2+)
|
|||
data[4] = 91 // Additional length (96-5)
|
|||
data[5] = 0x00 // SCCS, ACC, TPGS, 3PC
|
|||
data[6] = 0x00 // Obsolete, EncServ, VS, MultiP
|
|||
data[7] = 0x02 // CmdQue=1 (supports command queuing)
|
|||
|
|||
// Vendor ID (bytes 8-15, 8 chars, space padded)
|
|||
copy(data[8:16], padRight(h.vendorID, 8)) |
|||
// Product ID (bytes 16-31, 16 chars, space padded)
|
|||
copy(data[16:32], padRight(h.prodID, 16)) |
|||
// Product revision (bytes 32-35, 4 chars)
|
|||
copy(data[32:36], "0001") |
|||
|
|||
if int(allocLen) < len(data) { |
|||
data = data[:allocLen] |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
} |
|||
|
|||
func (h *SCSIHandler) inquiryVPD(pageCode uint8, allocLen uint16) SCSIResult { |
|||
switch pageCode { |
|||
case 0x00: // Supported VPD pages
|
|||
data := []byte{ |
|||
0x00, // device type
|
|||
0x00, // page code
|
|||
0x00, 0x03, // page length
|
|||
0x00, // supported pages: 0x00
|
|||
0x80, // 0x80 (serial)
|
|||
0x83, // 0x83 (device identification)
|
|||
} |
|||
if int(allocLen) < len(data) { |
|||
data = data[:allocLen] |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
|
|||
case 0x80: // Unit serial number
|
|||
serial := padRight(h.serial, 8) |
|||
data := make([]byte, 4+len(serial)) |
|||
data[0] = 0x00 // device type
|
|||
data[1] = 0x80 // page code
|
|||
binary.BigEndian.PutUint16(data[2:4], uint16(len(serial))) |
|||
copy(data[4:], serial) |
|||
if int(allocLen) < len(data) { |
|||
data = data[:allocLen] |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
|
|||
case 0x83: // Device identification
|
|||
// NAA identifier (8 bytes)
|
|||
naaID := []byte{ |
|||
0x01, // code set: binary
|
|||
0x03, // identifier type: NAA
|
|||
0x00, // reserved
|
|||
0x08, // identifier length
|
|||
0x60, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // NAA-6 fake
|
|||
} |
|||
data := make([]byte, 4+len(naaID)) |
|||
data[0] = 0x00 // device type
|
|||
data[1] = 0x83 // page code
|
|||
binary.BigEndian.PutUint16(data[2:4], uint16(len(naaID))) |
|||
copy(data[4:], naaID) |
|||
if int(allocLen) < len(data) { |
|||
data = data[:allocLen] |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
|
|||
default: |
|||
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
|||
} |
|||
} |
|||
|
|||
func (h *SCSIHandler) readCapacity10() SCSIResult { |
|||
blockSize := h.dev.BlockSize() |
|||
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
|||
|
|||
data := make([]byte, 8) |
|||
// If >2TB (blocks > 0xFFFFFFFF), return 0xFFFFFFFF to signal use READ_CAPACITY_16
|
|||
if totalBlocks > 0xFFFFFFFF { |
|||
binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF) |
|||
} else { |
|||
binary.BigEndian.PutUint32(data[0:4], uint32(totalBlocks-1)) // last LBA
|
|||
} |
|||
binary.BigEndian.PutUint32(data[4:8], blockSize) |
|||
|
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
} |
|||
|
|||
func (h *SCSIHandler) readCapacity16(cdb [16]byte) SCSIResult { |
|||
allocLen := binary.BigEndian.Uint32(cdb[10:14]) |
|||
if allocLen < 32 { |
|||
allocLen = 32 |
|||
} |
|||
|
|||
blockSize := h.dev.BlockSize() |
|||
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
|||
|
|||
data := make([]byte, 32) |
|||
binary.BigEndian.PutUint64(data[0:8], totalBlocks-1) // last LBA
|
|||
binary.BigEndian.PutUint32(data[8:12], blockSize) // block length
|
|||
// data[12]: LBPME (logical block provisioning management enabled) = 1 for UNMAP support
|
|||
data[14] = 0x80 // LBPME bit
|
|||
|
|||
if allocLen < uint32(len(data)) { |
|||
data = data[:allocLen] |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
} |
|||
|
|||
func (h *SCSIHandler) modeSense6(cdb [16]byte) SCSIResult { |
|||
// Minimal MODE SENSE(6) response — no mode pages
|
|||
allocLen := cdb[4] |
|||
if allocLen == 0 { |
|||
allocLen = 4 |
|||
} |
|||
|
|||
data := make([]byte, 4) |
|||
data[0] = 3 // Mode data length (3 bytes follow)
|
|||
data[1] = 0x00 // Medium type: default
|
|||
data[2] = 0x00 // Device-specific parameter (no write protect)
|
|||
data[3] = 0x00 // Block descriptor length = 0
|
|||
|
|||
if int(allocLen) < len(data) { |
|||
data = data[:allocLen] |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
} |
|||
|
|||
func (h *SCSIHandler) reportLuns(cdb [16]byte) SCSIResult { |
|||
allocLen := binary.BigEndian.Uint32(cdb[6:10]) |
|||
if allocLen < 16 { |
|||
allocLen = 16 |
|||
} |
|||
|
|||
// Report a single LUN (LUN 0)
|
|||
data := make([]byte, 16) |
|||
binary.BigEndian.PutUint32(data[0:4], 8) // LUN list length (8 bytes, 1 LUN)
|
|||
// data[4:7] reserved
|
|||
// data[8:16] = LUN 0 (all zeros)
|
|||
|
|||
if allocLen < uint32(len(data)) { |
|||
data = data[:allocLen] |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
} |
|||
|
|||
// --- Data commands (Task 2.6) ---
|
|||
|
|||
func (h *SCSIHandler) read10(cdb [16]byte) SCSIResult { |
|||
lba := uint64(binary.BigEndian.Uint32(cdb[2:6])) |
|||
transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9])) |
|||
return h.doRead(lba, transferLen) |
|||
} |
|||
|
|||
func (h *SCSIHandler) read16(cdb [16]byte) SCSIResult { |
|||
lba := binary.BigEndian.Uint64(cdb[2:10]) |
|||
transferLen := binary.BigEndian.Uint32(cdb[10:14]) |
|||
return h.doRead(lba, transferLen) |
|||
} |
|||
|
|||
func (h *SCSIHandler) write10(cdb [16]byte, dataOut []byte) SCSIResult { |
|||
lba := uint64(binary.BigEndian.Uint32(cdb[2:6])) |
|||
transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9])) |
|||
return h.doWrite(lba, transferLen, dataOut) |
|||
} |
|||
|
|||
func (h *SCSIHandler) write16(cdb [16]byte, dataOut []byte) SCSIResult { |
|||
lba := binary.BigEndian.Uint64(cdb[2:10]) |
|||
transferLen := binary.BigEndian.Uint32(cdb[10:14]) |
|||
return h.doWrite(lba, transferLen, dataOut) |
|||
} |
|||
|
|||
func (h *SCSIHandler) doRead(lba uint64, transferLen uint32) SCSIResult { |
|||
if transferLen == 0 { |
|||
return SCSIResult{Status: SCSIStatusGood} |
|||
} |
|||
|
|||
blockSize := h.dev.BlockSize() |
|||
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
|||
|
|||
if lba+uint64(transferLen) > totalBlocks { |
|||
return illegalRequest(ASCLBAOutOfRange, ASCQLuk) |
|||
} |
|||
|
|||
byteLen := transferLen * blockSize |
|||
data, err := h.dev.ReadAt(lba, byteLen) |
|||
if err != nil { |
|||
return SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: SenseMediumError, |
|||
SenseASC: 0x11, // Unrecovered read error
|
|||
SenseASCQ: 0x00, |
|||
} |
|||
} |
|||
|
|||
return SCSIResult{Status: SCSIStatusGood, Data: data} |
|||
} |
|||
|
|||
func (h *SCSIHandler) doWrite(lba uint64, transferLen uint32, dataOut []byte) SCSIResult { |
|||
if transferLen == 0 { |
|||
return SCSIResult{Status: SCSIStatusGood} |
|||
} |
|||
|
|||
blockSize := h.dev.BlockSize() |
|||
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
|||
|
|||
if lba+uint64(transferLen) > totalBlocks { |
|||
return illegalRequest(ASCLBAOutOfRange, ASCQLuk) |
|||
} |
|||
|
|||
expectedBytes := transferLen * blockSize |
|||
if uint32(len(dataOut)) < expectedBytes { |
|||
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
|||
} |
|||
|
|||
if err := h.dev.WriteAt(lba, dataOut[:expectedBytes]); err != nil { |
|||
return SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: SenseMediumError, |
|||
SenseASC: 0x0C, // Write error
|
|||
SenseASCQ: 0x00, |
|||
} |
|||
} |
|||
|
|||
return SCSIResult{Status: SCSIStatusGood} |
|||
} |
|||
|
|||
func (h *SCSIHandler) syncCache() SCSIResult { |
|||
if err := h.dev.SyncCache(); err != nil { |
|||
return SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: SenseHardwareError, |
|||
SenseASC: 0x00, |
|||
SenseASCQ: 0x00, |
|||
} |
|||
} |
|||
return SCSIResult{Status: SCSIStatusGood} |
|||
} |
|||
|
|||
func (h *SCSIHandler) unmap(cdb [16]byte, dataOut []byte) SCSIResult { |
|||
if len(dataOut) < 8 { |
|||
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
|||
} |
|||
|
|||
// UNMAP parameter list header (8 bytes)
|
|||
// descLen := binary.BigEndian.Uint16(dataOut[0:2]) // data length (unused)
|
|||
blockDescLen := binary.BigEndian.Uint16(dataOut[2:4]) |
|||
|
|||
if int(blockDescLen)+8 > len(dataOut) { |
|||
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
|||
} |
|||
|
|||
// Each UNMAP block descriptor is 16 bytes
|
|||
descData := dataOut[8 : 8+blockDescLen] |
|||
for len(descData) >= 16 { |
|||
lba := binary.BigEndian.Uint64(descData[0:8]) |
|||
numBlocks := binary.BigEndian.Uint32(descData[8:12]) |
|||
// descData[12:16] reserved
|
|||
|
|||
if numBlocks > 0 { |
|||
blockSize := h.dev.BlockSize() |
|||
if err := h.dev.Trim(lba, numBlocks*blockSize); err != nil { |
|||
return SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: SenseMediumError, |
|||
SenseASC: 0x0C, |
|||
SenseASCQ: 0x00, |
|||
} |
|||
} |
|||
} |
|||
descData = descData[16:] |
|||
} |
|||
|
|||
return SCSIResult{Status: SCSIStatusGood} |
|||
} |
|||
|
|||
// BuildSenseData constructs a fixed-format sense data buffer (18 bytes).
|
|||
func BuildSenseData(key, asc, ascq uint8) []byte { |
|||
data := make([]byte, 18) |
|||
data[0] = 0x70 // Response code: current errors, fixed format
|
|||
data[2] = key & 0x0f // Sense key
|
|||
data[7] = 10 // Additional sense length
|
|||
data[12] = asc // ASC
|
|||
data[13] = ascq // ASCQ
|
|||
return data |
|||
} |
|||
|
|||
func illegalRequest(asc, ascq uint8) SCSIResult { |
|||
return SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: SenseIllegalRequest, |
|||
SenseASC: asc, |
|||
SenseASCQ: ascq, |
|||
} |
|||
} |
|||
|
|||
func padRight(s string, n int) string { |
|||
if len(s) >= n { |
|||
return s[:n] |
|||
} |
|||
b := make([]byte, n) |
|||
copy(b, s) |
|||
for i := len(s); i < n; i++ { |
|||
b[i] = ' ' |
|||
} |
|||
return string(b) |
|||
} |
|||
@ -0,0 +1,692 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"errors" |
|||
"testing" |
|||
) |
|||
|
|||
// mockBlockDevice implements BlockDevice for testing.
|
|||
type mockBlockDevice struct { |
|||
blockSize uint32 |
|||
volumeSize uint64 |
|||
healthy bool |
|||
blocks map[uint64][]byte // LBA -> data
|
|||
syncErr error |
|||
readErr error |
|||
writeErr error |
|||
trimErr error |
|||
} |
|||
|
|||
func newMockDevice(volumeSize uint64) *mockBlockDevice { |
|||
return &mockBlockDevice{ |
|||
blockSize: 4096, |
|||
volumeSize: volumeSize, |
|||
healthy: true, |
|||
blocks: make(map[uint64][]byte), |
|||
} |
|||
} |
|||
|
|||
func (m *mockBlockDevice) ReadAt(lba uint64, length uint32) ([]byte, error) { |
|||
if m.readErr != nil { |
|||
return nil, m.readErr |
|||
} |
|||
blockCount := length / m.blockSize |
|||
result := make([]byte, length) |
|||
for i := uint32(0); i < blockCount; i++ { |
|||
if data, ok := m.blocks[lba+uint64(i)]; ok { |
|||
copy(result[i*m.blockSize:], data) |
|||
} |
|||
// Unwritten blocks return zeros (already zeroed)
|
|||
} |
|||
return result, nil |
|||
} |
|||
|
|||
func (m *mockBlockDevice) WriteAt(lba uint64, data []byte) error { |
|||
if m.writeErr != nil { |
|||
return m.writeErr |
|||
} |
|||
blockCount := uint32(len(data)) / m.blockSize |
|||
for i := uint32(0); i < blockCount; i++ { |
|||
block := make([]byte, m.blockSize) |
|||
copy(block, data[i*m.blockSize:]) |
|||
m.blocks[lba+uint64(i)] = block |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (m *mockBlockDevice) Trim(lba uint64, length uint32) error { |
|||
if m.trimErr != nil { |
|||
return m.trimErr |
|||
} |
|||
blockCount := length / m.blockSize |
|||
for i := uint32(0); i < blockCount; i++ { |
|||
delete(m.blocks, lba+uint64(i)) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (m *mockBlockDevice) SyncCache() error { return m.syncErr } |
|||
func (m *mockBlockDevice) BlockSize() uint32 { return m.blockSize } |
|||
func (m *mockBlockDevice) VolumeSize() uint64 { return m.volumeSize } |
|||
func (m *mockBlockDevice) IsHealthy() bool { return m.healthy } |
|||
|
|||
func TestSCSI(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"test_unit_ready_good", testTestUnitReadyGood}, |
|||
{"test_unit_ready_not_ready", testTestUnitReadyNotReady}, |
|||
{"inquiry_standard", testInquiryStandard}, |
|||
{"inquiry_vpd_supported_pages", testInquiryVPDSupportedPages}, |
|||
{"inquiry_vpd_serial", testInquiryVPDSerial}, |
|||
{"inquiry_vpd_device_id", testInquiryVPDDeviceID}, |
|||
{"inquiry_vpd_unknown_page", testInquiryVPDUnknownPage}, |
|||
{"inquiry_alloc_length", testInquiryAllocLength}, |
|||
{"read_capacity_10", testReadCapacity10}, |
|||
{"read_capacity_10_large", testReadCapacity10Large}, |
|||
{"read_capacity_16", testReadCapacity16}, |
|||
{"read_capacity_16_lbpme", testReadCapacity16LBPME}, |
|||
{"mode_sense_6", testModeSense6}, |
|||
{"report_luns", testReportLuns}, |
|||
{"unknown_opcode", testUnknownOpcode}, |
|||
{"read_10", testRead10}, |
|||
{"read_16", testRead16}, |
|||
{"write_10", testWrite10}, |
|||
{"write_16", testWrite16}, |
|||
{"read_write_roundtrip", testReadWriteRoundtrip}, |
|||
{"write_oob", testWriteOOB}, |
|||
{"read_oob", testReadOOB}, |
|||
{"zero_length_transfer", testZeroLengthTransfer}, |
|||
{"sync_cache", testSyncCache}, |
|||
{"sync_cache_error", testSyncCacheError}, |
|||
{"unmap_single", testUnmapSingle}, |
|||
{"unmap_multiple_descriptors", testUnmapMultipleDescriptors}, |
|||
{"unmap_short_param", testUnmapShortParam}, |
|||
{"build_sense_data", testBuildSenseData}, |
|||
{"read_error", testReadError}, |
|||
{"write_error", testWriteError}, |
|||
{"read_capacity_16_invalid_sa", testReadCapacity16InvalidSA}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testTestUnitReadyGood(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiTestUnitReady |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatalf("status: %d", r.Status) |
|||
} |
|||
} |
|||
|
|||
func testTestUnitReadyNotReady(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
dev.healthy = false |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiTestUnitReady |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatalf("status: %d", r.Status) |
|||
} |
|||
if r.SenseKey != SenseNotReady { |
|||
t.Fatalf("sense key: %d", r.SenseKey) |
|||
} |
|||
} |
|||
|
|||
func testInquiryStandard(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiInquiry |
|||
binary.BigEndian.PutUint16(cdb[3:5], 96) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatalf("status: %d", r.Status) |
|||
} |
|||
if len(r.Data) != 96 { |
|||
t.Fatalf("data length: %d", len(r.Data)) |
|||
} |
|||
// Check peripheral device type
|
|||
if r.Data[0] != 0x00 { |
|||
t.Fatal("not SBC device type") |
|||
} |
|||
// Check vendor
|
|||
vendor := string(r.Data[8:16]) |
|||
if vendor != "SeaweedF" { |
|||
t.Fatalf("vendor: %q", vendor) |
|||
} |
|||
// Check CmdQue
|
|||
if r.Data[7]&0x02 == 0 { |
|||
t.Fatal("CmdQue not set") |
|||
} |
|||
} |
|||
|
|||
func testInquiryVPDSupportedPages(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiInquiry |
|||
cdb[1] = 0x01 // EVPD
|
|||
cdb[2] = 0x00 // Supported pages
|
|||
binary.BigEndian.PutUint16(cdb[3:5], 255) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatalf("status: %d", r.Status) |
|||
} |
|||
if r.Data[1] != 0x00 { |
|||
t.Fatal("wrong page code") |
|||
} |
|||
// Should list pages 0x00, 0x80, 0x83
|
|||
if len(r.Data) < 7 || r.Data[4] != 0x00 || r.Data[5] != 0x80 || r.Data[6] != 0x83 { |
|||
t.Fatalf("supported pages: %v", r.Data) |
|||
} |
|||
} |
|||
|
|||
func testInquiryVPDSerial(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiInquiry |
|||
cdb[1] = 0x01 |
|||
cdb[2] = 0x80 |
|||
binary.BigEndian.PutUint16(cdb[3:5], 255) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatalf("status: %d", r.Status) |
|||
} |
|||
if r.Data[1] != 0x80 { |
|||
t.Fatal("wrong page code") |
|||
} |
|||
} |
|||
|
|||
func testInquiryVPDDeviceID(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiInquiry |
|||
cdb[1] = 0x01 |
|||
cdb[2] = 0x83 |
|||
binary.BigEndian.PutUint16(cdb[3:5], 255) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatalf("status: %d", r.Status) |
|||
} |
|||
if r.Data[1] != 0x83 { |
|||
t.Fatal("wrong page code") |
|||
} |
|||
} |
|||
|
|||
func testInquiryVPDUnknownPage(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiInquiry |
|||
cdb[1] = 0x01 |
|||
cdb[2] = 0xFF // unknown page
|
|||
binary.BigEndian.PutUint16(cdb[3:5], 255) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("expected CHECK_CONDITION") |
|||
} |
|||
if r.SenseKey != SenseIllegalRequest { |
|||
t.Fatal("expected ILLEGAL_REQUEST") |
|||
} |
|||
} |
|||
|
|||
func testInquiryAllocLength(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiInquiry |
|||
binary.BigEndian.PutUint16(cdb[3:5], 10) // small alloc
|
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("should succeed") |
|||
} |
|||
if len(r.Data) != 10 { |
|||
t.Fatalf("truncation: expected 10, got %d", len(r.Data)) |
|||
} |
|||
} |
|||
|
|||
func testReadCapacity10(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) // 100 blocks
|
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiReadCapacity10 |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("status not good") |
|||
} |
|||
if len(r.Data) != 8 { |
|||
t.Fatalf("data length: %d", len(r.Data)) |
|||
} |
|||
lastLBA := binary.BigEndian.Uint32(r.Data[0:4]) |
|||
blockSize := binary.BigEndian.Uint32(r.Data[4:8]) |
|||
if lastLBA != 99 { |
|||
t.Fatalf("last LBA: %d, expected 99", lastLBA) |
|||
} |
|||
if blockSize != 4096 { |
|||
t.Fatalf("block size: %d", blockSize) |
|||
} |
|||
} |
|||
|
|||
func testReadCapacity10Large(t *testing.T) { |
|||
// Volume with >2^32 blocks should return 0xFFFFFFFF
|
|||
dev := newMockDevice(uint64(0x100000001) * 4096) // 2^32+1 blocks
|
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiReadCapacity10 |
|||
r := h.HandleCommand(cdb, nil) |
|||
lastLBA := binary.BigEndian.Uint32(r.Data[0:4]) |
|||
if lastLBA != 0xFFFFFFFF { |
|||
t.Fatalf("should return 0xFFFFFFFF for >2TB, got %d", lastLBA) |
|||
} |
|||
} |
|||
|
|||
func testReadCapacity16(t *testing.T) { |
|||
dev := newMockDevice(3 * 1024 * 1024 * 1024 * 1024) // 3 TB
|
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiReadCapacity16 |
|||
cdb[1] = ScsiSAReadCapacity16 |
|||
binary.BigEndian.PutUint32(cdb[10:14], 32) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("status not good") |
|||
} |
|||
lastLBA := binary.BigEndian.Uint64(r.Data[0:8]) |
|||
expectedBlocks := uint64(3*1024*1024*1024*1024) / 4096 |
|||
if lastLBA != expectedBlocks-1 { |
|||
t.Fatalf("last LBA: %d, expected %d", lastLBA, expectedBlocks-1) |
|||
} |
|||
} |
|||
|
|||
func testReadCapacity16LBPME(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiReadCapacity16 |
|||
cdb[1] = ScsiSAReadCapacity16 |
|||
binary.BigEndian.PutUint32(cdb[10:14], 32) |
|||
r := h.HandleCommand(cdb, nil) |
|||
// LBPME bit should be set (byte 14, bit 7)
|
|||
if r.Data[14]&0x80 == 0 { |
|||
t.Fatal("LBPME bit not set") |
|||
} |
|||
} |
|||
|
|||
func testModeSense6(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiModeSense6 |
|||
cdb[4] = 255 |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("status not good") |
|||
} |
|||
if len(r.Data) != 4 { |
|||
t.Fatalf("mode sense data: %d bytes", len(r.Data)) |
|||
} |
|||
// No write protect
|
|||
if r.Data[2]&0x80 != 0 { |
|||
t.Fatal("write protect set") |
|||
} |
|||
} |
|||
|
|||
func testReportLuns(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiReportLuns |
|||
binary.BigEndian.PutUint32(cdb[6:10], 256) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("status not good") |
|||
} |
|||
lunListLen := binary.BigEndian.Uint32(r.Data[0:4]) |
|||
if lunListLen != 8 { |
|||
t.Fatalf("LUN list length: %d (expected 8 for 1 LUN)", lunListLen) |
|||
} |
|||
} |
|||
|
|||
func testUnknownOpcode(t *testing.T) { |
|||
dev := newMockDevice(1024 * 1024) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = 0xFF |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("expected CHECK_CONDITION") |
|||
} |
|||
if r.SenseKey != SenseIllegalRequest { |
|||
t.Fatal("expected ILLEGAL_REQUEST") |
|||
} |
|||
} |
|||
|
|||
func testRead10(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
// Write some data first
|
|||
data := make([]byte, 4096) |
|||
for i := range data { |
|||
data[i] = 0xAB |
|||
} |
|||
dev.blocks[5] = data |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiRead10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 5) // LBA=5
|
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block
|
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("read failed") |
|||
} |
|||
if len(r.Data) != 4096 { |
|||
t.Fatalf("data length: %d", len(r.Data)) |
|||
} |
|||
if r.Data[0] != 0xAB { |
|||
t.Fatal("data mismatch") |
|||
} |
|||
} |
|||
|
|||
func testRead16(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
data := make([]byte, 4096) |
|||
data[0] = 0xCD |
|||
dev.blocks[10] = data |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiRead16 |
|||
binary.BigEndian.PutUint64(cdb[2:10], 10) |
|||
binary.BigEndian.PutUint32(cdb[10:14], 1) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("read16 failed") |
|||
} |
|||
if r.Data[0] != 0xCD { |
|||
t.Fatal("data mismatch") |
|||
} |
|||
} |
|||
|
|||
func testWrite10(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
dataOut := make([]byte, 4096) |
|||
dataOut[0] = 0xEF |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiWrite10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 7) |
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) |
|||
r := h.HandleCommand(cdb, dataOut) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("write failed") |
|||
} |
|||
if dev.blocks[7][0] != 0xEF { |
|||
t.Fatal("data not written") |
|||
} |
|||
} |
|||
|
|||
func testWrite16(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
dataOut := make([]byte, 8192) |
|||
dataOut[0] = 0x11 |
|||
dataOut[4096] = 0x22 |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiWrite16 |
|||
binary.BigEndian.PutUint64(cdb[2:10], 50) |
|||
binary.BigEndian.PutUint32(cdb[10:14], 2) |
|||
r := h.HandleCommand(cdb, dataOut) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("write16 failed") |
|||
} |
|||
if dev.blocks[50][0] != 0x11 { |
|||
t.Fatal("block 50 wrong") |
|||
} |
|||
if dev.blocks[51][0] != 0x22 { |
|||
t.Fatal("block 51 wrong") |
|||
} |
|||
} |
|||
|
|||
func testReadWriteRoundtrip(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
// Write
|
|||
dataOut := make([]byte, 4096) |
|||
for i := range dataOut { |
|||
dataOut[i] = byte(i % 256) |
|||
} |
|||
var wcdb [16]byte |
|||
wcdb[0] = ScsiWrite10 |
|||
binary.BigEndian.PutUint32(wcdb[2:6], 0) |
|||
binary.BigEndian.PutUint16(wcdb[7:9], 1) |
|||
h.HandleCommand(wcdb, dataOut) |
|||
|
|||
// Read back
|
|||
var rcdb [16]byte |
|||
rcdb[0] = ScsiRead10 |
|||
binary.BigEndian.PutUint32(rcdb[2:6], 0) |
|||
binary.BigEndian.PutUint16(rcdb[7:9], 1) |
|||
r := h.HandleCommand(rcdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("read failed") |
|||
} |
|||
for i := 0; i < 4096; i++ { |
|||
if r.Data[i] != byte(i%256) { |
|||
t.Fatalf("byte %d: got %d, want %d", i, r.Data[i], i%256) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testWriteOOB(t *testing.T) { |
|||
dev := newMockDevice(10 * 4096) // 10 blocks
|
|||
h := NewSCSIHandler(dev) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiWrite10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 9) |
|||
binary.BigEndian.PutUint16(cdb[7:9], 2) // LBA 9 + 2 blocks > 10
|
|||
r := h.HandleCommand(cdb, make([]byte, 8192)) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("should fail for OOB") |
|||
} |
|||
} |
|||
|
|||
func testReadOOB(t *testing.T) { |
|||
dev := newMockDevice(10 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiRead10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 10) // LBA 10 == total blocks
|
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("should fail for OOB") |
|||
} |
|||
} |
|||
|
|||
func testZeroLengthTransfer(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiRead10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 0) |
|||
binary.BigEndian.PutUint16(cdb[7:9], 0) // 0 blocks
|
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("zero-length read should succeed") |
|||
} |
|||
} |
|||
|
|||
func testSyncCache(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiSyncCache10 |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("sync cache failed") |
|||
} |
|||
} |
|||
|
|||
func testSyncCacheError(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
dev.syncErr = errors.New("disk error") |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiSyncCache10 |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("should fail") |
|||
} |
|||
} |
|||
|
|||
func testUnmapSingle(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
// Write data at LBA 5
|
|||
dev.blocks[5] = make([]byte, 4096) |
|||
dev.blocks[5][0] = 0xFF |
|||
|
|||
// UNMAP parameter list
|
|||
unmapData := make([]byte, 24) // 8 header + 16 descriptor
|
|||
binary.BigEndian.PutUint16(unmapData[0:2], 22) // data length
|
|||
binary.BigEndian.PutUint16(unmapData[2:4], 16) // block desc length
|
|||
binary.BigEndian.PutUint64(unmapData[8:16], 5) // LBA
|
|||
binary.BigEndian.PutUint32(unmapData[16:20], 1) // num blocks
|
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiUnmap |
|||
r := h.HandleCommand(cdb, unmapData) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("unmap failed") |
|||
} |
|||
if _, ok := dev.blocks[5]; ok { |
|||
t.Fatal("block 5 should be trimmed") |
|||
} |
|||
} |
|||
|
|||
func testUnmapMultipleDescriptors(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
dev.blocks[3] = make([]byte, 4096) |
|||
dev.blocks[7] = make([]byte, 4096) |
|||
|
|||
// 2 descriptors
|
|||
unmapData := make([]byte, 40) // 8 header + 2*16 descriptors
|
|||
binary.BigEndian.PutUint16(unmapData[0:2], 38) |
|||
binary.BigEndian.PutUint16(unmapData[2:4], 32) |
|||
// Descriptor 1: LBA=3, count=1
|
|||
binary.BigEndian.PutUint64(unmapData[8:16], 3) |
|||
binary.BigEndian.PutUint32(unmapData[16:20], 1) |
|||
// Descriptor 2: LBA=7, count=1
|
|||
binary.BigEndian.PutUint64(unmapData[24:32], 7) |
|||
binary.BigEndian.PutUint32(unmapData[32:36], 1) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiUnmap |
|||
r := h.HandleCommand(cdb, unmapData) |
|||
if r.Status != SCSIStatusGood { |
|||
t.Fatal("unmap failed") |
|||
} |
|||
if _, ok := dev.blocks[3]; ok { |
|||
t.Fatal("block 3 should be trimmed") |
|||
} |
|||
if _, ok := dev.blocks[7]; ok { |
|||
t.Fatal("block 7 should be trimmed") |
|||
} |
|||
} |
|||
|
|||
func testUnmapShortParam(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiUnmap |
|||
r := h.HandleCommand(cdb, []byte{1, 2, 3}) // too short
|
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("should fail for short unmap params") |
|||
} |
|||
} |
|||
|
|||
func testBuildSenseData(t *testing.T) { |
|||
data := BuildSenseData(SenseIllegalRequest, ASCInvalidOpcode, ASCQLuk) |
|||
if len(data) != 18 { |
|||
t.Fatalf("length: %d", len(data)) |
|||
} |
|||
if data[0] != 0x70 { |
|||
t.Fatal("response code wrong") |
|||
} |
|||
if data[2] != SenseIllegalRequest { |
|||
t.Fatal("sense key wrong") |
|||
} |
|||
if data[12] != ASCInvalidOpcode { |
|||
t.Fatal("ASC wrong") |
|||
} |
|||
} |
|||
|
|||
func testReadError(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
dev.readErr = errors.New("io error") |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiRead10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 0) |
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) |
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("should fail") |
|||
} |
|||
if r.SenseKey != SenseMediumError { |
|||
t.Fatal("should be MEDIUM_ERROR") |
|||
} |
|||
} |
|||
|
|||
func testWriteError(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
dev.writeErr = errors.New("io error") |
|||
h := NewSCSIHandler(dev) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiWrite10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 0) |
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) |
|||
r := h.HandleCommand(cdb, make([]byte, 4096)) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("should fail") |
|||
} |
|||
} |
|||
|
|||
func testReadCapacity16InvalidSA(t *testing.T) { |
|||
dev := newMockDevice(100 * 4096) |
|||
h := NewSCSIHandler(dev) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiReadCapacity16 |
|||
cdb[1] = 0x05 // wrong service action
|
|||
r := h.HandleCommand(cdb, nil) |
|||
if r.Status != SCSIStatusCheckCond { |
|||
t.Fatal("should fail for wrong SA") |
|||
} |
|||
} |
|||
@ -0,0 +1,421 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"log" |
|||
"net" |
|||
"sync" |
|||
"sync/atomic" |
|||
) |
|||
|
|||
var ( |
|||
ErrSessionClosed = errors.New("iscsi: session closed") |
|||
ErrCmdSNOutOfWindow = errors.New("iscsi: CmdSN out of window") |
|||
) |
|||
|
|||
// SessionState tracks the lifecycle of an iSCSI session.
|
|||
type SessionState int |
|||
|
|||
const ( |
|||
SessionLogin SessionState = iota // login phase
|
|||
SessionLoggedIn // full feature phase
|
|||
SessionLogout // logout requested
|
|||
SessionClosed // terminated
|
|||
) |
|||
|
|||
// Session manages a single iSCSI session (one initiator connection).
|
|||
type Session struct { |
|||
mu sync.Mutex |
|||
|
|||
state SessionState |
|||
conn net.Conn |
|||
scsi *SCSIHandler |
|||
config TargetConfig |
|||
resolver TargetResolver |
|||
devices DeviceLookup |
|||
|
|||
// Sequence numbers
|
|||
expCmdSN atomic.Uint32 // expected CmdSN from initiator
|
|||
maxCmdSN atomic.Uint32 // max CmdSN we allow
|
|||
statSN uint32 // target status sequence number
|
|||
|
|||
// Login state
|
|||
negotiator *LoginNegotiator |
|||
loginDone bool |
|||
|
|||
// Data sequencing
|
|||
dataInWriter *DataInWriter |
|||
|
|||
// Shutdown
|
|||
closed atomic.Bool |
|||
closeErr error |
|||
|
|||
// Logging
|
|||
logger *log.Logger |
|||
} |
|||
|
|||
// NewSession creates a new iSCSI session on the given connection.
|
|||
func NewSession(conn net.Conn, config TargetConfig, resolver TargetResolver, devices DeviceLookup, logger *log.Logger) *Session { |
|||
if logger == nil { |
|||
logger = log.Default() |
|||
} |
|||
s := &Session{ |
|||
state: SessionLogin, |
|||
conn: conn, |
|||
config: config, |
|||
resolver: resolver, |
|||
devices: devices, |
|||
negotiator: NewLoginNegotiator(config), |
|||
logger: logger, |
|||
} |
|||
s.expCmdSN.Store(1) |
|||
s.maxCmdSN.Store(32) // window of 32 commands
|
|||
return s |
|||
} |
|||
|
|||
// HandleConnection processes PDUs until the connection is closed or an error occurs.
|
|||
func (s *Session) HandleConnection() error { |
|||
defer s.close() |
|||
|
|||
for !s.closed.Load() { |
|||
pdu, err := ReadPDU(s.conn) |
|||
if err != nil { |
|||
if s.closed.Load() { |
|||
return nil |
|||
} |
|||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { |
|||
return nil |
|||
} |
|||
return fmt.Errorf("read PDU: %w", err) |
|||
} |
|||
|
|||
if err := s.dispatch(pdu); err != nil { |
|||
if s.closed.Load() { |
|||
return nil |
|||
} |
|||
return fmt.Errorf("dispatch %s: %w", OpcodeName(pdu.Opcode()), err) |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// Close terminates the session.
|
|||
func (s *Session) Close() error { |
|||
s.closed.Store(true) |
|||
return s.conn.Close() |
|||
} |
|||
|
|||
func (s *Session) close() { |
|||
s.closed.Store(true) |
|||
s.mu.Lock() |
|||
s.state = SessionClosed |
|||
s.mu.Unlock() |
|||
s.conn.Close() |
|||
} |
|||
|
|||
func (s *Session) dispatch(pdu *PDU) error { |
|||
op := pdu.Opcode() |
|||
|
|||
switch op { |
|||
case OpLoginReq: |
|||
return s.handleLogin(pdu) |
|||
case OpTextReq: |
|||
return s.handleText(pdu) |
|||
case OpSCSICmd: |
|||
return s.handleSCSICmd(pdu) |
|||
case OpSCSIDataOut: |
|||
// Handled inline during write command processing
|
|||
return nil |
|||
case OpNOPOut: |
|||
return s.handleNOPOut(pdu) |
|||
case OpLogoutReq: |
|||
return s.handleLogout(pdu) |
|||
case OpSCSITaskMgmt: |
|||
return s.handleTaskMgmt(pdu) |
|||
default: |
|||
s.logger.Printf("unhandled opcode: %s", OpcodeName(op)) |
|||
return s.sendReject(pdu, 0x04) // command not supported
|
|||
} |
|||
} |
|||
|
|||
func (s *Session) handleLogin(pdu *PDU) error { |
|||
s.mu.Lock() |
|||
defer s.mu.Unlock() |
|||
|
|||
resp := s.negotiator.HandleLoginPDU(pdu, s.resolver) |
|||
|
|||
// Set sequence numbers
|
|||
resp.SetStatSN(s.statSN) |
|||
s.statSN++ |
|||
resp.SetExpCmdSN(s.expCmdSN.Load()) |
|||
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
|||
|
|||
if err := WritePDU(s.conn, resp); err != nil { |
|||
return err |
|||
} |
|||
|
|||
if s.negotiator.Done() { |
|||
s.loginDone = true |
|||
s.state = SessionLoggedIn |
|||
result := s.negotiator.Result() |
|||
s.dataInWriter = NewDataInWriter(uint32(result.MaxRecvDataSegLen)) |
|||
|
|||
// Bind SCSI handler to the device for the target the initiator logged into
|
|||
if s.devices != nil && result.TargetName != "" { |
|||
dev := s.devices.LookupDevice(result.TargetName) |
|||
s.scsi = NewSCSIHandler(dev) |
|||
} |
|||
|
|||
s.logger.Printf("login complete: initiator=%s target=%s session=%s", |
|||
result.InitiatorName, result.TargetName, result.SessionType) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (s *Session) handleText(pdu *PDU) error { |
|||
s.mu.Lock() |
|||
defer s.mu.Unlock() |
|||
|
|||
// Gather discovery targets from resolver if it supports listing
|
|||
var targets []DiscoveryTarget |
|||
if lister, ok := s.resolver.(TargetLister); ok { |
|||
targets = lister.ListTargets() |
|||
} |
|||
|
|||
resp := HandleTextRequest(pdu, targets) |
|||
resp.SetStatSN(s.statSN) |
|||
s.statSN++ |
|||
resp.SetExpCmdSN(s.expCmdSN.Load()) |
|||
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
|||
|
|||
return WritePDU(s.conn, resp) |
|||
} |
|||
|
|||
func (s *Session) handleSCSICmd(pdu *PDU) error { |
|||
if !s.loginDone { |
|||
return s.sendReject(pdu, 0x0b) // protocol error
|
|||
} |
|||
|
|||
cdb := pdu.CDB() |
|||
itt := pdu.InitiatorTaskTag() |
|||
flags := pdu.OpSpecific1() |
|||
|
|||
// Advance CmdSN
|
|||
if !pdu.Immediate() { |
|||
s.advanceCmdSN() |
|||
} |
|||
|
|||
isWrite := flags&FlagW != 0 |
|||
isRead := flags&FlagR != 0 |
|||
expectedLen := pdu.ExpectedDataTransferLength() |
|||
|
|||
// Handle write commands — collect data
|
|||
var dataOut []byte |
|||
if isWrite && expectedLen > 0 { |
|||
collector := NewDataOutCollector(expectedLen) |
|||
|
|||
// Immediate data
|
|||
if len(pdu.DataSegment) > 0 { |
|||
if err := collector.AddImmediateData(pdu.DataSegment); err != nil { |
|||
return s.sendCheckCondition(itt, SenseIllegalRequest, ASCInvalidFieldInCDB, ASCQLuk) |
|||
} |
|||
} |
|||
|
|||
// If more data needed, send R2T and collect Data-Out PDUs
|
|||
if !collector.Done() { |
|||
if err := s.collectDataOut(collector, itt); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
dataOut = collector.Data() |
|||
} |
|||
|
|||
// Execute SCSI command
|
|||
result := s.scsi.HandleCommand(cdb, dataOut) |
|||
|
|||
// Send response
|
|||
s.mu.Lock() |
|||
expCmdSN := s.expCmdSN.Load() |
|||
maxCmdSN := s.maxCmdSN.Load() |
|||
s.mu.Unlock() |
|||
|
|||
if isRead && result.Status == SCSIStatusGood && len(result.Data) > 0 { |
|||
// Send Data-In PDUs
|
|||
s.mu.Lock() |
|||
_, err := s.dataInWriter.WriteDataIn(s.conn, result.Data, itt, expCmdSN, maxCmdSN, &s.statSN) |
|||
s.mu.Unlock() |
|||
return err |
|||
} |
|||
|
|||
// Send SCSI Response
|
|||
s.mu.Lock() |
|||
err := SendSCSIResponse(s.conn, result, itt, &s.statSN, expCmdSN, maxCmdSN) |
|||
s.mu.Unlock() |
|||
return err |
|||
} |
|||
|
|||
func (s *Session) collectDataOut(collector *DataOutCollector, itt uint32) error { |
|||
var r2tSN uint32 |
|||
ttt := itt // use ITT as TTT for simplicity
|
|||
|
|||
for !collector.Done() { |
|||
// Send R2T
|
|||
s.mu.Lock() |
|||
r2t := BuildR2T(itt, ttt, r2tSN, s.totalReceived(collector), collector.Remaining(), |
|||
s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load()) |
|||
s.mu.Unlock() |
|||
|
|||
if err := WritePDU(s.conn, r2t); err != nil { |
|||
return err |
|||
} |
|||
r2tSN++ |
|||
|
|||
// Read Data-Out PDUs until F-bit
|
|||
for { |
|||
doPDU, err := ReadPDU(s.conn) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if doPDU.Opcode() != OpSCSIDataOut { |
|||
return fmt.Errorf("expected Data-Out, got %s", OpcodeName(doPDU.Opcode())) |
|||
} |
|||
if err := collector.AddDataOut(doPDU); err != nil { |
|||
return err |
|||
} |
|||
if doPDU.OpSpecific1()&FlagF != 0 { |
|||
break |
|||
} |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (s *Session) totalReceived(c *DataOutCollector) uint32 { |
|||
return c.expectedLen - c.Remaining() |
|||
} |
|||
|
|||
func (s *Session) handleNOPOut(pdu *PDU) error { |
|||
resp := &PDU{} |
|||
resp.SetOpcode(OpNOPIn) |
|||
resp.SetOpSpecific1(FlagF) |
|||
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) |
|||
resp.SetTargetTransferTag(0xFFFFFFFF) |
|||
|
|||
s.mu.Lock() |
|||
resp.SetStatSN(s.statSN) |
|||
s.statSN++ |
|||
resp.SetExpCmdSN(s.expCmdSN.Load()) |
|||
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
|||
s.mu.Unlock() |
|||
|
|||
// Echo back data if present
|
|||
if len(pdu.DataSegment) > 0 { |
|||
resp.DataSegment = pdu.DataSegment |
|||
} |
|||
|
|||
return WritePDU(s.conn, resp) |
|||
} |
|||
|
|||
func (s *Session) handleLogout(pdu *PDU) error { |
|||
s.mu.Lock() |
|||
defer s.mu.Unlock() |
|||
|
|||
s.state = SessionLogout |
|||
|
|||
resp := &PDU{} |
|||
resp.SetOpcode(OpLogoutResp) |
|||
resp.SetOpSpecific1(FlagF) |
|||
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) |
|||
resp.BHS[2] = 0x00 // response: connection/session closed successfully
|
|||
resp.SetStatSN(s.statSN) |
|||
s.statSN++ |
|||
resp.SetExpCmdSN(s.expCmdSN.Load()) |
|||
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
|||
|
|||
if err := WritePDU(s.conn, resp); err != nil { |
|||
return err |
|||
} |
|||
|
|||
// Signal HandleConnection to exit
|
|||
s.closed.Store(true) |
|||
s.conn.Close() |
|||
return nil |
|||
} |
|||
|
|||
func (s *Session) handleTaskMgmt(pdu *PDU) error { |
|||
s.mu.Lock() |
|||
defer s.mu.Unlock() |
|||
|
|||
// Simplified: always respond with "function complete"
|
|||
resp := &PDU{} |
|||
resp.SetOpcode(OpSCSITaskResp) |
|||
resp.SetOpSpecific1(FlagF) |
|||
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) |
|||
resp.BHS[2] = 0x00 // function complete
|
|||
resp.SetStatSN(s.statSN) |
|||
s.statSN++ |
|||
resp.SetExpCmdSN(s.expCmdSN.Load()) |
|||
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
|||
|
|||
return WritePDU(s.conn, resp) |
|||
} |
|||
|
|||
func (s *Session) advanceCmdSN() { |
|||
s.expCmdSN.Add(1) |
|||
s.maxCmdSN.Add(1) |
|||
} |
|||
|
|||
func (s *Session) sendReject(origPDU *PDU, reason uint8) error { |
|||
resp := &PDU{} |
|||
resp.SetOpcode(OpReject) |
|||
resp.SetOpSpecific1(FlagF) |
|||
resp.BHS[2] = reason |
|||
resp.SetInitiatorTaskTag(0xFFFFFFFF) |
|||
|
|||
s.mu.Lock() |
|||
resp.SetStatSN(s.statSN) |
|||
s.statSN++ |
|||
resp.SetExpCmdSN(s.expCmdSN.Load()) |
|||
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
|||
s.mu.Unlock() |
|||
|
|||
// Include the rejected BHS in the data segment
|
|||
resp.DataSegment = origPDU.BHS[:] |
|||
|
|||
return WritePDU(s.conn, resp) |
|||
} |
|||
|
|||
func (s *Session) sendCheckCondition(itt uint32, senseKey, asc, ascq uint8) error { |
|||
result := SCSIResult{ |
|||
Status: SCSIStatusCheckCond, |
|||
SenseKey: senseKey, |
|||
SenseASC: asc, |
|||
SenseASCQ: ascq, |
|||
} |
|||
s.mu.Lock() |
|||
defer s.mu.Unlock() |
|||
return SendSCSIResponse(s.conn, result, itt, &s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load()) |
|||
} |
|||
|
|||
// TargetLister is an optional interface that TargetResolver can implement
|
|||
// to support SendTargets discovery.
|
|||
type TargetLister interface { |
|||
ListTargets() []DiscoveryTarget |
|||
} |
|||
|
|||
// DeviceLookup resolves a target IQN to a BlockDevice.
|
|||
// Used after login to bind the session to the correct volume.
|
|||
type DeviceLookup interface { |
|||
LookupDevice(iqn string) BlockDevice |
|||
} |
|||
|
|||
// State returns the current session state.
|
|||
func (s *Session) State() SessionState { |
|||
s.mu.Lock() |
|||
defer s.mu.Unlock() |
|||
return s.state |
|||
} |
|||
@ -0,0 +1,364 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"io" |
|||
"log" |
|||
"net" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
// testResolver implements TargetResolver, TargetLister, and DeviceLookup.
|
|||
type testResolver struct { |
|||
targets []DiscoveryTarget |
|||
dev BlockDevice |
|||
} |
|||
|
|||
func (r *testResolver) HasTarget(name string) bool { |
|||
for _, t := range r.targets { |
|||
if t.Name == name { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
func (r *testResolver) ListTargets() []DiscoveryTarget { |
|||
return r.targets |
|||
} |
|||
|
|||
func (r *testResolver) LookupDevice(iqn string) BlockDevice { |
|||
if r.dev != nil { |
|||
return r.dev |
|||
} |
|||
return &nullDevice{} |
|||
} |
|||
|
|||
func newTestResolver() *testResolver { |
|||
return newTestResolverWithDevice(nil) |
|||
} |
|||
|
|||
func newTestResolverWithDevice(dev BlockDevice) *testResolver { |
|||
return &testResolver{ |
|||
targets: []DiscoveryTarget{ |
|||
{Name: testTargetName, Address: "127.0.0.1:3260,1"}, |
|||
}, |
|||
dev: dev, |
|||
} |
|||
} |
|||
|
|||
func TestSession(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"login_and_read", testLoginAndRead}, |
|||
{"login_and_write", testLoginAndWrite}, |
|||
{"nop_ping", testNOPPing}, |
|||
{"logout", testLogout}, |
|||
{"discovery_session", testDiscoverySession}, |
|||
{"task_mgmt", testTaskMgmt}, |
|||
{"reject_scsi_before_login", testRejectSCSIBeforeLogin}, |
|||
{"connection_close", testConnectionClose}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
type sessionTestEnv struct { |
|||
clientConn net.Conn |
|||
session *Session |
|||
done chan error |
|||
} |
|||
|
|||
func setupSession(t *testing.T) *sessionTestEnv { |
|||
t.Helper() |
|||
client, server := net.Pipe() |
|||
|
|||
dev := newMockDevice(1024 * 4096) // 1024 blocks
|
|||
config := DefaultTargetConfig() |
|||
config.TargetName = testTargetName |
|||
resolver := newTestResolverWithDevice(dev) |
|||
logger := log.New(io.Discard, "", 0) |
|||
|
|||
sess := NewSession(server, config, resolver, resolver, logger) |
|||
done := make(chan error, 1) |
|||
go func() { |
|||
done <- sess.HandleConnection() |
|||
}() |
|||
|
|||
t.Cleanup(func() { |
|||
sess.Close() |
|||
client.Close() |
|||
}) |
|||
|
|||
return &sessionTestEnv{clientConn: client, session: sess, done: done} |
|||
} |
|||
|
|||
func doLogin(t *testing.T, conn net.Conn) { |
|||
t.Helper() |
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
params.Set("SessionType", "Normal") |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
req.SetCmdSN(1) |
|||
|
|||
if err := WritePDU(conn, req); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := ReadPDU(conn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpLoginResp { |
|||
t.Fatalf("expected LoginResp, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
|||
} |
|||
} |
|||
|
|||
func testLoginAndRead(t *testing.T) { |
|||
env := setupSession(t) |
|||
doLogin(t, env.clientConn) |
|||
|
|||
// Send SCSI READ_10 for 1 block at LBA 0
|
|||
cmd := &PDU{} |
|||
cmd.SetOpcode(OpSCSICmd) |
|||
cmd.SetOpSpecific1(FlagF | FlagR) // Final + Read
|
|||
cmd.SetInitiatorTaskTag(1) |
|||
cmd.SetExpectedDataTransferLength(4096) |
|||
cmd.SetCmdSN(2) |
|||
cmd.SetExpStatSN(2) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiRead10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 0) // LBA=0
|
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block
|
|||
cmd.SetCDB(cdb) |
|||
|
|||
if err := WritePDU(env.clientConn, cmd); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
// Expect Data-In with S-bit (single PDU, 4096 bytes of zeros)
|
|||
resp, err := ReadPDU(env.clientConn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpSCSIDataIn { |
|||
t.Fatalf("expected Data-In, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
if resp.OpSpecific1()&FlagS == 0 { |
|||
t.Fatal("S-bit not set") |
|||
} |
|||
if len(resp.DataSegment) != 4096 { |
|||
t.Fatalf("data length: %d", len(resp.DataSegment)) |
|||
} |
|||
} |
|||
|
|||
func testLoginAndWrite(t *testing.T) { |
|||
env := setupSession(t) |
|||
doLogin(t, env.clientConn) |
|||
|
|||
// SCSI WRITE_10: 1 block at LBA 5 with immediate data
|
|||
cmd := &PDU{} |
|||
cmd.SetOpcode(OpSCSICmd) |
|||
cmd.SetOpSpecific1(FlagF | FlagW) // Final + Write
|
|||
cmd.SetInitiatorTaskTag(1) |
|||
cmd.SetExpectedDataTransferLength(4096) |
|||
cmd.SetCmdSN(2) |
|||
|
|||
var cdb [16]byte |
|||
cdb[0] = ScsiWrite10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 5) |
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) |
|||
cmd.SetCDB(cdb) |
|||
|
|||
// Include immediate data
|
|||
cmd.DataSegment = bytes.Repeat([]byte{0xAA}, 4096) |
|||
|
|||
if err := WritePDU(env.clientConn, cmd); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
// Expect SCSI Response (good)
|
|||
resp, err := ReadPDU(env.clientConn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpSCSIResp { |
|||
t.Fatalf("expected SCSI Response, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
if resp.SCSIStatus() != SCSIStatusGood { |
|||
t.Fatalf("status: %d", resp.SCSIStatus()) |
|||
} |
|||
} |
|||
|
|||
func testNOPPing(t *testing.T) { |
|||
env := setupSession(t) |
|||
doLogin(t, env.clientConn) |
|||
|
|||
// Send NOP-Out
|
|||
nop := &PDU{} |
|||
nop.SetOpcode(OpNOPOut) |
|||
nop.SetOpSpecific1(FlagF) |
|||
nop.SetInitiatorTaskTag(0x9999) |
|||
nop.SetImmediate(true) |
|||
nop.DataSegment = []byte("ping") |
|||
|
|||
if err := WritePDU(env.clientConn, nop); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := ReadPDU(env.clientConn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpNOPIn { |
|||
t.Fatalf("expected NOP-In, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
if resp.InitiatorTaskTag() != 0x9999 { |
|||
t.Fatal("ITT mismatch") |
|||
} |
|||
if string(resp.DataSegment) != "ping" { |
|||
t.Fatalf("echo data: %q", resp.DataSegment) |
|||
} |
|||
} |
|||
|
|||
func testLogout(t *testing.T) { |
|||
env := setupSession(t) |
|||
doLogin(t, env.clientConn) |
|||
|
|||
// Send Logout
|
|||
logout := &PDU{} |
|||
logout.SetOpcode(OpLogoutReq) |
|||
logout.SetOpSpecific1(FlagF) |
|||
logout.SetInitiatorTaskTag(0xAAAA) |
|||
logout.SetCmdSN(2) |
|||
|
|||
if err := WritePDU(env.clientConn, logout); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := ReadPDU(env.clientConn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpLogoutResp { |
|||
t.Fatalf("expected Logout Response, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
|
|||
// After logout, the connection should be closed by the target.
|
|||
// Verify by trying to read — should get EOF.
|
|||
_, err = ReadPDU(env.clientConn) |
|||
if err == nil { |
|||
t.Fatal("expected EOF after logout") |
|||
} |
|||
} |
|||
|
|||
func testDiscoverySession(t *testing.T) { |
|||
env := setupSession(t) |
|||
|
|||
// Login as discovery session
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("SessionType", "Discovery") |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
req.SetCmdSN(1) |
|||
WritePDU(env.clientConn, req) |
|||
ReadPDU(env.clientConn) // login resp
|
|||
|
|||
// Send SendTargets=All
|
|||
textParams := NewParams() |
|||
textParams.Set("SendTargets", "All") |
|||
textReq := makeTextReq(textParams) |
|||
textReq.SetCmdSN(2) |
|||
|
|||
if err := WritePDU(env.clientConn, textReq); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := ReadPDU(env.clientConn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpTextResp { |
|||
t.Fatalf("expected Text Response, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
body := string(resp.DataSegment) |
|||
if len(body) == 0 { |
|||
t.Fatal("empty discovery response") |
|||
} |
|||
} |
|||
|
|||
func testTaskMgmt(t *testing.T) { |
|||
env := setupSession(t) |
|||
doLogin(t, env.clientConn) |
|||
|
|||
tm := &PDU{} |
|||
tm.SetOpcode(OpSCSITaskMgmt) |
|||
tm.SetOpSpecific1(FlagF | 0x01) // ABORT TASK
|
|||
tm.SetInitiatorTaskTag(0xBBBB) |
|||
tm.SetImmediate(true) |
|||
|
|||
if err := WritePDU(env.clientConn, tm); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := ReadPDU(env.clientConn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpSCSITaskResp { |
|||
t.Fatalf("expected Task Mgmt Response, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
} |
|||
|
|||
func testRejectSCSIBeforeLogin(t *testing.T) { |
|||
env := setupSession(t) |
|||
|
|||
// Send SCSI command without login
|
|||
cmd := &PDU{} |
|||
cmd.SetOpcode(OpSCSICmd) |
|||
cmd.SetOpSpecific1(FlagF | FlagR) |
|||
cmd.SetInitiatorTaskTag(1) |
|||
|
|||
if err := WritePDU(env.clientConn, cmd); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := ReadPDU(env.clientConn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.Opcode() != OpReject { |
|||
t.Fatalf("expected Reject, got %s", OpcodeName(resp.Opcode())) |
|||
} |
|||
} |
|||
|
|||
func testConnectionClose(t *testing.T) { |
|||
env := setupSession(t) |
|||
doLogin(t, env.clientConn) |
|||
|
|||
// Close client side
|
|||
env.clientConn.Close() |
|||
|
|||
select { |
|||
case err := <-env.done: |
|||
if err != nil { |
|||
t.Fatalf("unexpected error on clean close: %v", err) |
|||
} |
|||
case <-time.After(2 * time.Second): |
|||
t.Fatal("session did not detect close") |
|||
} |
|||
} |
|||
@ -0,0 +1,216 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"log" |
|||
"net" |
|||
"sync" |
|||
"sync/atomic" |
|||
) |
|||
|
|||
var ( |
|||
ErrTargetClosed = errors.New("iscsi: target server closed") |
|||
ErrVolumeNotFound = errors.New("iscsi: volume not found") |
|||
) |
|||
|
|||
// TargetServer manages the iSCSI target: TCP listener, volume registry,
|
|||
// and active sessions.
|
|||
type TargetServer struct { |
|||
mu sync.RWMutex |
|||
listener net.Listener |
|||
config TargetConfig |
|||
volumes map[string]BlockDevice // target IQN -> device
|
|||
addr string |
|||
|
|||
// Active session tracking for graceful shutdown
|
|||
activeMu sync.Mutex |
|||
active map[uint64]*Session |
|||
nextID atomic.Uint64 |
|||
|
|||
sessions sync.WaitGroup |
|||
closed chan struct{} |
|||
logger *log.Logger |
|||
} |
|||
|
|||
// NewTargetServer creates a target server bound to the given address.
|
|||
func NewTargetServer(addr string, config TargetConfig, logger *log.Logger) *TargetServer { |
|||
if logger == nil { |
|||
logger = log.Default() |
|||
} |
|||
return &TargetServer{ |
|||
config: config, |
|||
volumes: make(map[string]BlockDevice), |
|||
active: make(map[uint64]*Session), |
|||
addr: addr, |
|||
closed: make(chan struct{}), |
|||
logger: logger, |
|||
} |
|||
} |
|||
|
|||
// AddVolume registers a block device under the given target IQN.
|
|||
func (ts *TargetServer) AddVolume(iqn string, dev BlockDevice) { |
|||
ts.mu.Lock() |
|||
defer ts.mu.Unlock() |
|||
ts.volumes[iqn] = dev |
|||
ts.logger.Printf("volume added: %s", iqn) |
|||
} |
|||
|
|||
// RemoveVolume unregisters a target IQN.
|
|||
func (ts *TargetServer) RemoveVolume(iqn string) { |
|||
ts.mu.Lock() |
|||
defer ts.mu.Unlock() |
|||
delete(ts.volumes, iqn) |
|||
ts.logger.Printf("volume removed: %s", iqn) |
|||
} |
|||
|
|||
// HasTarget implements TargetResolver.
|
|||
func (ts *TargetServer) HasTarget(name string) bool { |
|||
ts.mu.RLock() |
|||
defer ts.mu.RUnlock() |
|||
_, ok := ts.volumes[name] |
|||
return ok |
|||
} |
|||
|
|||
// ListTargets implements TargetLister.
|
|||
func (ts *TargetServer) ListTargets() []DiscoveryTarget { |
|||
ts.mu.RLock() |
|||
defer ts.mu.RUnlock() |
|||
targets := make([]DiscoveryTarget, 0, len(ts.volumes)) |
|||
for iqn := range ts.volumes { |
|||
targets = append(targets, DiscoveryTarget{ |
|||
Name: iqn, |
|||
Address: ts.ListenAddr(), |
|||
}) |
|||
} |
|||
return targets |
|||
} |
|||
|
|||
// ListenAndServe starts listening for iSCSI connections.
|
|||
// Blocks until Close() is called or an error occurs.
|
|||
func (ts *TargetServer) ListenAndServe() error { |
|||
ln, err := net.Listen("tcp", ts.addr) |
|||
if err != nil { |
|||
return fmt.Errorf("iscsi target: listen: %w", err) |
|||
} |
|||
ts.mu.Lock() |
|||
ts.listener = ln |
|||
ts.mu.Unlock() |
|||
|
|||
ts.logger.Printf("iSCSI target listening on %s", ln.Addr()) |
|||
return ts.acceptLoop(ln) |
|||
} |
|||
|
|||
// Serve accepts connections on an existing listener.
|
|||
func (ts *TargetServer) Serve(ln net.Listener) error { |
|||
ts.mu.Lock() |
|||
ts.listener = ln |
|||
ts.mu.Unlock() |
|||
return ts.acceptLoop(ln) |
|||
} |
|||
|
|||
func (ts *TargetServer) acceptLoop(ln net.Listener) error { |
|||
for { |
|||
conn, err := ln.Accept() |
|||
if err != nil { |
|||
select { |
|||
case <-ts.closed: |
|||
return nil |
|||
default: |
|||
return fmt.Errorf("iscsi target: accept: %w", err) |
|||
} |
|||
} |
|||
|
|||
ts.sessions.Add(1) |
|||
go ts.handleConn(conn) |
|||
} |
|||
} |
|||
|
|||
func (ts *TargetServer) handleConn(conn net.Conn) { |
|||
defer ts.sessions.Done() |
|||
defer conn.Close() |
|||
|
|||
ts.logger.Printf("new connection from %s", conn.RemoteAddr()) |
|||
|
|||
sess := NewSession(conn, ts.config, ts, ts, ts.logger) |
|||
|
|||
id := ts.nextID.Add(1) |
|||
ts.activeMu.Lock() |
|||
ts.active[id] = sess |
|||
ts.activeMu.Unlock() |
|||
|
|||
defer func() { |
|||
ts.activeMu.Lock() |
|||
delete(ts.active, id) |
|||
ts.activeMu.Unlock() |
|||
}() |
|||
|
|||
if err := sess.HandleConnection(); err != nil { |
|||
ts.logger.Printf("session error (%s): %v", conn.RemoteAddr(), err) |
|||
} |
|||
|
|||
ts.logger.Printf("connection closed: %s", conn.RemoteAddr()) |
|||
} |
|||
|
|||
// LookupDevice implements DeviceLookup. Returns the BlockDevice for the given IQN,
|
|||
// or a nullDevice if not found.
|
|||
func (ts *TargetServer) LookupDevice(iqn string) BlockDevice { |
|||
ts.mu.RLock() |
|||
defer ts.mu.RUnlock() |
|||
|
|||
if dev, ok := ts.volumes[iqn]; ok { |
|||
return dev |
|||
} |
|||
return &nullDevice{} |
|||
} |
|||
|
|||
// Close gracefully shuts down the target server.
|
|||
func (ts *TargetServer) Close() error { |
|||
select { |
|||
case <-ts.closed: |
|||
return nil |
|||
default: |
|||
} |
|||
close(ts.closed) |
|||
|
|||
ts.mu.RLock() |
|||
ln := ts.listener |
|||
ts.mu.RUnlock() |
|||
|
|||
if ln != nil { |
|||
ln.Close() |
|||
} |
|||
|
|||
// Close all active sessions to unblock ReadPDU
|
|||
ts.activeMu.Lock() |
|||
for _, sess := range ts.active { |
|||
sess.Close() |
|||
} |
|||
ts.activeMu.Unlock() |
|||
|
|||
ts.sessions.Wait() |
|||
return nil |
|||
} |
|||
|
|||
// ListenAddr returns the actual listen address (useful when port=0).
|
|||
func (ts *TargetServer) ListenAddr() string { |
|||
ts.mu.RLock() |
|||
defer ts.mu.RUnlock() |
|||
if ts.listener != nil { |
|||
return ts.listener.Addr().String() |
|||
} |
|||
return ts.addr |
|||
} |
|||
|
|||
// nullDevice is a stub BlockDevice used when no volumes are registered.
|
|||
type nullDevice struct{} |
|||
|
|||
func (d *nullDevice) ReadAt(lba uint64, length uint32) ([]byte, error) { |
|||
return make([]byte, length), nil |
|||
} |
|||
func (d *nullDevice) WriteAt(lba uint64, data []byte) error { return nil } |
|||
func (d *nullDevice) Trim(lba uint64, length uint32) error { return nil } |
|||
func (d *nullDevice) SyncCache() error { return nil } |
|||
func (d *nullDevice) BlockSize() uint32 { return 4096 } |
|||
func (d *nullDevice) VolumeSize() uint64 { return 0 } |
|||
func (d *nullDevice) IsHealthy() bool { return false } |
|||
@ -0,0 +1,259 @@ |
|||
package iscsi |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"io" |
|||
"log" |
|||
"net" |
|||
"strings" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestTarget(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{"listen_and_connect", testListenAndConnect}, |
|||
{"discovery_via_target", testDiscoveryViaTarget}, |
|||
{"login_read_write", testTargetLoginReadWrite}, |
|||
{"graceful_shutdown", testGracefulShutdown}, |
|||
{"add_remove_volume", testAddRemoveVolume}, |
|||
{"multiple_connections", testMultipleConnections}, |
|||
{"connect_no_volumes", testConnectNoVolumes}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func setupTarget(t *testing.T) (*TargetServer, string) { |
|||
t.Helper() |
|||
config := DefaultTargetConfig() |
|||
config.TargetName = testTargetName |
|||
logger := log.New(io.Discard, "", 0) |
|||
ts := NewTargetServer("127.0.0.1:0", config, logger) |
|||
|
|||
dev := newMockDevice(256 * 4096) // 256 blocks = 1MB
|
|||
ts.AddVolume(testTargetName, dev) |
|||
|
|||
ln, err := net.Listen("tcp", "127.0.0.1:0") |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
go ts.Serve(ln) |
|||
|
|||
addr := ln.Addr().String() |
|||
t.Cleanup(func() { |
|||
ts.Close() |
|||
}) |
|||
|
|||
return ts, addr |
|||
} |
|||
|
|||
func dialTarget(t *testing.T, addr string) net.Conn { |
|||
t.Helper() |
|||
conn, err := net.DialTimeout("tcp", addr, 2*time.Second) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
t.Cleanup(func() { conn.Close() }) |
|||
return conn |
|||
} |
|||
|
|||
func loginToTarget(t *testing.T, conn net.Conn) { |
|||
t.Helper() |
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("TargetName", testTargetName) |
|||
params.Set("SessionType", "Normal") |
|||
|
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
req.SetCmdSN(1) |
|||
|
|||
if err := WritePDU(conn, req); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
resp, err := ReadPDU(conn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
|||
} |
|||
} |
|||
|
|||
func testListenAndConnect(t *testing.T) { |
|||
_, addr := setupTarget(t) |
|||
|
|||
conn := dialTarget(t, addr) |
|||
loginToTarget(t, conn) |
|||
} |
|||
|
|||
func testDiscoveryViaTarget(t *testing.T) { |
|||
_, addr := setupTarget(t) |
|||
|
|||
conn := dialTarget(t, addr) |
|||
|
|||
// Login as discovery
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("SessionType", "Discovery") |
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
req.SetCmdSN(1) |
|||
WritePDU(conn, req) |
|||
ReadPDU(conn) |
|||
|
|||
// SendTargets=All
|
|||
textParams := NewParams() |
|||
textParams.Set("SendTargets", "All") |
|||
textReq := makeTextReq(textParams) |
|||
textReq.SetCmdSN(2) |
|||
WritePDU(conn, textReq) |
|||
|
|||
resp, err := ReadPDU(conn) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
body := string(resp.DataSegment) |
|||
if !strings.Contains(body, testTargetName) { |
|||
t.Fatalf("discovery response missing target: %q", body) |
|||
} |
|||
} |
|||
|
|||
func testTargetLoginReadWrite(t *testing.T) { |
|||
_, addr := setupTarget(t) |
|||
conn := dialTarget(t, addr) |
|||
loginToTarget(t, conn) |
|||
|
|||
// Write 1 block at LBA 0
|
|||
cmd := &PDU{} |
|||
cmd.SetOpcode(OpSCSICmd) |
|||
cmd.SetOpSpecific1(FlagF | FlagW) |
|||
cmd.SetInitiatorTaskTag(1) |
|||
cmd.SetExpectedDataTransferLength(4096) |
|||
cmd.SetCmdSN(2) |
|||
var cdb [16]byte |
|||
cdb[0] = ScsiWrite10 |
|||
binary.BigEndian.PutUint32(cdb[2:6], 0) |
|||
binary.BigEndian.PutUint16(cdb[7:9], 1) |
|||
cmd.SetCDB(cdb) |
|||
data := make([]byte, 4096) |
|||
for i := range data { |
|||
data[i] = 0xAB |
|||
} |
|||
cmd.DataSegment = data |
|||
|
|||
WritePDU(conn, cmd) |
|||
resp, _ := ReadPDU(conn) |
|||
if resp.SCSIStatus() != SCSIStatusGood { |
|||
t.Fatalf("write failed: %d", resp.SCSIStatus()) |
|||
} |
|||
|
|||
// Read it back
|
|||
cmd2 := &PDU{} |
|||
cmd2.SetOpcode(OpSCSICmd) |
|||
cmd2.SetOpSpecific1(FlagF | FlagR) |
|||
cmd2.SetInitiatorTaskTag(2) |
|||
cmd2.SetExpectedDataTransferLength(4096) |
|||
cmd2.SetCmdSN(3) |
|||
var cdb2 [16]byte |
|||
cdb2[0] = ScsiRead10 |
|||
binary.BigEndian.PutUint32(cdb2[2:6], 0) |
|||
binary.BigEndian.PutUint16(cdb2[7:9], 1) |
|||
cmd2.SetCDB(cdb2) |
|||
|
|||
WritePDU(conn, cmd2) |
|||
resp2, _ := ReadPDU(conn) |
|||
if resp2.Opcode() != OpSCSIDataIn { |
|||
t.Fatalf("expected Data-In, got %s", OpcodeName(resp2.Opcode())) |
|||
} |
|||
if resp2.DataSegment[0] != 0xAB { |
|||
t.Fatal("data mismatch") |
|||
} |
|||
} |
|||
|
|||
func testGracefulShutdown(t *testing.T) { |
|||
ts, addr := setupTarget(t) |
|||
|
|||
conn := dialTarget(t, addr) |
|||
loginToTarget(t, conn) |
|||
|
|||
// Close the target — should shut down cleanly
|
|||
ts.Close() |
|||
|
|||
// Connection should be dropped
|
|||
_, err := ReadPDU(conn) |
|||
if err == nil { |
|||
t.Fatal("expected error after shutdown") |
|||
} |
|||
} |
|||
|
|||
func testAddRemoveVolume(t *testing.T) { |
|||
config := DefaultTargetConfig() |
|||
logger := log.New(io.Discard, "", 0) |
|||
ts := NewTargetServer("127.0.0.1:0", config, logger) |
|||
|
|||
ts.AddVolume("iqn.2024.com.test:v1", newMockDevice(1024*4096)) |
|||
ts.AddVolume("iqn.2024.com.test:v2", newMockDevice(1024*4096)) |
|||
|
|||
if !ts.HasTarget("iqn.2024.com.test:v1") { |
|||
t.Fatal("v1 should exist") |
|||
} |
|||
|
|||
ts.RemoveVolume("iqn.2024.com.test:v1") |
|||
if ts.HasTarget("iqn.2024.com.test:v1") { |
|||
t.Fatal("v1 should be removed") |
|||
} |
|||
if !ts.HasTarget("iqn.2024.com.test:v2") { |
|||
t.Fatal("v2 should still exist") |
|||
} |
|||
|
|||
targets := ts.ListTargets() |
|||
if len(targets) != 1 { |
|||
t.Fatalf("expected 1 target, got %d", len(targets)) |
|||
} |
|||
} |
|||
|
|||
func testMultipleConnections(t *testing.T) { |
|||
_, addr := setupTarget(t) |
|||
|
|||
// Connect multiple clients
|
|||
for i := 0; i < 5; i++ { |
|||
conn := dialTarget(t, addr) |
|||
loginToTarget(t, conn) |
|||
} |
|||
} |
|||
|
|||
func testConnectNoVolumes(t *testing.T) { |
|||
config := DefaultTargetConfig() |
|||
logger := log.New(io.Discard, "", 0) |
|||
ts := NewTargetServer("127.0.0.1:0", config, logger) |
|||
// No volumes added
|
|||
|
|||
ln, err := net.Listen("tcp", "127.0.0.1:0") |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
go ts.Serve(ln) |
|||
t.Cleanup(func() { ts.Close() }) |
|||
|
|||
conn := dialTarget(t, ln.Addr().String()) |
|||
|
|||
// Login should work (discovery is fine without volumes)
|
|||
params := NewParams() |
|||
params.Set("InitiatorName", testInitiatorName) |
|||
params.Set("SessionType", "Discovery") |
|||
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
|||
req.SetCmdSN(1) |
|||
WritePDU(conn, req) |
|||
resp, _ := ReadPDU(conn) |
|||
if resp.LoginStatusClass() != LoginStatusSuccess { |
|||
t.Fatal("discovery login should succeed without volumes") |
|||
} |
|||
} |
|||
@ -0,0 +1,38 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
) |
|||
|
|||
var ( |
|||
ErrLBAOutOfBounds = errors.New("blockvol: LBA out of bounds") |
|||
ErrWritePastEnd = errors.New("blockvol: write extends past volume end") |
|||
ErrAlignment = errors.New("blockvol: data length not aligned to block size") |
|||
) |
|||
|
|||
// ValidateLBA checks that lba is within the volume's logical address space.
|
|||
func ValidateLBA(lba uint64, volumeSize uint64, blockSize uint32) error { |
|||
maxLBA := volumeSize / uint64(blockSize) |
|||
if lba >= maxLBA { |
|||
return fmt.Errorf("%w: lba=%d, max=%d", ErrLBAOutOfBounds, lba, maxLBA-1) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// ValidateWrite checks that a write at lba with dataLen bytes fits within
|
|||
// the volume and that dataLen is aligned to blockSize.
|
|||
func ValidateWrite(lba uint64, dataLen uint32, volumeSize uint64, blockSize uint32) error { |
|||
if err := ValidateLBA(lba, volumeSize, blockSize); err != nil { |
|||
return err |
|||
} |
|||
if dataLen%blockSize != 0 { |
|||
return fmt.Errorf("%w: dataLen=%d, blockSize=%d", ErrAlignment, dataLen, blockSize) |
|||
} |
|||
blocksNeeded := uint64(dataLen / blockSize) |
|||
maxLBA := volumeSize / uint64(blockSize) |
|||
if lba+blocksNeeded > maxLBA { |
|||
return fmt.Errorf("%w: lba=%d, blocks=%d, max=%d", ErrWritePastEnd, lba, blocksNeeded, maxLBA) |
|||
} |
|||
return nil |
|||
} |
|||
@ -0,0 +1,95 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"errors" |
|||
"testing" |
|||
) |
|||
|
|||
func TestLBAValidation(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "lba_within_bounds", run: testLBAWithinBounds}, |
|||
{name: "lba_last_block", run: testLBALastBlock}, |
|||
{name: "lba_out_of_bounds", run: testLBAOutOfBounds}, |
|||
{name: "lba_write_spans_end", run: testLBAWriteSpansEnd}, |
|||
{name: "lba_alignment", run: testLBAAlignment}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
const ( |
|||
testVolSize = 100 * 1024 * 1024 * 1024 // 100GB
|
|||
testBlockSize = 4096 |
|||
) |
|||
|
|||
func testLBAWithinBounds(t *testing.T) { |
|||
err := ValidateLBA(0, testVolSize, testBlockSize) |
|||
if err != nil { |
|||
t.Errorf("LBA=0 should be valid: %v", err) |
|||
} |
|||
|
|||
err = ValidateLBA(1000, testVolSize, testBlockSize) |
|||
if err != nil { |
|||
t.Errorf("LBA=1000 should be valid: %v", err) |
|||
} |
|||
} |
|||
|
|||
func testLBALastBlock(t *testing.T) { |
|||
maxLBA := uint64(testVolSize/testBlockSize) - 1 |
|||
err := ValidateLBA(maxLBA, testVolSize, testBlockSize) |
|||
if err != nil { |
|||
t.Errorf("last LBA=%d should be valid: %v", maxLBA, err) |
|||
} |
|||
} |
|||
|
|||
func testLBAOutOfBounds(t *testing.T) { |
|||
maxLBA := uint64(testVolSize / testBlockSize) |
|||
err := ValidateLBA(maxLBA, testVolSize, testBlockSize) |
|||
if !errors.Is(err, ErrLBAOutOfBounds) { |
|||
t.Errorf("LBA=%d (one past end): expected ErrLBAOutOfBounds, got %v", maxLBA, err) |
|||
} |
|||
|
|||
err = ValidateLBA(maxLBA+1000, testVolSize, testBlockSize) |
|||
if !errors.Is(err, ErrLBAOutOfBounds) { |
|||
t.Errorf("LBA far past end: expected ErrLBAOutOfBounds, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testLBAWriteSpansEnd(t *testing.T) { |
|||
maxLBA := uint64(testVolSize / testBlockSize) |
|||
// Write 2 blocks starting at last LBA -- spans past end.
|
|||
lastLBA := maxLBA - 1 |
|||
err := ValidateWrite(lastLBA, 2*testBlockSize, testVolSize, testBlockSize) |
|||
if !errors.Is(err, ErrWritePastEnd) { |
|||
t.Errorf("write spanning end: expected ErrWritePastEnd, got %v", err) |
|||
} |
|||
|
|||
// Write 1 block at last LBA -- should succeed.
|
|||
err = ValidateWrite(lastLBA, testBlockSize, testVolSize, testBlockSize) |
|||
if err != nil { |
|||
t.Errorf("write at last LBA should succeed: %v", err) |
|||
} |
|||
} |
|||
|
|||
func testLBAAlignment(t *testing.T) { |
|||
err := ValidateWrite(0, 4096, testVolSize, testBlockSize) |
|||
if err != nil { |
|||
t.Errorf("aligned write should succeed: %v", err) |
|||
} |
|||
|
|||
err = ValidateWrite(0, 4000, testVolSize, testBlockSize) |
|||
if !errors.Is(err, ErrAlignment) { |
|||
t.Errorf("unaligned write: expected ErrAlignment, got %v", err) |
|||
} |
|||
|
|||
err = ValidateWrite(0, 1, testVolSize, testBlockSize) |
|||
if !errors.Is(err, ErrAlignment) { |
|||
t.Errorf("1-byte write: expected ErrAlignment, got %v", err) |
|||
} |
|||
} |
|||
@ -0,0 +1,157 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"fmt" |
|||
"os" |
|||
) |
|||
|
|||
// RecoveryResult contains the outcome of WAL recovery.
|
|||
type RecoveryResult struct { |
|||
EntriesReplayed int // number of entries replayed into dirty map
|
|||
HighestLSN uint64 // highest LSN seen during recovery
|
|||
TornEntries int // entries discarded due to CRC failure
|
|||
} |
|||
|
|||
// RecoverWAL scans the WAL region from tail to head, replaying valid entries
|
|||
// into the dirty map. Entries with LSN <= checkpointLSN are skipped (already
|
|||
// in extent). Scanning stops at the first CRC failure (torn write).
|
|||
//
|
|||
// The WAL is a circular buffer. If head >= tail, scan [tail, head).
|
|||
// If head < tail (wrapped), scan [tail, walSize) then [0, head).
|
|||
func RecoverWAL(fd *os.File, sb *Superblock, dirtyMap *DirtyMap) (RecoveryResult, error) { |
|||
result := RecoveryResult{} |
|||
|
|||
logicalHead := sb.WALHead |
|||
logicalTail := sb.WALTail |
|||
walOffset := sb.WALOffset |
|||
walSize := sb.WALSize |
|||
checkpointLSN := sb.WALCheckpointLSN |
|||
|
|||
if logicalHead == logicalTail { |
|||
// WAL is empty (or fully flushed).
|
|||
return result, nil |
|||
} |
|||
|
|||
// Convert logical positions to physical.
|
|||
physHead := logicalHead % walSize |
|||
physTail := logicalTail % walSize |
|||
|
|||
// Build the list of byte ranges to scan.
|
|||
type scanRange struct { |
|||
start, end uint64 // physical positions within WAL
|
|||
} |
|||
|
|||
var ranges []scanRange |
|||
if physHead > physTail { |
|||
// No wrap: scan [tail, head).
|
|||
ranges = append(ranges, scanRange{physTail, physHead}) |
|||
} else if physHead == physTail { |
|||
// Head and tail at same physical position but different logical positions
|
|||
// means the WAL is completely full. Scan the entire region.
|
|||
ranges = append(ranges, scanRange{physTail, walSize}) |
|||
if physHead > 0 { |
|||
ranges = append(ranges, scanRange{0, physHead}) |
|||
} |
|||
} else { |
|||
// Wrapped: scan [tail, walSize) then [0, head).
|
|||
ranges = append(ranges, scanRange{physTail, walSize}) |
|||
if physHead > 0 { |
|||
ranges = append(ranges, scanRange{0, physHead}) |
|||
} |
|||
} |
|||
|
|||
for _, r := range ranges { |
|||
pos := r.start |
|||
for pos < r.end { |
|||
remaining := r.end - pos |
|||
|
|||
// Need at least a header to proceed.
|
|||
if remaining < uint64(walEntryHeaderSize) { |
|||
break |
|||
} |
|||
|
|||
// Read header.
|
|||
headerBuf := make([]byte, walEntryHeaderSize) |
|||
absOff := int64(walOffset + pos) |
|||
if _, err := fd.ReadAt(headerBuf, absOff); err != nil { |
|||
return result, fmt.Errorf("recovery: read header at WAL+%d: %w", pos, err) |
|||
} |
|||
|
|||
// Parse entry type and length field.
|
|||
entryType := headerBuf[16] |
|||
lengthField := parseLength(headerBuf) |
|||
|
|||
// For padding entries, skip forward.
|
|||
if entryType == EntryTypePadding { |
|||
entrySize := uint64(walEntryHeaderSize) + uint64(lengthField) |
|||
pos += entrySize |
|||
continue |
|||
} |
|||
|
|||
// Calculate on-disk entry size. WRITE and PADDING carry data payload;
|
|||
// TRIM and BARRIER do not (Length is metadata, not data size).
|
|||
var payloadLen uint64 |
|||
if entryType == EntryTypeWrite { |
|||
payloadLen = uint64(lengthField) |
|||
} |
|||
entrySize := uint64(walEntryHeaderSize) + payloadLen |
|||
if entrySize > remaining { |
|||
// Torn write: entry extends past available data.
|
|||
result.TornEntries++ |
|||
break |
|||
} |
|||
|
|||
// Read full entry.
|
|||
fullBuf := make([]byte, entrySize) |
|||
if _, err := fd.ReadAt(fullBuf, absOff); err != nil { |
|||
return result, fmt.Errorf("recovery: read entry at WAL+%d: %w", pos, err) |
|||
} |
|||
|
|||
// Decode and validate CRC.
|
|||
entry, err := DecodeWALEntry(fullBuf) |
|||
if err != nil { |
|||
// CRC failure or corrupt entry — stop here (torn write).
|
|||
result.TornEntries++ |
|||
break |
|||
} |
|||
|
|||
// Skip entries already flushed to extent.
|
|||
if entry.LSN <= checkpointLSN { |
|||
pos += entrySize |
|||
continue |
|||
} |
|||
|
|||
// Replay entry.
|
|||
switch entry.Type { |
|||
case EntryTypeWrite: |
|||
blocks := entry.Length / sb.BlockSize |
|||
for i := uint32(0); i < blocks; i++ { |
|||
dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize) |
|||
} |
|||
result.EntriesReplayed++ |
|||
|
|||
case EntryTypeTrim: |
|||
// TRIM carries Length (bytes) covering multiple blocks.
|
|||
blocks := entry.Length / sb.BlockSize |
|||
if blocks == 0 { |
|||
blocks = 1 // legacy single-block trim
|
|||
} |
|||
for i := uint32(0); i < blocks; i++ { |
|||
dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize) |
|||
} |
|||
result.EntriesReplayed++ |
|||
|
|||
case EntryTypeBarrier: |
|||
// Barriers don't modify data, just skip.
|
|||
} |
|||
|
|||
if entry.LSN > result.HighestLSN { |
|||
result.HighestLSN = entry.LSN |
|||
} |
|||
|
|||
pos += entrySize |
|||
} |
|||
} |
|||
|
|||
return result, nil |
|||
} |
|||
@ -0,0 +1,416 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"bytes" |
|||
"path/filepath" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestRecovery(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "recover_empty_wal", run: testRecoverEmptyWAL}, |
|||
{name: "recover_one_entry", run: testRecoverOneEntry}, |
|||
{name: "recover_many_entries", run: testRecoverManyEntries}, |
|||
{name: "recover_torn_write", run: testRecoverTornWrite}, |
|||
{name: "recover_after_checkpoint", run: testRecoverAfterCheckpoint}, |
|||
{name: "recover_idempotent", run: testRecoverIdempotent}, |
|||
{name: "recover_wal_full", run: testRecoverWALFull}, |
|||
{name: "recover_barrier_only", run: testRecoverBarrierOnly}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// simulateCrash closes the volume without syncing and returns the path.
|
|||
func simulateCrash(v *BlockVol) string { |
|||
path := v.Path() |
|||
v.groupCommit.Stop() |
|||
v.fd.Close() |
|||
return path |
|||
} |
|||
|
|||
func testRecoverEmptyWAL(t *testing.T) { |
|||
v := createTestVol(t) |
|||
// Sync to make superblock durable.
|
|||
v.fd.Sync() |
|||
path := simulateCrash(v) |
|||
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
// Empty volume: read should return zeros.
|
|||
got, err := v2.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA: %v", err) |
|||
} |
|||
if !bytes.Equal(got, make([]byte, 4096)) { |
|||
t.Error("expected zeros from empty volume") |
|||
} |
|||
} |
|||
|
|||
func testRecoverOneEntry(t *testing.T) { |
|||
v := createTestVol(t) |
|||
data := makeBlock('A') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
// Sync WAL to make the entry durable.
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
// Update superblock with WAL state.
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
v.fd.Seek(0, 0) |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
path := simulateCrash(v) |
|||
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
got, err := v2.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA after recovery: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("data not recovered") |
|||
} |
|||
} |
|||
|
|||
func testRecoverManyEntries(t *testing.T) { |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "many.blockvol") |
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: 4 * 1024 * 1024, // 4MB
|
|||
BlockSize: 4096, |
|||
WALSize: 2 * 1024 * 1024, // 2MB WAL -- enough for ~500 entries
|
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
|
|||
const numWrites = 400 |
|||
for i := uint64(0); i < numWrites; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte(i%26+'A'))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
v.fd.Seek(0, 0) |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
path = simulateCrash(v) |
|||
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
for i := uint64(0); i < numWrites; i++ { |
|||
got, err := v2.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d) after recovery: %v", i, err) |
|||
} |
|||
expected := makeBlock(byte(i%26 + 'A')) |
|||
if !bytes.Equal(got, expected) { |
|||
t.Errorf("block %d: data mismatch after recovery", i) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testRecoverTornWrite(t *testing.T) { |
|||
v := createTestVol(t) |
|||
|
|||
// Write 2 entries.
|
|||
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
|||
t.Fatalf("WriteLBA(0): %v", err) |
|||
} |
|||
if err := v.WriteLBA(1, makeBlock('B')); err != nil { |
|||
t.Fatalf("WriteLBA(1): %v", err) |
|||
} |
|||
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
// Save superblock with correct WAL state.
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
v.fd.Seek(0, 0) |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
// Corrupt the last 2 bytes of the second entry to simulate torn write.
|
|||
entrySize := uint64(walEntryHeaderSize + 4096) |
|||
secondEntryEnd := v.super.WALOffset + entrySize*2 |
|||
corruptOff := int64(secondEntryEnd - 2) |
|||
v.fd.WriteAt([]byte{0xFF, 0xFF}, corruptOff) |
|||
v.fd.Sync() |
|||
|
|||
path := simulateCrash(v) |
|||
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
// First entry should be recovered.
|
|||
got, err := v2.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(0) after torn recovery: %v", err) |
|||
} |
|||
if !bytes.Equal(got, makeBlock('A')) { |
|||
t.Error("block 0 should be recovered") |
|||
} |
|||
|
|||
// Second entry was torn -- should NOT be in dirty map.
|
|||
// Read returns zeros (from extent).
|
|||
got, err = v2.ReadLBA(1, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(1) after torn recovery: %v", err) |
|||
} |
|||
if !bytes.Equal(got, make([]byte, 4096)) { |
|||
t.Error("block 1 (torn) should return zeros") |
|||
} |
|||
} |
|||
|
|||
func testRecoverAfterCheckpoint(t *testing.T) { |
|||
v := createTestVol(t) |
|||
|
|||
// Write blocks 0-9.
|
|||
for i := uint64(0); i < 10; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
// Flush first 5 blocks (flusher moves them to extent, advances checkpoint).
|
|||
f := NewFlusher(FlusherConfig{ |
|||
FD: v.fd, |
|||
Super: &v.super, |
|||
WAL: v.wal, |
|||
DirtyMap: v.dirtyMap, |
|||
Interval: 1 * time.Hour, |
|||
}) |
|||
|
|||
// Flush all (flusher takes all dirty entries).
|
|||
// To simulate partial flush, we'll manually set checkpoint to midpoint.
|
|||
if err := f.FlushOnce(); err != nil { |
|||
t.Fatalf("FlushOnce: %v", err) |
|||
} |
|||
|
|||
// Now write 5 more blocks (these will need replay after crash).
|
|||
for i := uint64(10); i < 15; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
// Update superblock.
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
v.fd.Seek(0, 0) |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
path := simulateCrash(v) |
|||
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
// Blocks 0-9 should be in extent (flushed). Blocks 10-14 replayed from WAL.
|
|||
for i := uint64(0); i < 15; i++ { |
|||
got, err := v2.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d): %v", i, err) |
|||
} |
|||
expected := makeBlock(byte('A' + i)) |
|||
if !bytes.Equal(got, expected) { |
|||
t.Errorf("block %d: data mismatch after checkpoint recovery", i) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testRecoverIdempotent(t *testing.T) { |
|||
v := createTestVol(t) |
|||
data := makeBlock('X') |
|||
if err := v.WriteLBA(0, data); err != nil { |
|||
t.Fatalf("WriteLBA: %v", err) |
|||
} |
|||
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
v.fd.Seek(0, 0) |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
path := simulateCrash(v) |
|||
|
|||
// First recovery.
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol 1: %v", err) |
|||
} |
|||
|
|||
// Verify data.
|
|||
got, err := v2.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA 1: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("first recovery: data mismatch") |
|||
} |
|||
|
|||
// Close and recover again (should be idempotent).
|
|||
path2 := simulateCrash(v2) |
|||
|
|||
v3, err := OpenBlockVol(path2) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol 2: %v", err) |
|||
} |
|||
defer v3.Close() |
|||
|
|||
got, err = v3.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA 2: %v", err) |
|||
} |
|||
if !bytes.Equal(got, data) { |
|||
t.Error("second recovery: data mismatch") |
|||
} |
|||
} |
|||
|
|||
func testRecoverWALFull(t *testing.T) { |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "full.blockvol") |
|||
|
|||
entrySize := uint64(walEntryHeaderSize + 4096) |
|||
walSize := entrySize * 5 // room for exactly 5 entries
|
|||
|
|||
v, err := CreateBlockVol(path, CreateOptions{ |
|||
VolumeSize: 1 * 1024 * 1024, |
|||
BlockSize: 4096, |
|||
WALSize: walSize, |
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("CreateBlockVol: %v", err) |
|||
} |
|||
|
|||
// Write exactly 5 entries to fill WAL.
|
|||
for i := uint64(0); i < 5; i++ { |
|||
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
|||
t.Fatalf("WriteLBA(%d): %v", i, err) |
|||
} |
|||
} |
|||
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
v.fd.Seek(0, 0) |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
path = simulateCrash(v) |
|||
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
for i := uint64(0); i < 5; i++ { |
|||
got, err := v2.ReadLBA(i, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA(%d): %v", i, err) |
|||
} |
|||
expected := makeBlock(byte('A' + i)) |
|||
if !bytes.Equal(got, expected) { |
|||
t.Errorf("block %d: data mismatch after full WAL recovery", i) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testRecoverBarrierOnly(t *testing.T) { |
|||
v := createTestVol(t) |
|||
|
|||
// Write a barrier entry.
|
|||
lsn := v.nextLSN.Add(1) - 1 |
|||
entry := &WALEntry{ |
|||
LSN: lsn, |
|||
Type: EntryTypeBarrier, |
|||
LBA: 0, |
|||
} |
|||
if _, err := v.wal.Append(entry); err != nil { |
|||
t.Fatalf("Append barrier: %v", err) |
|||
} |
|||
|
|||
if err := v.SyncCache(); err != nil { |
|||
t.Fatalf("SyncCache: %v", err) |
|||
} |
|||
|
|||
v.super.WALHead = v.wal.LogicalHead() |
|||
v.super.WALTail = v.wal.LogicalTail() |
|||
v.fd.Seek(0, 0) |
|||
v.super.WriteTo(v.fd) |
|||
v.fd.Sync() |
|||
|
|||
path := simulateCrash(v) |
|||
|
|||
v2, err := OpenBlockVol(path) |
|||
if err != nil { |
|||
t.Fatalf("OpenBlockVol: %v", err) |
|||
} |
|||
defer v2.Close() |
|||
|
|||
// No data changes from barrier.
|
|||
got, err := v2.ReadLBA(0, 4096) |
|||
if err != nil { |
|||
t.Fatalf("ReadLBA: %v", err) |
|||
} |
|||
if !bytes.Equal(got, make([]byte, 4096)) { |
|||
t.Error("barrier-only WAL should leave data as zeros") |
|||
} |
|||
} |
|||
@ -0,0 +1,249 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
|
|||
"github.com/google/uuid" |
|||
) |
|||
|
|||
const ( |
|||
SuperblockSize = 4096 |
|||
MagicSWBK = "SWBK" |
|||
CurrentVersion = 1 |
|||
) |
|||
|
|||
var ( |
|||
ErrNotBlockVol = errors.New("blockvol: not a blockvol file (bad magic)") |
|||
ErrUnsupportedVersion = errors.New("blockvol: unsupported version") |
|||
ErrInvalidVolumeSize = errors.New("blockvol: volume size must be > 0") |
|||
) |
|||
|
|||
// Superblock is the 4KB header at offset 0 of a blockvol file.
|
|||
// It identifies the file format, stores volume geometry, and tracks WAL state.
|
|||
type Superblock struct { |
|||
Magic [4]byte |
|||
Version uint16 |
|||
Flags uint16 |
|||
UUID [16]byte |
|||
VolumeSize uint64 // logical size in bytes
|
|||
ExtentSize uint32 // default 64KB
|
|||
BlockSize uint32 // default 4KB (min I/O unit)
|
|||
WALOffset uint64 // byte offset where WAL region starts
|
|||
WALSize uint64 // WAL region size in bytes
|
|||
WALHead uint64 // logical WAL write position (monotonically increasing)
|
|||
WALTail uint64 // logical WAL flush position (monotonically increasing)
|
|||
WALCheckpointLSN uint64 // last LSN flushed to extent region
|
|||
Replication [4]byte |
|||
CreatedAt uint64 // unix timestamp
|
|||
SnapshotCount uint32 |
|||
} |
|||
|
|||
// superblockOnDisk is the fixed-size on-disk layout (binary.Write/Read target).
|
|||
// Must be <= SuperblockSize bytes. Remaining space is zero-padded.
|
|||
type superblockOnDisk struct { |
|||
Magic [4]byte |
|||
Version uint16 |
|||
Flags uint16 |
|||
UUID [16]byte |
|||
VolumeSize uint64 |
|||
ExtentSize uint32 |
|||
BlockSize uint32 |
|||
WALOffset uint64 |
|||
WALSize uint64 |
|||
WALHead uint64 |
|||
WALTail uint64 |
|||
WALCheckpointLSN uint64 |
|||
Replication [4]byte |
|||
CreatedAt uint64 |
|||
SnapshotCount uint32 |
|||
} |
|||
|
|||
// NewSuperblock creates a superblock with defaults and a fresh UUID.
|
|||
func NewSuperblock(volumeSize uint64, opts CreateOptions) (Superblock, error) { |
|||
if volumeSize == 0 { |
|||
return Superblock{}, ErrInvalidVolumeSize |
|||
} |
|||
|
|||
extentSize := opts.ExtentSize |
|||
if extentSize == 0 { |
|||
extentSize = 64 * 1024 // 64KB
|
|||
} |
|||
blockSize := opts.BlockSize |
|||
if blockSize == 0 { |
|||
blockSize = 4096 // 4KB
|
|||
} |
|||
walSize := opts.WALSize |
|||
if walSize == 0 { |
|||
walSize = 64 * 1024 * 1024 // 64MB
|
|||
} |
|||
|
|||
var repl [4]byte |
|||
if opts.Replication != "" { |
|||
copy(repl[:], opts.Replication) |
|||
} else { |
|||
copy(repl[:], "000") |
|||
} |
|||
|
|||
id := uuid.New() |
|||
|
|||
sb := Superblock{ |
|||
Version: CurrentVersion, |
|||
VolumeSize: volumeSize, |
|||
ExtentSize: extentSize, |
|||
BlockSize: blockSize, |
|||
WALOffset: SuperblockSize, |
|||
WALSize: walSize, |
|||
} |
|||
copy(sb.Magic[:], MagicSWBK) |
|||
sb.UUID = id |
|||
sb.Replication = repl |
|||
|
|||
return sb, nil |
|||
} |
|||
|
|||
// WriteTo serializes the superblock to w as a 4096-byte block.
|
|||
func (sb *Superblock) WriteTo(w io.Writer) (int64, error) { |
|||
buf := make([]byte, SuperblockSize) |
|||
|
|||
d := superblockOnDisk{ |
|||
Magic: sb.Magic, |
|||
Version: sb.Version, |
|||
Flags: sb.Flags, |
|||
UUID: sb.UUID, |
|||
VolumeSize: sb.VolumeSize, |
|||
ExtentSize: sb.ExtentSize, |
|||
BlockSize: sb.BlockSize, |
|||
WALOffset: sb.WALOffset, |
|||
WALSize: sb.WALSize, |
|||
WALHead: sb.WALHead, |
|||
WALTail: sb.WALTail, |
|||
WALCheckpointLSN: sb.WALCheckpointLSN, |
|||
Replication: sb.Replication, |
|||
CreatedAt: sb.CreatedAt, |
|||
SnapshotCount: sb.SnapshotCount, |
|||
} |
|||
|
|||
// Encode into beginning of buf; rest stays zero (padding).
|
|||
endian := binary.LittleEndian |
|||
off := 0 |
|||
off += copy(buf[off:], d.Magic[:]) |
|||
endian.PutUint16(buf[off:], d.Version) |
|||
off += 2 |
|||
endian.PutUint16(buf[off:], d.Flags) |
|||
off += 2 |
|||
off += copy(buf[off:], d.UUID[:]) |
|||
endian.PutUint64(buf[off:], d.VolumeSize) |
|||
off += 8 |
|||
endian.PutUint32(buf[off:], d.ExtentSize) |
|||
off += 4 |
|||
endian.PutUint32(buf[off:], d.BlockSize) |
|||
off += 4 |
|||
endian.PutUint64(buf[off:], d.WALOffset) |
|||
off += 8 |
|||
endian.PutUint64(buf[off:], d.WALSize) |
|||
off += 8 |
|||
endian.PutUint64(buf[off:], d.WALHead) |
|||
off += 8 |
|||
endian.PutUint64(buf[off:], d.WALTail) |
|||
off += 8 |
|||
endian.PutUint64(buf[off:], d.WALCheckpointLSN) |
|||
off += 8 |
|||
off += copy(buf[off:], d.Replication[:]) |
|||
endian.PutUint64(buf[off:], d.CreatedAt) |
|||
off += 8 |
|||
endian.PutUint32(buf[off:], d.SnapshotCount) |
|||
|
|||
n, err := w.Write(buf) |
|||
return int64(n), err |
|||
} |
|||
|
|||
// ReadSuperblock reads and validates a superblock from r.
|
|||
func ReadSuperblock(r io.Reader) (Superblock, error) { |
|||
buf := make([]byte, SuperblockSize) |
|||
if _, err := io.ReadFull(r, buf); err != nil { |
|||
return Superblock{}, fmt.Errorf("blockvol: read superblock: %w", err) |
|||
} |
|||
|
|||
endian := binary.LittleEndian |
|||
var sb Superblock |
|||
off := 0 |
|||
copy(sb.Magic[:], buf[off:off+4]) |
|||
off += 4 |
|||
|
|||
if string(sb.Magic[:]) != MagicSWBK { |
|||
return Superblock{}, ErrNotBlockVol |
|||
} |
|||
|
|||
sb.Version = endian.Uint16(buf[off:]) |
|||
off += 2 |
|||
if sb.Version != CurrentVersion { |
|||
return Superblock{}, fmt.Errorf("%w: got %d, want %d", ErrUnsupportedVersion, sb.Version, CurrentVersion) |
|||
} |
|||
|
|||
sb.Flags = endian.Uint16(buf[off:]) |
|||
off += 2 |
|||
copy(sb.UUID[:], buf[off:off+16]) |
|||
off += 16 |
|||
sb.VolumeSize = endian.Uint64(buf[off:]) |
|||
off += 8 |
|||
|
|||
if sb.VolumeSize == 0 { |
|||
return Superblock{}, ErrInvalidVolumeSize |
|||
} |
|||
|
|||
sb.ExtentSize = endian.Uint32(buf[off:]) |
|||
off += 4 |
|||
sb.BlockSize = endian.Uint32(buf[off:]) |
|||
off += 4 |
|||
sb.WALOffset = endian.Uint64(buf[off:]) |
|||
off += 8 |
|||
sb.WALSize = endian.Uint64(buf[off:]) |
|||
off += 8 |
|||
sb.WALHead = endian.Uint64(buf[off:]) |
|||
off += 8 |
|||
sb.WALTail = endian.Uint64(buf[off:]) |
|||
off += 8 |
|||
sb.WALCheckpointLSN = endian.Uint64(buf[off:]) |
|||
off += 8 |
|||
copy(sb.Replication[:], buf[off:off+4]) |
|||
off += 4 |
|||
sb.CreatedAt = endian.Uint64(buf[off:]) |
|||
off += 8 |
|||
sb.SnapshotCount = endian.Uint32(buf[off:]) |
|||
|
|||
return sb, nil |
|||
} |
|||
|
|||
var ErrInvalidSuperblock = errors.New("blockvol: invalid superblock") |
|||
|
|||
// Validate checks that the superblock fields are internally consistent.
|
|||
func (sb *Superblock) Validate() error { |
|||
if string(sb.Magic[:]) != MagicSWBK { |
|||
return ErrNotBlockVol |
|||
} |
|||
if sb.Version != CurrentVersion { |
|||
return fmt.Errorf("%w: got %d", ErrUnsupportedVersion, sb.Version) |
|||
} |
|||
if sb.VolumeSize == 0 { |
|||
return ErrInvalidVolumeSize |
|||
} |
|||
if sb.BlockSize == 0 { |
|||
return fmt.Errorf("%w: BlockSize is 0", ErrInvalidSuperblock) |
|||
} |
|||
if sb.ExtentSize == 0 { |
|||
return fmt.Errorf("%w: ExtentSize is 0", ErrInvalidSuperblock) |
|||
} |
|||
if sb.WALSize == 0 { |
|||
return fmt.Errorf("%w: WALSize is 0", ErrInvalidSuperblock) |
|||
} |
|||
if sb.WALOffset != SuperblockSize { |
|||
return fmt.Errorf("%w: WALOffset=%d, expected %d", ErrInvalidSuperblock, sb.WALOffset, SuperblockSize) |
|||
} |
|||
if sb.VolumeSize%uint64(sb.BlockSize) != 0 { |
|||
return fmt.Errorf("%w: VolumeSize %d not aligned to BlockSize %d", ErrInvalidSuperblock, sb.VolumeSize, sb.BlockSize) |
|||
} |
|||
return nil |
|||
} |
|||
@ -0,0 +1,146 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"errors" |
|||
"testing" |
|||
) |
|||
|
|||
func TestSuperblock(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "superblock_roundtrip", run: testSuperblockRoundtrip}, |
|||
{name: "superblock_magic_check", run: testSuperblockMagicCheck}, |
|||
{name: "superblock_version_check", run: testSuperblockVersionCheck}, |
|||
{name: "superblock_zero_vol_size", run: testSuperblockZeroVolSize}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testSuperblockRoundtrip(t *testing.T) { |
|||
sb, err := NewSuperblock(100*1024*1024*1024, CreateOptions{ |
|||
ExtentSize: 64 * 1024, |
|||
BlockSize: 4096, |
|||
WALSize: 64 * 1024 * 1024, |
|||
Replication: "010", |
|||
}) |
|||
if err != nil { |
|||
t.Fatalf("NewSuperblock: %v", err) |
|||
} |
|||
sb.CreatedAt = 1700000000 |
|||
sb.SnapshotCount = 3 |
|||
sb.WALHead = 8192 |
|||
sb.WALCheckpointLSN = 42 |
|||
|
|||
var buf bytes.Buffer |
|||
n, err := sb.WriteTo(&buf) |
|||
if err != nil { |
|||
t.Fatalf("WriteTo: %v", err) |
|||
} |
|||
if n != SuperblockSize { |
|||
t.Fatalf("WriteTo wrote %d bytes, want %d", n, SuperblockSize) |
|||
} |
|||
|
|||
got, err := ReadSuperblock(&buf) |
|||
if err != nil { |
|||
t.Fatalf("ReadSuperblock: %v", err) |
|||
} |
|||
|
|||
if string(got.Magic[:]) != MagicSWBK { |
|||
t.Errorf("Magic = %q, want %q", got.Magic, MagicSWBK) |
|||
} |
|||
if got.Version != CurrentVersion { |
|||
t.Errorf("Version = %d, want %d", got.Version, CurrentVersion) |
|||
} |
|||
if got.UUID != sb.UUID { |
|||
t.Errorf("UUID mismatch") |
|||
} |
|||
if got.VolumeSize != sb.VolumeSize { |
|||
t.Errorf("VolumeSize = %d, want %d", got.VolumeSize, sb.VolumeSize) |
|||
} |
|||
if got.ExtentSize != sb.ExtentSize { |
|||
t.Errorf("ExtentSize = %d, want %d", got.ExtentSize, sb.ExtentSize) |
|||
} |
|||
if got.BlockSize != sb.BlockSize { |
|||
t.Errorf("BlockSize = %d, want %d", got.BlockSize, sb.BlockSize) |
|||
} |
|||
if got.WALOffset != sb.WALOffset { |
|||
t.Errorf("WALOffset = %d, want %d", got.WALOffset, sb.WALOffset) |
|||
} |
|||
if got.WALSize != sb.WALSize { |
|||
t.Errorf("WALSize = %d, want %d", got.WALSize, sb.WALSize) |
|||
} |
|||
if got.WALHead != sb.WALHead { |
|||
t.Errorf("WALHead = %d, want %d", got.WALHead, sb.WALHead) |
|||
} |
|||
if got.WALCheckpointLSN != sb.WALCheckpointLSN { |
|||
t.Errorf("WALCheckpointLSN = %d, want %d", got.WALCheckpointLSN, sb.WALCheckpointLSN) |
|||
} |
|||
if got.Replication != sb.Replication { |
|||
t.Errorf("Replication = %q, want %q", got.Replication, sb.Replication) |
|||
} |
|||
if got.CreatedAt != sb.CreatedAt { |
|||
t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, sb.CreatedAt) |
|||
} |
|||
if got.SnapshotCount != sb.SnapshotCount { |
|||
t.Errorf("SnapshotCount = %d, want %d", got.SnapshotCount, sb.SnapshotCount) |
|||
} |
|||
} |
|||
|
|||
func testSuperblockMagicCheck(t *testing.T) { |
|||
// Write a valid superblock, then corrupt the magic bytes.
|
|||
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) |
|||
var buf bytes.Buffer |
|||
sb.WriteTo(&buf) |
|||
|
|||
data := buf.Bytes() |
|||
copy(data[0:4], "XXXX") // corrupt magic
|
|||
|
|||
_, err := ReadSuperblock(bytes.NewReader(data)) |
|||
if !errors.Is(err, ErrNotBlockVol) { |
|||
t.Errorf("expected ErrNotBlockVol, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testSuperblockVersionCheck(t *testing.T) { |
|||
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) |
|||
var buf bytes.Buffer |
|||
sb.WriteTo(&buf) |
|||
|
|||
data := buf.Bytes() |
|||
// Version is at offset 4, uint16 little-endian.
|
|||
binary.LittleEndian.PutUint16(data[4:], 99) |
|||
|
|||
_, err := ReadSuperblock(bytes.NewReader(data)) |
|||
if !errors.Is(err, ErrUnsupportedVersion) { |
|||
t.Errorf("expected ErrUnsupportedVersion, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testSuperblockZeroVolSize(t *testing.T) { |
|||
_, err := NewSuperblock(0, CreateOptions{}) |
|||
if !errors.Is(err, ErrInvalidVolumeSize) { |
|||
t.Errorf("NewSuperblock(0): expected ErrInvalidVolumeSize, got %v", err) |
|||
} |
|||
|
|||
// Also test reading a superblock with zero volume size.
|
|||
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) |
|||
var buf bytes.Buffer |
|||
sb.WriteTo(&buf) |
|||
|
|||
data := buf.Bytes() |
|||
// VolumeSize is at offset 4+2+2+16 = 24, uint64 little-endian.
|
|||
binary.LittleEndian.PutUint64(data[24:], 0) |
|||
|
|||
_, err = ReadSuperblock(bytes.NewReader(data)) |
|||
if !errors.Is(err, ErrInvalidVolumeSize) { |
|||
t.Errorf("ReadSuperblock(vol_size=0): expected ErrInvalidVolumeSize, got %v", err) |
|||
} |
|||
} |
|||
@ -0,0 +1,153 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"hash/crc32" |
|||
) |
|||
|
|||
const ( |
|||
EntryTypeWrite = 0x01 |
|||
EntryTypeTrim = 0x02 |
|||
EntryTypeBarrier = 0x03 |
|||
EntryTypePadding = 0xFF |
|||
|
|||
// walEntryHeaderSize is the fixed portion: LSN(8) + Epoch(8) + Type(1) +
|
|||
// Flags(1) + LBA(8) + Length(4) + CRC32(4) + EntrySize(4) = 38 bytes.
|
|||
walEntryHeaderSize = 38 |
|||
) |
|||
|
|||
var ( |
|||
ErrCRCMismatch = errors.New("blockvol: CRC mismatch") |
|||
ErrInvalidEntry = errors.New("blockvol: invalid WAL entry") |
|||
ErrEntryTruncated = errors.New("blockvol: WAL entry truncated") |
|||
) |
|||
|
|||
// WALEntry is a variable-size record in the WAL region.
|
|||
type WALEntry struct { |
|||
LSN uint64 |
|||
Epoch uint64 // writer's epoch (Phase 1: always 0)
|
|||
Type uint8 // EntryTypeWrite, EntryTypeTrim, EntryTypeBarrier
|
|||
Flags uint8 |
|||
LBA uint64 // in blocks
|
|||
Length uint32 // data length in bytes (WRITE: data size, TRIM: trim size, BARRIER: 0)
|
|||
Data []byte // present only for WRITE
|
|||
CRC32 uint32 // covers LSN through Data
|
|||
EntrySize uint32 // total serialized size
|
|||
} |
|||
|
|||
// Encode serializes the entry into a byte slice, computing CRC over all fields.
|
|||
func (e *WALEntry) Encode() ([]byte, error) { |
|||
switch e.Type { |
|||
case EntryTypeWrite: |
|||
if len(e.Data) == 0 { |
|||
return nil, fmt.Errorf("%w: WRITE entry with no data", ErrInvalidEntry) |
|||
} |
|||
if uint32(len(e.Data)) != e.Length { |
|||
return nil, fmt.Errorf("%w: data length %d != Length field %d", ErrInvalidEntry, len(e.Data), e.Length) |
|||
} |
|||
case EntryTypeTrim: |
|||
if len(e.Data) != 0 { |
|||
return nil, fmt.Errorf("%w: TRIM entry must have no data payload", ErrInvalidEntry) |
|||
} |
|||
// TRIM carries Length (trim size in bytes) but no Data payload.
|
|||
case EntryTypeBarrier: |
|||
if e.Length != 0 || len(e.Data) != 0 { |
|||
return nil, fmt.Errorf("%w: BARRIER entry must have no data", ErrInvalidEntry) |
|||
} |
|||
} |
|||
|
|||
totalSize := uint32(walEntryHeaderSize + len(e.Data)) |
|||
buf := make([]byte, totalSize) |
|||
|
|||
le := binary.LittleEndian |
|||
off := 0 |
|||
le.PutUint64(buf[off:], e.LSN) |
|||
off += 8 |
|||
le.PutUint64(buf[off:], e.Epoch) |
|||
off += 8 |
|||
buf[off] = e.Type |
|||
off++ |
|||
buf[off] = e.Flags |
|||
off++ |
|||
le.PutUint64(buf[off:], e.LBA) |
|||
off += 8 |
|||
le.PutUint32(buf[off:], e.Length) |
|||
off += 4 |
|||
|
|||
if len(e.Data) > 0 { |
|||
copy(buf[off:], e.Data) |
|||
off += len(e.Data) |
|||
} |
|||
|
|||
// CRC covers everything from start through Data (offset 0 to off).
|
|||
checksum := crc32.ChecksumIEEE(buf[:off]) |
|||
le.PutUint32(buf[off:], checksum) |
|||
off += 4 |
|||
le.PutUint32(buf[off:], totalSize) |
|||
|
|||
e.CRC32 = checksum |
|||
e.EntrySize = totalSize |
|||
return buf, nil |
|||
} |
|||
|
|||
// DecodeWALEntry deserializes a WAL entry from buf, validating CRC.
|
|||
func DecodeWALEntry(buf []byte) (WALEntry, error) { |
|||
if len(buf) < walEntryHeaderSize { |
|||
return WALEntry{}, fmt.Errorf("%w: need %d bytes, have %d", ErrEntryTruncated, walEntryHeaderSize, len(buf)) |
|||
} |
|||
|
|||
le := binary.LittleEndian |
|||
var e WALEntry |
|||
off := 0 |
|||
e.LSN = le.Uint64(buf[off:]) |
|||
off += 8 |
|||
e.Epoch = le.Uint64(buf[off:]) |
|||
off += 8 |
|||
e.Type = buf[off] |
|||
off++ |
|||
e.Flags = buf[off] |
|||
off++ |
|||
e.LBA = le.Uint64(buf[off:]) |
|||
off += 8 |
|||
e.Length = le.Uint32(buf[off:]) |
|||
off += 4 |
|||
|
|||
// For WRITE entries, Length is the data payload size.
|
|||
// For TRIM entries, Length is the trim extent in bytes (no data payload).
|
|||
// For BARRIER/PADDING, Length is 0 (or padding size).
|
|||
var dataLen int |
|||
if e.Type == EntryTypeWrite || e.Type == EntryTypePadding { |
|||
dataLen = int(e.Length) |
|||
} |
|||
|
|||
dataEnd := off + dataLen |
|||
if dataEnd+8 > len(buf) { // +8 for CRC32 + EntrySize
|
|||
return WALEntry{}, fmt.Errorf("%w: need %d bytes for data+footer, have %d", ErrEntryTruncated, dataEnd+8, len(buf)) |
|||
} |
|||
|
|||
if dataLen > 0 { |
|||
e.Data = make([]byte, dataLen) |
|||
copy(e.Data, buf[off:dataEnd]) |
|||
} |
|||
off = dataEnd |
|||
|
|||
e.CRC32 = le.Uint32(buf[off:]) |
|||
off += 4 |
|||
e.EntrySize = le.Uint32(buf[off:]) |
|||
|
|||
// Verify CRC: covers LSN through Data.
|
|||
expected := crc32.ChecksumIEEE(buf[:dataEnd]) |
|||
if e.CRC32 != expected { |
|||
return WALEntry{}, fmt.Errorf("%w: stored=%08x computed=%08x", ErrCRCMismatch, e.CRC32, expected) |
|||
} |
|||
|
|||
// Verify EntrySize matches actual layout.
|
|||
expectedSize := uint32(walEntryHeaderSize) + uint32(dataLen) |
|||
if e.EntrySize != expectedSize { |
|||
return WALEntry{}, fmt.Errorf("%w: EntrySize=%d, expected=%d", ErrInvalidEntry, e.EntrySize, expectedSize) |
|||
} |
|||
|
|||
return e, nil |
|||
} |
|||
@ -0,0 +1,284 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"errors" |
|||
"testing" |
|||
) |
|||
|
|||
func TestWALEntry(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "wal_entry_roundtrip", run: testWALEntryRoundtrip}, |
|||
{name: "wal_entry_trim_no_data", run: testWALEntryTrimNoData}, |
|||
{name: "wal_entry_barrier", run: testWALEntryBarrier}, |
|||
{name: "wal_entry_crc_valid", run: testWALEntryCRCValid}, |
|||
{name: "wal_entry_crc_corrupt", run: testWALEntryCRCCorrupt}, |
|||
{name: "wal_entry_max_size", run: testWALEntryMaxSize}, |
|||
{name: "wal_entry_zero_length", run: testWALEntryZeroLength}, |
|||
{name: "wal_entry_trim_rejects_data", run: testWALEntryTrimRejectsData}, |
|||
{name: "wal_entry_barrier_rejects_data", run: testWALEntryBarrierRejectsData}, |
|||
{name: "wal_entry_bad_entry_size", run: testWALEntryBadEntrySize}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryRoundtrip(t *testing.T) { |
|||
data := make([]byte, 4096) |
|||
for i := range data { |
|||
data[i] = byte(i % 251) // deterministic pattern
|
|||
} |
|||
|
|||
e := WALEntry{ |
|||
LSN: 1, |
|||
Epoch: 0, |
|||
Type: EntryTypeWrite, |
|||
Flags: 0, |
|||
LBA: 100, |
|||
Length: uint32(len(data)), |
|||
Data: data, |
|||
} |
|||
|
|||
buf, err := e.Encode() |
|||
if err != nil { |
|||
t.Fatalf("Encode: %v", err) |
|||
} |
|||
|
|||
got, err := DecodeWALEntry(buf) |
|||
if err != nil { |
|||
t.Fatalf("DecodeWALEntry: %v", err) |
|||
} |
|||
|
|||
if got.LSN != e.LSN { |
|||
t.Errorf("LSN = %d, want %d", got.LSN, e.LSN) |
|||
} |
|||
if got.Epoch != e.Epoch { |
|||
t.Errorf("Epoch = %d, want %d", got.Epoch, e.Epoch) |
|||
} |
|||
if got.Type != e.Type { |
|||
t.Errorf("Type = %d, want %d", got.Type, e.Type) |
|||
} |
|||
if got.LBA != e.LBA { |
|||
t.Errorf("LBA = %d, want %d", got.LBA, e.LBA) |
|||
} |
|||
if got.Length != e.Length { |
|||
t.Errorf("Length = %d, want %d", got.Length, e.Length) |
|||
} |
|||
if !bytes.Equal(got.Data, e.Data) { |
|||
t.Errorf("Data mismatch") |
|||
} |
|||
if got.EntrySize != uint32(walEntryHeaderSize+len(data)) { |
|||
t.Errorf("EntrySize = %d, want %d", got.EntrySize, walEntryHeaderSize+len(data)) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryTrimNoData(t *testing.T) { |
|||
// TRIM carries Length (trim size in bytes) but no Data payload.
|
|||
e := WALEntry{ |
|||
LSN: 5, |
|||
Epoch: 0, |
|||
Type: EntryTypeTrim, |
|||
LBA: 200, |
|||
Length: 4096, |
|||
} |
|||
|
|||
buf, err := e.Encode() |
|||
if err != nil { |
|||
t.Fatalf("Encode: %v", err) |
|||
} |
|||
|
|||
got, err := DecodeWALEntry(buf) |
|||
if err != nil { |
|||
t.Fatalf("DecodeWALEntry: %v", err) |
|||
} |
|||
|
|||
if got.Type != EntryTypeTrim { |
|||
t.Errorf("Type = %d, want %d", got.Type, EntryTypeTrim) |
|||
} |
|||
if got.Length != 4096 { |
|||
t.Errorf("Length = %d, want 4096", got.Length) |
|||
} |
|||
if len(got.Data) != 0 { |
|||
t.Errorf("Data should be empty, got %d bytes", len(got.Data)) |
|||
} |
|||
// TRIM with Length but no Data: EntrySize = header only.
|
|||
if got.EntrySize != uint32(walEntryHeaderSize) { |
|||
t.Errorf("EntrySize = %d, want %d (header only)", got.EntrySize, walEntryHeaderSize) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryBarrier(t *testing.T) { |
|||
e := WALEntry{ |
|||
LSN: 10, |
|||
Epoch: 0, |
|||
Type: EntryTypeBarrier, |
|||
} |
|||
|
|||
buf, err := e.Encode() |
|||
if err != nil { |
|||
t.Fatalf("Encode: %v", err) |
|||
} |
|||
|
|||
got, err := DecodeWALEntry(buf) |
|||
if err != nil { |
|||
t.Fatalf("DecodeWALEntry: %v", err) |
|||
} |
|||
|
|||
if got.Type != EntryTypeBarrier { |
|||
t.Errorf("Type = %d, want %d", got.Type, EntryTypeBarrier) |
|||
} |
|||
if len(got.Data) != 0 { |
|||
t.Errorf("Barrier should have no data, got %d bytes", len(got.Data)) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryCRCValid(t *testing.T) { |
|||
data := []byte("hello blockvol WAL") |
|||
e := WALEntry{ |
|||
LSN: 1, |
|||
Type: EntryTypeWrite, |
|||
LBA: 0, |
|||
Length: uint32(len(data)), |
|||
Data: data, |
|||
} |
|||
|
|||
buf, err := e.Encode() |
|||
if err != nil { |
|||
t.Fatalf("Encode: %v", err) |
|||
} |
|||
|
|||
// Decode should succeed with valid CRC.
|
|||
_, err = DecodeWALEntry(buf) |
|||
if err != nil { |
|||
t.Fatalf("valid entry should decode without error, got: %v", err) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryCRCCorrupt(t *testing.T) { |
|||
data := []byte("hello blockvol WAL") |
|||
e := WALEntry{ |
|||
LSN: 1, |
|||
Type: EntryTypeWrite, |
|||
LBA: 0, |
|||
Length: uint32(len(data)), |
|||
Data: data, |
|||
} |
|||
|
|||
buf, err := e.Encode() |
|||
if err != nil { |
|||
t.Fatalf("Encode: %v", err) |
|||
} |
|||
|
|||
// Flip one bit in the data region.
|
|||
buf[walEntryHeaderSize-8] ^= 0x01 // flip bit in data area (before CRC/EntrySize footer)
|
|||
|
|||
_, err = DecodeWALEntry(buf) |
|||
if !errors.Is(err, ErrCRCMismatch) { |
|||
t.Errorf("expected ErrCRCMismatch, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryMaxSize(t *testing.T) { |
|||
data := make([]byte, 64*1024) // 64KB
|
|||
for i := range data { |
|||
data[i] = byte(i % 256) |
|||
} |
|||
|
|||
e := WALEntry{ |
|||
LSN: 99, |
|||
Epoch: 0, |
|||
Type: EntryTypeWrite, |
|||
LBA: 0, |
|||
Length: uint32(len(data)), |
|||
Data: data, |
|||
} |
|||
|
|||
buf, err := e.Encode() |
|||
if err != nil { |
|||
t.Fatalf("Encode 64KB entry: %v", err) |
|||
} |
|||
|
|||
got, err := DecodeWALEntry(buf) |
|||
if err != nil { |
|||
t.Fatalf("DecodeWALEntry: %v", err) |
|||
} |
|||
|
|||
if !bytes.Equal(got.Data, data) { |
|||
t.Errorf("64KB data mismatch after roundtrip") |
|||
} |
|||
} |
|||
|
|||
func testWALEntryZeroLength(t *testing.T) { |
|||
e := WALEntry{ |
|||
LSN: 1, |
|||
Type: EntryTypeWrite, |
|||
LBA: 0, |
|||
Length: 0, |
|||
Data: nil, |
|||
} |
|||
|
|||
_, err := e.Encode() |
|||
if !errors.Is(err, ErrInvalidEntry) { |
|||
t.Errorf("WRITE with zero data: expected ErrInvalidEntry, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryTrimRejectsData(t *testing.T) { |
|||
// TRIM allows Length (trim extent) but rejects Data payload.
|
|||
e := WALEntry{ |
|||
LSN: 1, |
|||
Type: EntryTypeTrim, |
|||
LBA: 0, |
|||
Length: 4096, |
|||
Data: []byte("bad!"), |
|||
} |
|||
_, err := e.Encode() |
|||
if !errors.Is(err, ErrInvalidEntry) { |
|||
t.Errorf("TRIM with data payload: expected ErrInvalidEntry, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryBarrierRejectsData(t *testing.T) { |
|||
e := WALEntry{ |
|||
LSN: 1, |
|||
Type: EntryTypeBarrier, |
|||
Length: 1, |
|||
Data: []byte("x"), |
|||
} |
|||
_, err := e.Encode() |
|||
if !errors.Is(err, ErrInvalidEntry) { |
|||
t.Errorf("BARRIER with data: expected ErrInvalidEntry, got %v", err) |
|||
} |
|||
} |
|||
|
|||
func testWALEntryBadEntrySize(t *testing.T) { |
|||
data := []byte("test data for entry size validation") |
|||
e := WALEntry{ |
|||
LSN: 1, |
|||
Type: EntryTypeWrite, |
|||
LBA: 0, |
|||
Length: uint32(len(data)), |
|||
Data: data, |
|||
} |
|||
|
|||
buf, err := e.Encode() |
|||
if err != nil { |
|||
t.Fatalf("Encode: %v", err) |
|||
} |
|||
|
|||
// Corrupt EntrySize (last 4 bytes of buf).
|
|||
le := binary.LittleEndian |
|||
le.PutUint32(buf[len(buf)-4:], 9999) |
|||
|
|||
_, err = DecodeWALEntry(buf) |
|||
if !errors.Is(err, ErrInvalidEntry) { |
|||
t.Errorf("bad EntrySize: expected ErrInvalidEntry, got %v", err) |
|||
} |
|||
} |
|||
@ -0,0 +1,192 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"hash/crc32" |
|||
"os" |
|||
"sync" |
|||
) |
|||
|
|||
var ( |
|||
ErrWALFull = errors.New("blockvol: WAL region full") |
|||
) |
|||
|
|||
// WALWriter appends entries to the circular WAL region of a blockvol file.
|
|||
//
|
|||
// It uses logical (monotonically increasing) head and tail counters to track
|
|||
// used space. Physical position = logical % walSize. This eliminates the
|
|||
// classic circular buffer ambiguity where head==tail could mean empty or full.
|
|||
// Used space = logicalHead - logicalTail. Free space = walSize - used.
|
|||
type WALWriter struct { |
|||
mu sync.Mutex |
|||
fd *os.File |
|||
walOffset uint64 // absolute file offset where WAL region starts
|
|||
walSize uint64 // size of the WAL region in bytes
|
|||
logicalHead uint64 // monotonically increasing write position
|
|||
logicalTail uint64 // monotonically increasing flush position
|
|||
} |
|||
|
|||
// NewWALWriter creates a WAL writer for the given file.
|
|||
// head and tail are physical positions relative to WAL region start.
|
|||
// For a fresh WAL, both are 0.
|
|||
func NewWALWriter(fd *os.File, walOffset, walSize, head, tail uint64) *WALWriter { |
|||
return &WALWriter{ |
|||
fd: fd, |
|||
walOffset: walOffset, |
|||
walSize: walSize, |
|||
logicalHead: head, // on fresh WAL, physical == logical (both start at 0)
|
|||
logicalTail: tail, |
|||
} |
|||
} |
|||
|
|||
// physicalPos converts a logical position to a physical WAL offset.
|
|||
func (w *WALWriter) physicalPos(logical uint64) uint64 { |
|||
return logical % w.walSize |
|||
} |
|||
|
|||
// used returns the number of bytes occupied in the WAL.
|
|||
func (w *WALWriter) used() uint64 { |
|||
return w.logicalHead - w.logicalTail |
|||
} |
|||
|
|||
// Append writes a serialized WAL entry to the circular WAL region.
|
|||
// Returns the physical WAL-relative offset where the entry was written.
|
|||
// If the entry doesn't fit in the remaining space before the region end,
|
|||
// a padding entry is written and the real entry starts at physical offset 0.
|
|||
func (w *WALWriter) Append(entry *WALEntry) (walRelOffset uint64, err error) { |
|||
buf, err := entry.Encode() |
|||
if err != nil { |
|||
return 0, fmt.Errorf("WALWriter.Append: encode: %w", err) |
|||
} |
|||
|
|||
w.mu.Lock() |
|||
defer w.mu.Unlock() |
|||
|
|||
entryLen := uint64(len(buf)) |
|||
if entryLen > w.walSize { |
|||
return 0, fmt.Errorf("%w: entry size %d exceeds WAL size %d", ErrWALFull, entryLen, w.walSize) |
|||
} |
|||
|
|||
physHead := w.physicalPos(w.logicalHead) |
|||
remaining := w.walSize - physHead |
|||
|
|||
if remaining < entryLen { |
|||
// Not enough room at end of region -- write padding and wrap.
|
|||
// Padding consumes 'remaining' bytes logically.
|
|||
if w.used()+remaining+entryLen > w.walSize { |
|||
return 0, ErrWALFull |
|||
} |
|||
if err := w.writePadding(remaining, physHead); err != nil { |
|||
return 0, fmt.Errorf("WALWriter.Append: padding: %w", err) |
|||
} |
|||
w.logicalHead += remaining |
|||
physHead = 0 |
|||
} |
|||
|
|||
// Check if there's enough free space for the entry.
|
|||
if w.used()+entryLen > w.walSize { |
|||
return 0, ErrWALFull |
|||
} |
|||
|
|||
absOffset := int64(w.walOffset + physHead) |
|||
if _, err := w.fd.WriteAt(buf, absOffset); err != nil { |
|||
return 0, fmt.Errorf("WALWriter.Append: pwrite at offset %d: %w", absOffset, err) |
|||
} |
|||
|
|||
writeOffset := physHead |
|||
w.logicalHead += entryLen |
|||
return writeOffset, nil |
|||
} |
|||
|
|||
// writePadding writes a padding entry at the given physical position.
|
|||
func (w *WALWriter) writePadding(size uint64, physPos uint64) error { |
|||
if size < walEntryHeaderSize { |
|||
// Too small for a proper entry header -- zero it out.
|
|||
buf := make([]byte, size) |
|||
absOffset := int64(w.walOffset + physPos) |
|||
_, err := w.fd.WriteAt(buf, absOffset) |
|||
return err |
|||
} |
|||
|
|||
buf := make([]byte, size) |
|||
le := binary.LittleEndian |
|||
off := 0 |
|||
le.PutUint64(buf[off:], 0) // LSN=0
|
|||
off += 8 |
|||
le.PutUint64(buf[off:], 0) // Epoch=0
|
|||
off += 8 |
|||
buf[off] = EntryTypePadding |
|||
off++ |
|||
buf[off] = 0 // Flags
|
|||
off++ |
|||
le.PutUint64(buf[off:], 0) // LBA=0
|
|||
off += 8 |
|||
paddingDataLen := uint32(size) - uint32(walEntryHeaderSize) |
|||
le.PutUint32(buf[off:], paddingDataLen) |
|||
off += 4 |
|||
dataEnd := off + int(paddingDataLen) |
|||
|
|||
crc := crc32.ChecksumIEEE(buf[:dataEnd]) |
|||
le.PutUint32(buf[dataEnd:], crc) |
|||
le.PutUint32(buf[dataEnd+4:], uint32(size)) |
|||
|
|||
absOffset := int64(w.walOffset + physPos) |
|||
_, err := w.fd.WriteAt(buf, absOffset) |
|||
return err |
|||
} |
|||
|
|||
// AdvanceTail moves the tail forward, freeing WAL space.
|
|||
// Called by the flusher after entries have been written to the extent region.
|
|||
// newTail is a physical position; it is converted to a logical advance.
|
|||
func (w *WALWriter) AdvanceTail(newTail uint64) { |
|||
w.mu.Lock() |
|||
physTail := w.physicalPos(w.logicalTail) |
|||
var advance uint64 |
|||
if newTail >= physTail { |
|||
advance = newTail - physTail |
|||
} else { |
|||
// Tail wrapped around.
|
|||
advance = w.walSize - physTail + newTail |
|||
} |
|||
w.logicalTail += advance |
|||
w.mu.Unlock() |
|||
} |
|||
|
|||
// Head returns the current physical head position (relative to WAL start).
|
|||
func (w *WALWriter) Head() uint64 { |
|||
w.mu.Lock() |
|||
h := w.physicalPos(w.logicalHead) |
|||
w.mu.Unlock() |
|||
return h |
|||
} |
|||
|
|||
// Tail returns the current physical tail position (relative to WAL start).
|
|||
func (w *WALWriter) Tail() uint64 { |
|||
w.mu.Lock() |
|||
t := w.physicalPos(w.logicalTail) |
|||
w.mu.Unlock() |
|||
return t |
|||
} |
|||
|
|||
// LogicalHead returns the logical (monotonically increasing) head position.
|
|||
func (w *WALWriter) LogicalHead() uint64 { |
|||
w.mu.Lock() |
|||
h := w.logicalHead |
|||
w.mu.Unlock() |
|||
return h |
|||
} |
|||
|
|||
// LogicalTail returns the logical (monotonically increasing) tail position.
|
|||
func (w *WALWriter) LogicalTail() uint64 { |
|||
w.mu.Lock() |
|||
t := w.logicalTail |
|||
w.mu.Unlock() |
|||
return t |
|||
} |
|||
|
|||
// Sync fsyncs the underlying file descriptor.
|
|||
func (w *WALWriter) Sync() error { |
|||
return w.fd.Sync() |
|||
} |
|||
@ -0,0 +1,244 @@ |
|||
package blockvol |
|||
|
|||
import ( |
|||
"os" |
|||
"path/filepath" |
|||
"testing" |
|||
) |
|||
|
|||
func TestWALWriter(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
run func(t *testing.T) |
|||
}{ |
|||
{name: "wal_writer_append_read_back", run: testWALWriterAppendReadBack}, |
|||
{name: "wal_writer_multiple_entries", run: testWALWriterMultipleEntries}, |
|||
{name: "wal_writer_wrap_around", run: testWALWriterWrapAround}, |
|||
{name: "wal_writer_full", run: testWALWriterFull}, |
|||
{name: "wal_writer_advance_tail_frees_space", run: testWALWriterAdvanceTailFreesSpace}, |
|||
{name: "wal_writer_fill_no_flusher", run: testWALWriterFillNoFlusher}, |
|||
} |
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tt.run(t) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// createTestWAL creates a temp file with a WAL region at the given offset and size.
|
|||
func createTestWAL(t *testing.T, walOffset, walSize uint64) (*os.File, func()) { |
|||
t.Helper() |
|||
dir := t.TempDir() |
|||
path := filepath.Join(dir, "test.blockvol") |
|||
fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) |
|||
if err != nil { |
|||
t.Fatalf("create temp file: %v", err) |
|||
} |
|||
// Extend file to cover superblock + WAL region.
|
|||
totalSize := int64(walOffset + walSize) |
|||
if err := fd.Truncate(totalSize); err != nil { |
|||
fd.Close() |
|||
t.Fatalf("truncate: %v", err) |
|||
} |
|||
return fd, func() { fd.Close() } |
|||
} |
|||
|
|||
func testWALWriterAppendReadBack(t *testing.T) { |
|||
walOffset := uint64(SuperblockSize) |
|||
walSize := uint64(64 * 1024) // 64KB WAL
|
|||
fd, cleanup := createTestWAL(t, walOffset, walSize) |
|||
defer cleanup() |
|||
|
|||
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
|||
|
|||
data := []byte("hello WAL writer test!") |
|||
// Pad to block size for a realistic entry.
|
|||
padded := make([]byte, 4096) |
|||
copy(padded, data) |
|||
|
|||
entry := &WALEntry{ |
|||
LSN: 1, |
|||
Type: EntryTypeWrite, |
|||
LBA: 0, |
|||
Length: uint32(len(padded)), |
|||
Data: padded, |
|||
} |
|||
|
|||
off, err := w.Append(entry) |
|||
if err != nil { |
|||
t.Fatalf("Append: %v", err) |
|||
} |
|||
if off != 0 { |
|||
t.Errorf("first entry offset = %d, want 0", off) |
|||
} |
|||
|
|||
// Read back from file and decode.
|
|||
buf := make([]byte, entry.EntrySize) |
|||
if _, err := fd.ReadAt(buf, int64(walOffset+off)); err != nil { |
|||
t.Fatalf("ReadAt: %v", err) |
|||
} |
|||
decoded, err := DecodeWALEntry(buf) |
|||
if err != nil { |
|||
t.Fatalf("DecodeWALEntry: %v", err) |
|||
} |
|||
if decoded.LSN != 1 { |
|||
t.Errorf("decoded LSN = %d, want 1", decoded.LSN) |
|||
} |
|||
} |
|||
|
|||
func testWALWriterMultipleEntries(t *testing.T) { |
|||
walOffset := uint64(SuperblockSize) |
|||
walSize := uint64(64 * 1024) |
|||
fd, cleanup := createTestWAL(t, walOffset, walSize) |
|||
defer cleanup() |
|||
|
|||
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
|||
|
|||
for i := uint64(1); i <= 5; i++ { |
|||
entry := &WALEntry{ |
|||
LSN: i, |
|||
Type: EntryTypeWrite, |
|||
LBA: i * 10, |
|||
Length: 4096, |
|||
Data: make([]byte, 4096), |
|||
} |
|||
_, err := w.Append(entry) |
|||
if err != nil { |
|||
t.Fatalf("Append entry %d: %v", i, err) |
|||
} |
|||
} |
|||
|
|||
expectedHead := uint64(5 * (walEntryHeaderSize + 4096)) |
|||
if w.Head() != expectedHead { |
|||
t.Errorf("head = %d, want %d", w.Head(), expectedHead) |
|||
} |
|||
} |
|||
|
|||
func testWALWriterWrapAround(t *testing.T) { |
|||
walOffset := uint64(SuperblockSize) |
|||
// Small WAL: 2.5x entry size. After writing 2 entries (head at 2*entrySize),
|
|||
// remaining (0.5*entrySize) is too small for entry 3, so it wraps to offset 0.
|
|||
// Tail must be advanced past both entries so there's free space after wrap.
|
|||
entrySize := uint64(walEntryHeaderSize + 4096) |
|||
walSize := entrySize*2 + entrySize/2 |
|||
|
|||
fd, cleanup := createTestWAL(t, walOffset, walSize) |
|||
defer cleanup() |
|||
|
|||
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
|||
|
|||
// Write 2 entries to fill most of the WAL.
|
|||
for i := uint64(1); i <= 2; i++ { |
|||
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)} |
|||
if _, err := w.Append(entry); err != nil { |
|||
t.Fatalf("Append entry %d: %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Advance tail past both entries (simulates flusher flushed them).
|
|||
w.AdvanceTail(entrySize * 2) |
|||
|
|||
// Write a 3rd entry -- should wrap around.
|
|||
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 3, Length: 4096, Data: make([]byte, 4096)} |
|||
off, err := w.Append(entry3) |
|||
if err != nil { |
|||
t.Fatalf("Append entry 3 (wrap): %v", err) |
|||
} |
|||
|
|||
// After wrap, entry should be written at offset 0.
|
|||
if off != 0 { |
|||
t.Errorf("wrapped entry offset = %d, want 0", off) |
|||
} |
|||
} |
|||
|
|||
func testWALWriterFull(t *testing.T) { |
|||
walOffset := uint64(SuperblockSize) |
|||
entrySize := uint64(walEntryHeaderSize + 4096) |
|||
walSize := entrySize * 2 // fits exactly 2 entries
|
|||
|
|||
fd, cleanup := createTestWAL(t, walOffset, walSize) |
|||
defer cleanup() |
|||
|
|||
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
|||
|
|||
// Fill the WAL with 2 entries (exact fit).
|
|||
for i := uint64(1); i <= 2; i++ { |
|||
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)} |
|||
if _, err := w.Append(entry); err != nil { |
|||
t.Fatalf("Append entry %d: %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// Third entry should fail -- tail hasn't moved, so no free space.
|
|||
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)} |
|||
_, err := w.Append(entry3) |
|||
if err == nil { |
|||
t.Fatal("expected ErrWALFull when WAL is full") |
|||
} |
|||
} |
|||
|
|||
func testWALWriterAdvanceTailFreesSpace(t *testing.T) { |
|||
walOffset := uint64(SuperblockSize) |
|||
entrySize := uint64(walEntryHeaderSize + 4096) |
|||
walSize := entrySize * 2 |
|||
|
|||
fd, cleanup := createTestWAL(t, walOffset, walSize) |
|||
defer cleanup() |
|||
|
|||
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
|||
|
|||
// Fill the WAL with 2 entries.
|
|||
for i := uint64(1); i <= 2; i++ { |
|||
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)} |
|||
if _, err := w.Append(entry); err != nil { |
|||
t.Fatalf("Append entry %d: %v", i, err) |
|||
} |
|||
} |
|||
|
|||
// WAL is full. Third entry should fail.
|
|||
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)} |
|||
if _, err := w.Append(entry3); err == nil { |
|||
t.Fatal("expected ErrWALFull before AdvanceTail") |
|||
} |
|||
|
|||
// Advance tail to free space for 1 entry.
|
|||
w.AdvanceTail(entrySize) |
|||
|
|||
// Now entry 3 should succeed.
|
|||
if _, err := w.Append(entry3); err != nil { |
|||
t.Fatalf("Append after AdvanceTail: %v", err) |
|||
} |
|||
} |
|||
|
|||
func testWALWriterFillNoFlusher(t *testing.T) { |
|||
// QA-001 regression: fill WAL without flusher (tail stays at 0).
|
|||
// After wrap, head=0 and tail=0 must NOT be treated as "empty".
|
|||
walOffset := uint64(SuperblockSize) |
|||
entrySize := uint64(walEntryHeaderSize + 4096) |
|||
walSize := entrySize * 10 // room for ~10 entries
|
|||
|
|||
fd, cleanup := createTestWAL(t, walOffset, walSize) |
|||
defer cleanup() |
|||
|
|||
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
|||
|
|||
// Fill the WAL completely -- tail never moves (no flusher).
|
|||
written := 0 |
|||
for i := uint64(1); ; i++ { |
|||
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)} |
|||
_, err := w.Append(entry) |
|||
if err != nil { |
|||
// Should eventually get ErrWALFull, NOT wrap and overwrite.
|
|||
break |
|||
} |
|||
written++ |
|||
if written > 20 { |
|||
t.Fatalf("wrote %d entries to a 10-entry WAL without ErrWALFull -- wrap overwrote live entries (QA-001)", written) |
|||
} |
|||
} |
|||
|
|||
if written == 0 { |
|||
t.Fatal("should have written at least 1 entry") |
|||
} |
|||
t.Logf("correctly wrote %d entries then got ErrWALFull", written) |
|||
} |
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue