Browse Source
feat: add BlockVol engine and iSCSI target (Phase 1 + Phase 2)
feat: add BlockVol engine and iSCSI target (Phase 1 + Phase 2)
Phase 1: Extent-mapped block storage engine with WAL, crash recovery, dirty map, flusher, and group commit. 174 tests, zero SeaweedFS imports. Phase 2: Pure Go iSCSI target (RFC 7143) with PDU codec, login negotiation, SendTargets discovery, 12 SCSI opcodes, Data-In/Out/R2T sequencing, session management, and standalone iscsi-target binary. 164 tests. IQN->BlockDevice binding via DeviceLookup interface. Total: 338 tests, 14.6K lines. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>feature/sw-block
38 changed files with 14618 additions and 0 deletions
-
371weed/storage/blockvol/blockvol.go
-
3903weed/storage/blockvol/blockvol_qa_test.go
-
685weed/storage/blockvol/blockvol_test.go
-
110weed/storage/blockvol/dirty_map.go
-
153weed/storage/blockvol/dirty_map_test.go
-
239weed/storage/blockvol/flusher.go
-
261weed/storage/blockvol/flusher_test.go
-
167weed/storage/blockvol/group_commit.go
-
287weed/storage/blockvol/group_commit_test.go
-
141weed/storage/blockvol/iscsi/cmd/iscsi-target/main.go
-
216weed/storage/blockvol/iscsi/cmd/iscsi-target/smoke-test.sh
-
207weed/storage/blockvol/iscsi/dataio.go
-
410weed/storage/blockvol/iscsi/dataio_test.go
-
68weed/storage/blockvol/iscsi/discovery.go
-
213weed/storage/blockvol/iscsi/discovery_test.go
-
412weed/storage/blockvol/iscsi/integration_test.go
-
396weed/storage/blockvol/iscsi/login.go
-
444weed/storage/blockvol/iscsi/login_test.go
-
183weed/storage/blockvol/iscsi/params.go
-
357weed/storage/blockvol/iscsi/params_test.go
-
447weed/storage/blockvol/iscsi/pdu.go
-
559weed/storage/blockvol/iscsi/pdu_test.go
-
463weed/storage/blockvol/iscsi/scsi.go
-
692weed/storage/blockvol/iscsi/scsi_test.go
-
421weed/storage/blockvol/iscsi/session.go
-
364weed/storage/blockvol/iscsi/session_test.go
-
216weed/storage/blockvol/iscsi/target.go
-
259weed/storage/blockvol/iscsi/target_test.go
-
38weed/storage/blockvol/lba.go
-
95weed/storage/blockvol/lba_test.go
-
157weed/storage/blockvol/recovery.go
-
416weed/storage/blockvol/recovery_test.go
-
249weed/storage/blockvol/superblock.go
-
146weed/storage/blockvol/superblock_test.go
-
153weed/storage/blockvol/wal_entry.go
-
284weed/storage/blockvol/wal_entry_test.go
-
192weed/storage/blockvol/wal_writer.go
-
244weed/storage/blockvol/wal_writer_test.go
@ -0,0 +1,371 @@ |
|||||
|
// Package blockvol implements a block volume storage engine with WAL,
|
||||
|
// dirty map, group commit, and crash recovery. It operates on a single
|
||||
|
// file and has no dependency on SeaweedFS internals.
|
||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"fmt" |
||||
|
"os" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// CreateOptions configures a new block volume.
|
||||
|
type CreateOptions struct { |
||||
|
VolumeSize uint64 // required, logical size in bytes
|
||||
|
ExtentSize uint32 // default 64KB
|
||||
|
BlockSize uint32 // default 4KB
|
||||
|
WALSize uint64 // default 64MB
|
||||
|
Replication string // default "000"
|
||||
|
} |
||||
|
|
||||
|
// BlockVol is the core block volume engine.
|
||||
|
type BlockVol struct { |
||||
|
mu sync.RWMutex |
||||
|
fd *os.File |
||||
|
path string |
||||
|
super Superblock |
||||
|
wal *WALWriter |
||||
|
dirtyMap *DirtyMap |
||||
|
groupCommit *GroupCommitter |
||||
|
flusher *Flusher |
||||
|
nextLSN atomic.Uint64 |
||||
|
healthy atomic.Bool |
||||
|
} |
||||
|
|
||||
|
// CreateBlockVol creates a new block volume file at path.
|
||||
|
func CreateBlockVol(path string, opts CreateOptions) (*BlockVol, error) { |
||||
|
if opts.VolumeSize == 0 { |
||||
|
return nil, ErrInvalidVolumeSize |
||||
|
} |
||||
|
|
||||
|
sb, err := NewSuperblock(opts.VolumeSize, opts) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("blockvol: create superblock: %w", err) |
||||
|
} |
||||
|
sb.CreatedAt = uint64(time.Now().Unix()) |
||||
|
|
||||
|
// Extent region starts after superblock + WAL.
|
||||
|
extentStart := sb.WALOffset + sb.WALSize |
||||
|
totalFileSize := extentStart + opts.VolumeSize |
||||
|
|
||||
|
fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0644) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("blockvol: create file: %w", err) |
||||
|
} |
||||
|
|
||||
|
if err := fd.Truncate(int64(totalFileSize)); err != nil { |
||||
|
fd.Close() |
||||
|
os.Remove(path) |
||||
|
return nil, fmt.Errorf("blockvol: truncate: %w", err) |
||||
|
} |
||||
|
|
||||
|
if _, err := sb.WriteTo(fd); err != nil { |
||||
|
fd.Close() |
||||
|
os.Remove(path) |
||||
|
return nil, fmt.Errorf("blockvol: write superblock: %w", err) |
||||
|
} |
||||
|
|
||||
|
if err := fd.Sync(); err != nil { |
||||
|
fd.Close() |
||||
|
os.Remove(path) |
||||
|
return nil, fmt.Errorf("blockvol: sync: %w", err) |
||||
|
} |
||||
|
|
||||
|
dm := NewDirtyMap() |
||||
|
wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, 0, 0) |
||||
|
v := &BlockVol{ |
||||
|
fd: fd, |
||||
|
path: path, |
||||
|
super: sb, |
||||
|
wal: wal, |
||||
|
dirtyMap: dm, |
||||
|
} |
||||
|
v.nextLSN.Store(1) |
||||
|
v.healthy.Store(true) |
||||
|
v.groupCommit = NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: fd.Sync, |
||||
|
OnDegraded: func() { v.healthy.Store(false) }, |
||||
|
}) |
||||
|
go v.groupCommit.Run() |
||||
|
v.flusher = NewFlusher(FlusherConfig{ |
||||
|
FD: fd, |
||||
|
Super: &v.super, |
||||
|
WAL: wal, |
||||
|
DirtyMap: dm, |
||||
|
}) |
||||
|
go v.flusher.Run() |
||||
|
return v, nil |
||||
|
} |
||||
|
|
||||
|
// OpenBlockVol opens an existing block volume file and runs crash recovery.
|
||||
|
func OpenBlockVol(path string) (*BlockVol, error) { |
||||
|
fd, err := os.OpenFile(path, os.O_RDWR, 0644) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("blockvol: open file: %w", err) |
||||
|
} |
||||
|
|
||||
|
sb, err := ReadSuperblock(fd) |
||||
|
if err != nil { |
||||
|
fd.Close() |
||||
|
return nil, fmt.Errorf("blockvol: read superblock: %w", err) |
||||
|
} |
||||
|
if err := sb.Validate(); err != nil { |
||||
|
fd.Close() |
||||
|
return nil, fmt.Errorf("blockvol: validate superblock: %w", err) |
||||
|
} |
||||
|
|
||||
|
dirtyMap := NewDirtyMap() |
||||
|
|
||||
|
// Run WAL recovery: replay entries from tail to head.
|
||||
|
result, err := RecoverWAL(fd, &sb, dirtyMap) |
||||
|
if err != nil { |
||||
|
fd.Close() |
||||
|
return nil, fmt.Errorf("blockvol: recovery: %w", err) |
||||
|
} |
||||
|
|
||||
|
nextLSN := sb.WALCheckpointLSN + 1 |
||||
|
if result.HighestLSN >= nextLSN { |
||||
|
nextLSN = result.HighestLSN + 1 |
||||
|
} |
||||
|
|
||||
|
wal := NewWALWriter(fd, sb.WALOffset, sb.WALSize, sb.WALHead, sb.WALTail) |
||||
|
v := &BlockVol{ |
||||
|
fd: fd, |
||||
|
path: path, |
||||
|
super: sb, |
||||
|
wal: wal, |
||||
|
dirtyMap: dirtyMap, |
||||
|
} |
||||
|
v.nextLSN.Store(nextLSN) |
||||
|
v.healthy.Store(true) |
||||
|
v.groupCommit = NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: fd.Sync, |
||||
|
OnDegraded: func() { v.healthy.Store(false) }, |
||||
|
}) |
||||
|
go v.groupCommit.Run() |
||||
|
v.flusher = NewFlusher(FlusherConfig{ |
||||
|
FD: fd, |
||||
|
Super: &v.super, |
||||
|
WAL: wal, |
||||
|
DirtyMap: dirtyMap, |
||||
|
}) |
||||
|
go v.flusher.Run() |
||||
|
return v, nil |
||||
|
} |
||||
|
|
||||
|
// WriteLBA writes data at the given logical block address.
|
||||
|
// Data length must be a multiple of BlockSize.
|
||||
|
func (v *BlockVol) WriteLBA(lba uint64, data []byte) error { |
||||
|
if err := ValidateWrite(lba, uint32(len(data)), v.super.VolumeSize, v.super.BlockSize); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
lsn := v.nextLSN.Add(1) - 1 |
||||
|
entry := &WALEntry{ |
||||
|
LSN: lsn, |
||||
|
Epoch: 0, // Phase 1: no fencing
|
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: lba, |
||||
|
Length: uint32(len(data)), |
||||
|
Data: data, |
||||
|
} |
||||
|
|
||||
|
walOff, err := v.wal.Append(entry) |
||||
|
if err != nil { |
||||
|
return fmt.Errorf("blockvol: WriteLBA: %w", err) |
||||
|
} |
||||
|
|
||||
|
// Update dirty map: one entry per block written.
|
||||
|
blocks := uint32(len(data)) / v.super.BlockSize |
||||
|
for i := uint32(0); i < blocks; i++ { |
||||
|
blockOff := walOff // all blocks in this entry share the same WAL offset
|
||||
|
v.dirtyMap.Put(lba+uint64(i), blockOff, lsn, v.super.BlockSize) |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// ReadLBA reads data at the given logical block address.
|
||||
|
// length is in bytes and must be a multiple of BlockSize.
|
||||
|
func (v *BlockVol) ReadLBA(lba uint64, length uint32) ([]byte, error) { |
||||
|
if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
blocks := length / v.super.BlockSize |
||||
|
result := make([]byte, length) |
||||
|
|
||||
|
for i := uint32(0); i < blocks; i++ { |
||||
|
blockLBA := lba + uint64(i) |
||||
|
blockData, err := v.readOneBlock(blockLBA) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("blockvol: ReadLBA block %d: %w", blockLBA, err) |
||||
|
} |
||||
|
copy(result[i*v.super.BlockSize:], blockData) |
||||
|
} |
||||
|
|
||||
|
return result, nil |
||||
|
} |
||||
|
|
||||
|
// readOneBlock reads a single block, checking dirty map first, then extent.
|
||||
|
func (v *BlockVol) readOneBlock(lba uint64) ([]byte, error) { |
||||
|
walOff, _, _, ok := v.dirtyMap.Get(lba) |
||||
|
if ok { |
||||
|
return v.readBlockFromWAL(walOff, lba) |
||||
|
} |
||||
|
return v.readBlockFromExtent(lba) |
||||
|
} |
||||
|
|
||||
|
// maxWALEntryDataLen caps the data length we trust from a WAL entry header.
|
||||
|
// Anything larger than the WAL region itself is corrupt.
|
||||
|
const maxWALEntryDataLen = 256 * 1024 * 1024 // 256MB absolute ceiling
|
||||
|
|
||||
|
// readBlockFromWAL reads a block's data from its WAL entry.
|
||||
|
func (v *BlockVol) readBlockFromWAL(walOff uint64, lba uint64) ([]byte, error) { |
||||
|
// Read the WAL entry header to get the full entry size.
|
||||
|
headerBuf := make([]byte, walEntryHeaderSize) |
||||
|
absOff := int64(v.super.WALOffset + walOff) |
||||
|
if _, err := v.fd.ReadAt(headerBuf, absOff); err != nil { |
||||
|
return nil, fmt.Errorf("readBlockFromWAL: read header at %d: %w", absOff, err) |
||||
|
} |
||||
|
|
||||
|
// Check entry type first — TRIM has no data payload, so Length is
|
||||
|
// metadata (trim extent), not a data size to allocate.
|
||||
|
entryType := headerBuf[16] // Type is at offset LSN(8) + Epoch(8) = 16
|
||||
|
if entryType == EntryTypeTrim { |
||||
|
// TRIM entry: return zeros regardless of Length field.
|
||||
|
return make([]byte, v.super.BlockSize), nil |
||||
|
} |
||||
|
if entryType != EntryTypeWrite { |
||||
|
return nil, fmt.Errorf("readBlockFromWAL: expected WRITE or TRIM entry, got type 0x%02x", entryType) |
||||
|
} |
||||
|
|
||||
|
// Parse and validate the data Length field before allocating (WRITE only).
|
||||
|
dataLen := v.parseDataLength(headerBuf) |
||||
|
if uint64(dataLen) > v.super.WALSize || uint64(dataLen) > maxWALEntryDataLen { |
||||
|
return nil, fmt.Errorf("readBlockFromWAL: corrupt entry length %d exceeds WAL size %d", dataLen, v.super.WALSize) |
||||
|
} |
||||
|
|
||||
|
entryLen := walEntryHeaderSize + int(dataLen) |
||||
|
fullBuf := make([]byte, entryLen) |
||||
|
if _, err := v.fd.ReadAt(fullBuf, absOff); err != nil { |
||||
|
return nil, fmt.Errorf("readBlockFromWAL: read entry at %d: %w", absOff, err) |
||||
|
} |
||||
|
|
||||
|
entry, err := DecodeWALEntry(fullBuf) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("readBlockFromWAL: decode: %w", err) |
||||
|
} |
||||
|
|
||||
|
// Find the block within the entry's data.
|
||||
|
blockOffset := (lba - entry.LBA) * uint64(v.super.BlockSize) |
||||
|
if blockOffset+uint64(v.super.BlockSize) > uint64(len(entry.Data)) { |
||||
|
return nil, fmt.Errorf("readBlockFromWAL: block offset %d out of range for entry data len %d", blockOffset, len(entry.Data)) |
||||
|
} |
||||
|
|
||||
|
block := make([]byte, v.super.BlockSize) |
||||
|
copy(block, entry.Data[blockOffset:blockOffset+uint64(v.super.BlockSize)]) |
||||
|
return block, nil |
||||
|
} |
||||
|
|
||||
|
// parseDataLength extracts the Length field from a WAL entry header buffer.
|
||||
|
func (v *BlockVol) parseDataLength(headerBuf []byte) uint32 { |
||||
|
// Length is at offset: LSN(8) + Epoch(8) + Type(1) + Flags(1) + LBA(8) = 26
|
||||
|
return binary.LittleEndian.Uint32(headerBuf[26:]) |
||||
|
} |
||||
|
|
||||
|
// readBlockFromExtent reads a block directly from the extent region.
|
||||
|
func (v *BlockVol) readBlockFromExtent(lba uint64) ([]byte, error) { |
||||
|
extentStart := v.super.WALOffset + v.super.WALSize |
||||
|
byteOffset := extentStart + lba*uint64(v.super.BlockSize) |
||||
|
|
||||
|
block := make([]byte, v.super.BlockSize) |
||||
|
if _, err := v.fd.ReadAt(block, int64(byteOffset)); err != nil { |
||||
|
return nil, fmt.Errorf("readBlockFromExtent: pread at %d: %w", byteOffset, err) |
||||
|
} |
||||
|
return block, nil |
||||
|
} |
||||
|
|
||||
|
// Trim marks blocks as deallocated. Subsequent reads return zeros.
|
||||
|
// The trim is recorded in the WAL with a Length field so the flusher
|
||||
|
// can zero the extent region and recovery can replay the trim.
|
||||
|
func (v *BlockVol) Trim(lba uint64, length uint32) error { |
||||
|
if err := ValidateWrite(lba, length, v.super.VolumeSize, v.super.BlockSize); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
lsn := v.nextLSN.Add(1) - 1 |
||||
|
entry := &WALEntry{ |
||||
|
LSN: lsn, |
||||
|
Epoch: 0, |
||||
|
Type: EntryTypeTrim, |
||||
|
LBA: lba, |
||||
|
Length: length, |
||||
|
} |
||||
|
|
||||
|
walOff, err := v.wal.Append(entry) |
||||
|
if err != nil { |
||||
|
return fmt.Errorf("blockvol: Trim: %w", err) |
||||
|
} |
||||
|
|
||||
|
// Update dirty map: mark each trimmed block so the flusher sees it.
|
||||
|
// readOneBlock checks entry type and returns zeros for TRIM entries.
|
||||
|
blocks := length / v.super.BlockSize |
||||
|
for i := uint32(0); i < blocks; i++ { |
||||
|
v.dirtyMap.Put(lba+uint64(i), walOff, lsn, v.super.BlockSize) |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Path returns the file path of the block volume.
|
||||
|
func (v *BlockVol) Path() string { |
||||
|
return v.path |
||||
|
} |
||||
|
|
||||
|
// Info returns volume metadata.
|
||||
|
func (v *BlockVol) Info() VolumeInfo { |
||||
|
return VolumeInfo{ |
||||
|
VolumeSize: v.super.VolumeSize, |
||||
|
BlockSize: v.super.BlockSize, |
||||
|
ExtentSize: v.super.ExtentSize, |
||||
|
WALSize: v.super.WALSize, |
||||
|
Healthy: v.healthy.Load(), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// VolumeInfo contains read-only volume metadata.
|
||||
|
type VolumeInfo struct { |
||||
|
VolumeSize uint64 |
||||
|
BlockSize uint32 |
||||
|
ExtentSize uint32 |
||||
|
WALSize uint64 |
||||
|
Healthy bool |
||||
|
} |
||||
|
|
||||
|
// SyncCache ensures all previously written WAL entries are durable on disk.
|
||||
|
// It submits a sync request to the group committer, which batches fsyncs.
|
||||
|
func (v *BlockVol) SyncCache() error { |
||||
|
return v.groupCommit.Submit() |
||||
|
} |
||||
|
|
||||
|
// Close shuts down the block volume and closes the file.
|
||||
|
// Shutdown order: group committer → stop flusher goroutine → final flush → close fd.
|
||||
|
func (v *BlockVol) Close() error { |
||||
|
if v.groupCommit != nil { |
||||
|
v.groupCommit.Stop() |
||||
|
} |
||||
|
var flushErr error |
||||
|
if v.flusher != nil { |
||||
|
v.flusher.Stop() // stop background goroutine first (no concurrent flush)
|
||||
|
flushErr = v.flusher.FlushOnce() // then do final flush safely
|
||||
|
} |
||||
|
closeErr := v.fd.Close() |
||||
|
if flushErr != nil { |
||||
|
return flushErr |
||||
|
} |
||||
|
return closeErr |
||||
|
} |
||||
3903
weed/storage/blockvol/blockvol_qa_test.go
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
@ -0,0 +1,685 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"os" |
||||
|
"path/filepath" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestBlockVol(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "write_then_read", run: testWriteThenRead}, |
||||
|
{name: "overwrite_read_latest", run: testOverwriteReadLatest}, |
||||
|
{name: "write_no_sync_not_durable", run: testWriteNoSyncNotDurable}, |
||||
|
{name: "write_sync_durable", run: testWriteSyncDurable}, |
||||
|
{name: "write_multiple_sync", run: testWriteMultipleSync}, |
||||
|
{name: "read_unflushed", run: testReadUnflushed}, |
||||
|
{name: "read_flushed", run: testReadFlushed}, |
||||
|
{name: "read_mixed_dirty_clean", run: testReadMixedDirtyClean}, |
||||
|
{name: "wal_read_corrupt_length", run: testWALReadCorruptLength}, |
||||
|
{name: "open_invalid_superblock", run: testOpenInvalidSuperblock}, |
||||
|
{name: "trim_large_length_read_returns_zero", run: testTrimLargeLengthReadReturnsZero}, |
||||
|
// Task 1.10: Lifecycle tests.
|
||||
|
{name: "lifecycle_create_close_reopen", run: testLifecycleCreateCloseReopen}, |
||||
|
{name: "lifecycle_close_flushes_dirty", run: testLifecycleCloseFlushes}, |
||||
|
{name: "lifecycle_double_close", run: testLifecycleDoubleClose}, |
||||
|
{name: "lifecycle_info", run: testLifecycleInfo}, |
||||
|
{name: "lifecycle_write_sync_close_reopen", run: testLifecycleWriteSyncCloseReopen}, |
||||
|
// Task 1.11: Crash stress test.
|
||||
|
{name: "crash_stress_100", run: testCrashStress100}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func createTestVol(t *testing.T) *BlockVol { |
||||
|
t.Helper() |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "test.blockvol") |
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: 1 * 1024 * 1024, // 1MB
|
||||
|
BlockSize: 4096, |
||||
|
WALSize: 256 * 1024, // 256KB WAL
|
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
return v |
||||
|
} |
||||
|
|
||||
|
func makeBlock(fill byte) []byte { |
||||
|
b := make([]byte, 4096) |
||||
|
for i := range b { |
||||
|
b[i] = fill |
||||
|
} |
||||
|
return b |
||||
|
} |
||||
|
|
||||
|
func testWriteThenRead(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
data := makeBlock('A') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("read data does not match written data") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testOverwriteReadLatest(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
||||
|
t.Fatalf("WriteLBA(A): %v", err) |
||||
|
} |
||||
|
if err := v.WriteLBA(0, makeBlock('B')); err != nil { |
||||
|
t.Fatalf("WriteLBA(B): %v", err) |
||||
|
} |
||||
|
|
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock('B')) { |
||||
|
t.Error("read should return latest write ('B'), not 'A'") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWriteNoSyncNotDurable(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
path := v.Path() |
||||
|
|
||||
|
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Simulate crash: close fd without sync.
|
||||
|
v.fd.Close() |
||||
|
|
||||
|
// Reopen -- without recovery (Phase 1.9), data MAY be lost.
|
||||
|
// This test just verifies we can reopen without error.
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol after crash: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
// Data may or may not be present -- both are correct without SyncCache.
|
||||
|
} |
||||
|
|
||||
|
func testWriteSyncDurable(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
path := v.Path() |
||||
|
|
||||
|
data := makeBlock('A') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Sync WAL (manual fsync for now, group commit in Task 1.7).
|
||||
|
if err := v.wal.Sync(); err != nil { |
||||
|
t.Fatalf("Sync: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Update superblock WALHead so reopen knows where entries are.
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALCheckpointLSN = 0 |
||||
|
if _, err := v.fd.Seek(0, 0); err != nil { |
||||
|
t.Fatalf("Seek: %v", err) |
||||
|
} |
||||
|
if _, err := v.super.WriteTo(v.fd); err != nil { |
||||
|
t.Fatalf("WriteTo: %v", err) |
||||
|
} |
||||
|
v.fd.Sync() |
||||
|
v.fd.Close() |
||||
|
|
||||
|
// Reopen and manually replay WAL to verify data is durable.
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
// Manually replay: read WAL entry and populate dirty map.
|
||||
|
replayBuf := make([]byte, v2.super.WALHead) |
||||
|
if _, err := v2.fd.ReadAt(replayBuf, int64(v2.super.WALOffset)); err != nil { |
||||
|
t.Fatalf("read WAL for replay: %v", err) |
||||
|
} |
||||
|
entry, err := DecodeWALEntry(replayBuf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("decode WAL entry: %v", err) |
||||
|
} |
||||
|
blocks := entry.Length / v2.super.BlockSize |
||||
|
for i := uint32(0); i < blocks; i++ { |
||||
|
v2.dirtyMap.Put(entry.LBA+uint64(i), 0, entry.LSN, v2.super.BlockSize) |
||||
|
} |
||||
|
|
||||
|
got, err := v2.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA after recovery: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("data not durable after sync + reopen") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWriteMultipleSync(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write 10 blocks with different data.
|
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Read all back.
|
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
got, err := v.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d): %v", i, err) |
||||
|
} |
||||
|
expected := makeBlock(byte('A' + i)) |
||||
|
if !bytes.Equal(got, expected) { |
||||
|
t.Errorf("block %d: data mismatch", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadUnflushed(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
data := makeBlock('X') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Read before flusher runs -- should come from dirty map / WAL.
|
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("unflushed read: data mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadFlushed(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
data := makeBlock('F') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Manually flush: copy WAL data to extent region, clear dirty map.
|
||||
|
extentStart := v.super.WALOffset + v.super.WALSize |
||||
|
if _, err := v.fd.WriteAt(data, int64(extentStart)); err != nil { |
||||
|
t.Fatalf("manual flush write: %v", err) |
||||
|
} |
||||
|
v.dirtyMap.Delete(0) |
||||
|
|
||||
|
// Read should now come from extent region.
|
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA after flush: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("flushed read: data mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadMixedDirtyClean(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write blocks 0, 2, 4 (dirty).
|
||||
|
for _, lba := range []uint64{0, 2, 4} { |
||||
|
if err := v.WriteLBA(lba, makeBlock(byte('A'+lba))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", lba, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Manually flush block 0 to extent, remove from dirty map.
|
||||
|
extentStart := v.super.WALOffset + v.super.WALSize |
||||
|
if _, err := v.fd.WriteAt(makeBlock('A'), int64(extentStart)); err != nil { |
||||
|
t.Fatalf("manual flush: %v", err) |
||||
|
} |
||||
|
v.dirtyMap.Delete(0) |
||||
|
|
||||
|
// Block 0: from extent (flushed)
|
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(0): %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock('A')) { |
||||
|
t.Error("block 0 (flushed) mismatch") |
||||
|
} |
||||
|
|
||||
|
// Block 2: from dirty map (WAL)
|
||||
|
got, err = v.ReadLBA(2, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(2): %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock('C')) { // 'A'+2 = 'C'
|
||||
|
t.Error("block 2 (dirty) mismatch") |
||||
|
} |
||||
|
|
||||
|
// Block 4: from dirty map (WAL)
|
||||
|
got, err = v.ReadLBA(4, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(4): %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock('E')) { // 'A'+4 = 'E'
|
||||
|
t.Error("block 4 (dirty) mismatch") |
||||
|
} |
||||
|
|
||||
|
// Blocks 1, 3, 5: never written, should be zeros (from extent region).
|
||||
|
for _, lba := range []uint64{1, 3, 5} { |
||||
|
got, err = v.ReadLBA(lba, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d): %v", lba, err) |
||||
|
} |
||||
|
if !bytes.Equal(got, make([]byte, 4096)) { |
||||
|
t.Errorf("block %d (unwritten) should be zeros", lba) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALReadCorruptLength(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write a valid block.
|
||||
|
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Get the WAL offset from dirty map.
|
||||
|
walOff, _, _, ok := v.dirtyMap.Get(0) |
||||
|
if !ok { |
||||
|
t.Fatal("block 0 not in dirty map") |
||||
|
} |
||||
|
|
||||
|
// Corrupt the Length field in the WAL entry on disk.
|
||||
|
// Length is at header offset 26 (LSN=8 + Epoch=8 + Type=1 + Flags=1 + LBA=8).
|
||||
|
absOff := int64(v.super.WALOffset + walOff) |
||||
|
lengthOff := absOff + 26 |
||||
|
var hugeLenBuf [4]byte |
||||
|
binary.LittleEndian.PutUint32(hugeLenBuf[:], 999999999) // ~1GB
|
||||
|
if _, err := v.fd.WriteAt(hugeLenBuf[:], lengthOff); err != nil { |
||||
|
t.Fatalf("corrupt length: %v", err) |
||||
|
} |
||||
|
|
||||
|
// ReadLBA should detect the corrupt length and error (not panic/OOM).
|
||||
|
_, err := v.ReadLBA(0, 4096) |
||||
|
if err == nil { |
||||
|
t.Error("expected error reading corrupt WAL entry, got nil") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testOpenInvalidSuperblock(t *testing.T) { |
||||
|
dir := t.TempDir() |
||||
|
|
||||
|
// Create a valid volume, then corrupt BlockSize to 0.
|
||||
|
path := filepath.Join(dir, "corrupt.blockvol") |
||||
|
v, err := CreateBlockVol(path, CreateOptions{VolumeSize: 1024 * 1024}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
v.Close() |
||||
|
|
||||
|
// Corrupt BlockSize (at superblock offset 36: Magic=4 + Version=2 + Flags=2 + UUID=16 + VolumeSize=8 + ExtentSize=4 = 36).
|
||||
|
fd, err := os.OpenFile(path, os.O_RDWR, 0644) |
||||
|
if err != nil { |
||||
|
t.Fatalf("open for corrupt: %v", err) |
||||
|
} |
||||
|
var zeroBuf [4]byte |
||||
|
if _, err := fd.WriteAt(zeroBuf[:], 36); err != nil { |
||||
|
fd.Close() |
||||
|
t.Fatalf("corrupt blocksize: %v", err) |
||||
|
} |
||||
|
fd.Close() |
||||
|
|
||||
|
// OpenBlockVol should reject the corrupt superblock.
|
||||
|
_, err = OpenBlockVol(path) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected error opening volume with BlockSize=0") |
||||
|
} |
||||
|
if !errors.Is(err, ErrInvalidSuperblock) { |
||||
|
t.Errorf("expected ErrInvalidSuperblock, got: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// --- Task 1.10: Lifecycle tests ---
|
||||
|
|
||||
|
func testLifecycleCreateCloseReopen(t *testing.T) { |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "lifecycle.blockvol") |
||||
|
|
||||
|
// Create, write, close.
|
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: 1 * 1024 * 1024, |
||||
|
BlockSize: 4096, |
||||
|
WALSize: 256 * 1024, |
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
|
||||
|
data := makeBlock('L') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
if err := v.Close(); err != nil { |
||||
|
t.Fatalf("Close: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Reopen and verify data survived.
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
got, err := v2.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("data not durable after create→write→sync→close→reopen") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLifecycleCloseFlushes(t *testing.T) { |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "flush.blockvol") |
||||
|
|
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: 1 * 1024 * 1024, |
||||
|
BlockSize: 4096, |
||||
|
WALSize: 256 * 1024, |
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Write several blocks.
|
||||
|
for i := uint64(0); i < 5; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Close does a final flush — dirty map should be drained.
|
||||
|
if err := v.Close(); err != nil { |
||||
|
t.Fatalf("Close: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Reopen and verify.
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
for i := uint64(0); i < 5; i++ { |
||||
|
got, err := v2.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d): %v", i, err) |
||||
|
} |
||||
|
expected := makeBlock(byte('A' + i)) |
||||
|
if !bytes.Equal(got, expected) { |
||||
|
t.Errorf("block %d: data mismatch after close+reopen", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLifecycleDoubleClose(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
if err := v.Close(); err != nil { |
||||
|
t.Fatalf("first Close: %v", err) |
||||
|
} |
||||
|
// Second close should not panic (group committer + flusher are idempotent).
|
||||
|
// The fd.Close() will return an error but should not panic.
|
||||
|
_ = v.Close() |
||||
|
} |
||||
|
|
||||
|
func testLifecycleInfo(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
info := v.Info() |
||||
|
if info.VolumeSize != 1*1024*1024 { |
||||
|
t.Errorf("VolumeSize = %d, want %d", info.VolumeSize, 1*1024*1024) |
||||
|
} |
||||
|
if info.BlockSize != 4096 { |
||||
|
t.Errorf("BlockSize = %d, want 4096", info.BlockSize) |
||||
|
} |
||||
|
if !info.Healthy { |
||||
|
t.Error("Healthy = false, want true") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLifecycleWriteSyncCloseReopen(t *testing.T) { |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "wsco.blockvol") |
||||
|
|
||||
|
// Cycle: create → write → sync → close → reopen → write → sync → close → verify.
|
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: 1 * 1024 * 1024, |
||||
|
BlockSize: 4096, |
||||
|
WALSize: 256 * 1024, |
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
|
||||
|
if err := v.WriteLBA(0, makeBlock('X')); err != nil { |
||||
|
t.Fatalf("WriteLBA round 1: %v", err) |
||||
|
} |
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache round 1: %v", err) |
||||
|
} |
||||
|
v.Close() |
||||
|
|
||||
|
// Reopen, write more.
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol round 2: %v", err) |
||||
|
} |
||||
|
|
||||
|
if err := v2.WriteLBA(1, makeBlock('Y')); err != nil { |
||||
|
t.Fatalf("WriteLBA round 2: %v", err) |
||||
|
} |
||||
|
// Overwrite block 0.
|
||||
|
if err := v2.WriteLBA(0, makeBlock('Z')); err != nil { |
||||
|
t.Fatalf("WriteLBA overwrite: %v", err) |
||||
|
} |
||||
|
if err := v2.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache round 2: %v", err) |
||||
|
} |
||||
|
v2.Close() |
||||
|
|
||||
|
// Final reopen — verify.
|
||||
|
v3, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol round 3: %v", err) |
||||
|
} |
||||
|
defer v3.Close() |
||||
|
|
||||
|
got0, _ := v3.ReadLBA(0, 4096) |
||||
|
if !bytes.Equal(got0, makeBlock('Z')) { |
||||
|
t.Error("block 0: expected 'Z' (overwritten)") |
||||
|
} |
||||
|
got1, _ := v3.ReadLBA(1, 4096) |
||||
|
if !bytes.Equal(got1, makeBlock('Y')) { |
||||
|
t.Error("block 1: expected 'Y'") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// --- Task 1.11: Crash stress test ---
|
||||
|
|
||||
|
func testCrashStress100(t *testing.T) { |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "stress.blockvol") |
||||
|
|
||||
|
const ( |
||||
|
volumeSize = 256 * 1024 // 256KB volume (64 blocks of 4KB)
|
||||
|
blockSize = 4096 |
||||
|
walSize = 64 * 1024 // 64KB WAL
|
||||
|
maxLBA = volumeSize / blockSize |
||||
|
iterations = 100 |
||||
|
) |
||||
|
|
||||
|
// Oracle: tracks the expected state of each block.
|
||||
|
oracle := make(map[uint64]byte) // lba → fill byte (0 = zeros/trimmed)
|
||||
|
|
||||
|
// Create initial volume.
|
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: volumeSize, |
||||
|
BlockSize: blockSize, |
||||
|
WALSize: walSize, |
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
|
||||
|
for iter := 0; iter < iterations; iter++ { |
||||
|
// Deterministic "random" ops using iteration number.
|
||||
|
numOps := 3 + (iter % 5) // 3-7 ops per iteration
|
||||
|
|
||||
|
for op := 0; op < numOps; op++ { |
||||
|
lba := uint64((iter*7 + op*13) % maxLBA) |
||||
|
action := (iter + op) % 3 // 0=write, 1=overwrite, 2=trim
|
||||
|
|
||||
|
switch action { |
||||
|
case 0, 1: // write
|
||||
|
fill := byte('A' + (iter+op)%26) |
||||
|
err := v.WriteLBA(lba, makeBlock(fill)) |
||||
|
if err != nil { |
||||
|
if errors.Is(err, ErrWALFull) { |
||||
|
continue // WAL full, skip this op
|
||||
|
} |
||||
|
t.Fatalf("iter %d op %d: WriteLBA(%d): %v", iter, op, lba, err) |
||||
|
} |
||||
|
oracle[lba] = fill |
||||
|
case 2: // trim
|
||||
|
err := v.Trim(lba, blockSize) |
||||
|
if err != nil { |
||||
|
if errors.Is(err, ErrWALFull) { |
||||
|
continue |
||||
|
} |
||||
|
t.Fatalf("iter %d op %d: Trim(%d): %v", iter, op, lba, err) |
||||
|
} |
||||
|
oracle[lba] = 0 |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Sync WAL.
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("iter %d: SyncCache: %v", iter, err) |
||||
|
} |
||||
|
|
||||
|
// Simulate crash: close fd without clean shutdown.
|
||||
|
v.fd.Sync() // ensure WAL is on disk
|
||||
|
// Write superblock with current WAL positions for recovery.
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
if _, seekErr := v.fd.Seek(0, 0); seekErr != nil { |
||||
|
t.Fatalf("iter %d: Seek: %v", iter, seekErr) |
||||
|
} |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
// Hard crash: close fd, stop goroutines.
|
||||
|
v.groupCommit.Stop() |
||||
|
v.flusher.Stop() |
||||
|
v.fd.Close() |
||||
|
|
||||
|
// Reopen with recovery.
|
||||
|
v, err = OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("iter %d: OpenBlockVol: %v", iter, err) |
||||
|
} |
||||
|
|
||||
|
// Verify oracle against actual reads.
|
||||
|
for lba, fill := range oracle { |
||||
|
got, readErr := v.ReadLBA(lba, blockSize) |
||||
|
if readErr != nil { |
||||
|
t.Fatalf("iter %d: ReadLBA(%d): %v", iter, lba, readErr) |
||||
|
} |
||||
|
var expected []byte |
||||
|
if fill == 0 { |
||||
|
expected = make([]byte, blockSize) |
||||
|
} else { |
||||
|
expected = makeBlock(fill) |
||||
|
} |
||||
|
if !bytes.Equal(got, expected) { |
||||
|
t.Fatalf("iter %d: block %d mismatch: got[0]=%d want[0]=%d", iter, lba, got[0], expected[0]) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
v.Close() |
||||
|
} |
||||
|
|
||||
|
func testTrimLargeLengthReadReturnsZero(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write data first.
|
||||
|
data := makeBlock('A') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Verify data is written.
|
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA before trim: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Fatal("data mismatch before trim") |
||||
|
} |
||||
|
|
||||
|
// Trim with a length larger than WAL size — should still work.
|
||||
|
// The trim Length is metadata (trim extent), not a data allocation.
|
||||
|
if err := v.Trim(0, 4096); err != nil { |
||||
|
t.Fatalf("Trim: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Read should return zeros (TRIM entry in dirty map).
|
||||
|
got, err = v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA after trim: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, make([]byte, 4096)) { |
||||
|
t.Error("read after trim should return zeros") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,110 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import "sync" |
||||
|
|
||||
|
// dirtyEntry tracks a single LBA's location in the WAL.
|
||||
|
type dirtyEntry struct { |
||||
|
walOffset uint64 |
||||
|
lsn uint64 |
||||
|
length uint32 |
||||
|
} |
||||
|
|
||||
|
// DirtyMap is a concurrency-safe map from LBA to WAL offset.
|
||||
|
// Phase 1 uses a single shard; Phase 3 upgrades to 256 shards.
|
||||
|
type DirtyMap struct { |
||||
|
mu sync.RWMutex |
||||
|
m map[uint64]dirtyEntry |
||||
|
} |
||||
|
|
||||
|
// NewDirtyMap creates an empty dirty map.
|
||||
|
func NewDirtyMap() *DirtyMap { |
||||
|
return &DirtyMap{m: make(map[uint64]dirtyEntry)} |
||||
|
} |
||||
|
|
||||
|
// Put records that the given LBA has dirty data at walOffset.
|
||||
|
func (d *DirtyMap) Put(lba uint64, walOffset uint64, lsn uint64, length uint32) { |
||||
|
d.mu.Lock() |
||||
|
d.m[lba] = dirtyEntry{walOffset: walOffset, lsn: lsn, length: length} |
||||
|
d.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
// Get returns the dirty entry for lba. ok is false if lba is not dirty.
|
||||
|
func (d *DirtyMap) Get(lba uint64) (walOffset uint64, lsn uint64, length uint32, ok bool) { |
||||
|
d.mu.RLock() |
||||
|
e, found := d.m[lba] |
||||
|
d.mu.RUnlock() |
||||
|
if !found { |
||||
|
return 0, 0, 0, false |
||||
|
} |
||||
|
return e.walOffset, e.lsn, e.length, true |
||||
|
} |
||||
|
|
||||
|
// Delete removes a single LBA from the dirty map.
|
||||
|
func (d *DirtyMap) Delete(lba uint64) { |
||||
|
d.mu.Lock() |
||||
|
delete(d.m, lba) |
||||
|
d.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
// rangeEntry is a snapshot of one dirty entry for lock-free iteration.
|
||||
|
type rangeEntry struct { |
||||
|
lba uint64 |
||||
|
walOffset uint64 |
||||
|
lsn uint64 |
||||
|
length uint32 |
||||
|
} |
||||
|
|
||||
|
// Range calls fn for each dirty entry with LBA in [start, start+count).
|
||||
|
// Entries are copied under lock, then fn is called without holding the lock,
|
||||
|
// so fn may safely call back into DirtyMap (Put, Delete, etc.).
|
||||
|
func (d *DirtyMap) Range(start uint64, count uint32, fn func(lba, walOffset, lsn uint64, length uint32)) { |
||||
|
end := start + uint64(count) |
||||
|
|
||||
|
// Snapshot matching entries under lock.
|
||||
|
d.mu.RLock() |
||||
|
entries := make([]rangeEntry, 0, len(d.m)) |
||||
|
for lba, e := range d.m { |
||||
|
if lba >= start && lba < end { |
||||
|
entries = append(entries, rangeEntry{lba, e.walOffset, e.lsn, e.length}) |
||||
|
} |
||||
|
} |
||||
|
d.mu.RUnlock() |
||||
|
|
||||
|
// Call fn without lock.
|
||||
|
for _, e := range entries { |
||||
|
fn(e.lba, e.walOffset, e.lsn, e.length) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Len returns the number of dirty entries.
|
||||
|
func (d *DirtyMap) Len() int { |
||||
|
d.mu.RLock() |
||||
|
n := len(d.m) |
||||
|
d.mu.RUnlock() |
||||
|
return n |
||||
|
} |
||||
|
|
||||
|
// SnapshotEntry is an exported snapshot of one dirty entry.
|
||||
|
type SnapshotEntry struct { |
||||
|
Lba uint64 |
||||
|
WalOffset uint64 |
||||
|
Lsn uint64 |
||||
|
Length uint32 |
||||
|
} |
||||
|
|
||||
|
// Snapshot returns a copy of all dirty entries. The snapshot is taken under
|
||||
|
// read lock but returned without holding the lock.
|
||||
|
func (d *DirtyMap) Snapshot() []SnapshotEntry { |
||||
|
d.mu.RLock() |
||||
|
entries := make([]SnapshotEntry, 0, len(d.m)) |
||||
|
for lba, e := range d.m { |
||||
|
entries = append(entries, SnapshotEntry{ |
||||
|
Lba: lba, |
||||
|
WalOffset: e.walOffset, |
||||
|
Lsn: e.lsn, |
||||
|
Length: e.length, |
||||
|
}) |
||||
|
} |
||||
|
d.mu.RUnlock() |
||||
|
return entries |
||||
|
} |
||||
@ -0,0 +1,153 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"sync" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestDirtyMap(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "dirty_put_get", run: testDirtyPutGet}, |
||||
|
{name: "dirty_overwrite", run: testDirtyOverwrite}, |
||||
|
{name: "dirty_delete", run: testDirtyDelete}, |
||||
|
{name: "dirty_range_query", run: testDirtyRangeQuery}, |
||||
|
{name: "dirty_empty", run: testDirtyEmpty}, |
||||
|
{name: "dirty_concurrent_rw", run: testDirtyConcurrentRW}, |
||||
|
{name: "dirty_range_modify", run: testDirtyRangeModify}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDirtyPutGet(t *testing.T) { |
||||
|
dm := NewDirtyMap() |
||||
|
dm.Put(100, 5000, 1, 4096) |
||||
|
|
||||
|
off, lsn, length, ok := dm.Get(100) |
||||
|
if !ok { |
||||
|
t.Fatal("Get(100) returned not-found") |
||||
|
} |
||||
|
if off != 5000 { |
||||
|
t.Errorf("walOffset = %d, want 5000", off) |
||||
|
} |
||||
|
if lsn != 1 { |
||||
|
t.Errorf("lsn = %d, want 1", lsn) |
||||
|
} |
||||
|
if length != 4096 { |
||||
|
t.Errorf("length = %d, want 4096", length) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDirtyOverwrite(t *testing.T) { |
||||
|
dm := NewDirtyMap() |
||||
|
dm.Put(100, 5000, 1, 4096) |
||||
|
dm.Put(100, 9000, 2, 4096) |
||||
|
|
||||
|
off, lsn, _, ok := dm.Get(100) |
||||
|
if !ok { |
||||
|
t.Fatal("Get(100) returned not-found") |
||||
|
} |
||||
|
if off != 9000 { |
||||
|
t.Errorf("walOffset = %d, want 9000 (latest)", off) |
||||
|
} |
||||
|
if lsn != 2 { |
||||
|
t.Errorf("lsn = %d, want 2 (latest)", lsn) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDirtyDelete(t *testing.T) { |
||||
|
dm := NewDirtyMap() |
||||
|
dm.Put(100, 5000, 1, 4096) |
||||
|
dm.Delete(100) |
||||
|
|
||||
|
_, _, _, ok := dm.Get(100) |
||||
|
if ok { |
||||
|
t.Error("Get(100) should return not-found after delete") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDirtyRangeQuery(t *testing.T) { |
||||
|
dm := NewDirtyMap() |
||||
|
for i := uint64(100); i < 110; i++ { |
||||
|
dm.Put(i, i*1000, i, 4096) |
||||
|
} |
||||
|
|
||||
|
found := make(map[uint64]bool) |
||||
|
dm.Range(100, 10, func(lba, walOffset, lsn uint64, length uint32) { |
||||
|
found[lba] = true |
||||
|
}) |
||||
|
|
||||
|
if len(found) != 10 { |
||||
|
t.Errorf("Range returned %d entries, want 10", len(found)) |
||||
|
} |
||||
|
for i := uint64(100); i < 110; i++ { |
||||
|
if !found[i] { |
||||
|
t.Errorf("Range missing LBA %d", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDirtyEmpty(t *testing.T) { |
||||
|
dm := NewDirtyMap() |
||||
|
|
||||
|
_, _, _, ok := dm.Get(0) |
||||
|
if ok { |
||||
|
t.Error("empty map: Get(0) should return not-found") |
||||
|
} |
||||
|
_, _, _, ok = dm.Get(999) |
||||
|
if ok { |
||||
|
t.Error("empty map: Get(999) should return not-found") |
||||
|
} |
||||
|
if dm.Len() != 0 { |
||||
|
t.Errorf("empty map: Len() = %d, want 0", dm.Len()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDirtyConcurrentRW(t *testing.T) { |
||||
|
dm := NewDirtyMap() |
||||
|
const goroutines = 16 |
||||
|
const opsPerGoroutine = 1000 |
||||
|
|
||||
|
var wg sync.WaitGroup |
||||
|
wg.Add(goroutines) |
||||
|
for g := 0; g < goroutines; g++ { |
||||
|
go func(id int) { |
||||
|
defer wg.Done() |
||||
|
base := uint64(id * opsPerGoroutine) |
||||
|
for i := uint64(0); i < opsPerGoroutine; i++ { |
||||
|
lba := base + i |
||||
|
dm.Put(lba, lba*10, lba, 4096) |
||||
|
dm.Get(lba) |
||||
|
if i%3 == 0 { |
||||
|
dm.Delete(lba) |
||||
|
} |
||||
|
} |
||||
|
}(g) |
||||
|
} |
||||
|
wg.Wait() |
||||
|
|
||||
|
// No assertion on final state -- the test passes if no race detector fires.
|
||||
|
} |
||||
|
|
||||
|
func testDirtyRangeModify(t *testing.T) { |
||||
|
// Verify that Range callback can safely call Delete without deadlock.
|
||||
|
dm := NewDirtyMap() |
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
dm.Put(i, i*1000, i+1, 4096) |
||||
|
} |
||||
|
|
||||
|
// Delete every entry from within the Range callback.
|
||||
|
dm.Range(0, 10, func(lba, walOffset, lsn uint64, length uint32) { |
||||
|
dm.Delete(lba) |
||||
|
}) |
||||
|
|
||||
|
if dm.Len() != 0 { |
||||
|
t.Errorf("after Range+Delete: Len() = %d, want 0", dm.Len()) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,239 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"fmt" |
||||
|
"os" |
||||
|
"sync" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// Flusher copies WAL entries to the extent region and frees WAL space.
|
||||
|
// It runs as a background goroutine and can also be triggered manually.
|
||||
|
type Flusher struct { |
||||
|
fd *os.File |
||||
|
super *Superblock |
||||
|
wal *WALWriter |
||||
|
dirtyMap *DirtyMap |
||||
|
walOffset uint64 // absolute file offset of WAL region
|
||||
|
walSize uint64 |
||||
|
blockSize uint32 |
||||
|
extentStart uint64 // absolute file offset of extent region
|
||||
|
|
||||
|
mu sync.Mutex |
||||
|
checkpointLSN uint64 // last flushed LSN
|
||||
|
checkpointTail uint64 // WAL physical tail after last flush
|
||||
|
|
||||
|
interval time.Duration |
||||
|
notifyCh chan struct{} |
||||
|
stopCh chan struct{} |
||||
|
done chan struct{} |
||||
|
stopOnce sync.Once |
||||
|
} |
||||
|
|
||||
|
// FlusherConfig configures the flusher.
|
||||
|
type FlusherConfig struct { |
||||
|
FD *os.File |
||||
|
Super *Superblock |
||||
|
WAL *WALWriter |
||||
|
DirtyMap *DirtyMap |
||||
|
Interval time.Duration // default 100ms
|
||||
|
} |
||||
|
|
||||
|
// NewFlusher creates a flusher. Call Run() in a goroutine.
|
||||
|
func NewFlusher(cfg FlusherConfig) *Flusher { |
||||
|
if cfg.Interval == 0 { |
||||
|
cfg.Interval = 100 * time.Millisecond |
||||
|
} |
||||
|
return &Flusher{ |
||||
|
fd: cfg.FD, |
||||
|
super: cfg.Super, |
||||
|
wal: cfg.WAL, |
||||
|
dirtyMap: cfg.DirtyMap, |
||||
|
walOffset: cfg.Super.WALOffset, |
||||
|
walSize: cfg.Super.WALSize, |
||||
|
blockSize: cfg.Super.BlockSize, |
||||
|
extentStart: cfg.Super.WALOffset + cfg.Super.WALSize, |
||||
|
checkpointLSN: cfg.Super.WALCheckpointLSN, |
||||
|
checkpointTail: 0, |
||||
|
interval: cfg.Interval, |
||||
|
notifyCh: make(chan struct{}, 1), |
||||
|
stopCh: make(chan struct{}), |
||||
|
done: make(chan struct{}), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Run is the flusher main loop. Call in a goroutine.
|
||||
|
func (f *Flusher) Run() { |
||||
|
defer close(f.done) |
||||
|
ticker := time.NewTicker(f.interval) |
||||
|
defer ticker.Stop() |
||||
|
|
||||
|
for { |
||||
|
select { |
||||
|
case <-f.stopCh: |
||||
|
return |
||||
|
case <-ticker.C: |
||||
|
f.FlushOnce() |
||||
|
case <-f.notifyCh: |
||||
|
f.FlushOnce() |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Notify wakes up the flusher for an immediate flush cycle.
|
||||
|
func (f *Flusher) Notify() { |
||||
|
select { |
||||
|
case f.notifyCh <- struct{}{}: |
||||
|
default: |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Stop shuts down the flusher. Safe to call multiple times.
|
||||
|
func (f *Flusher) Stop() { |
||||
|
f.stopOnce.Do(func() { |
||||
|
close(f.stopCh) |
||||
|
}) |
||||
|
<-f.done |
||||
|
} |
||||
|
|
||||
|
// FlushOnce performs a single flush cycle: scan dirty map, copy data to
|
||||
|
// extent region, fsync, update checkpoint, advance WAL tail.
|
||||
|
func (f *Flusher) FlushOnce() error { |
||||
|
// Snapshot dirty entries. We use a full scan (Range over all possible LBAs
|
||||
|
// is impractical), so we collect from the dirty map directly.
|
||||
|
type flushEntry struct { |
||||
|
lba uint64 |
||||
|
walOff uint64 |
||||
|
lsn uint64 |
||||
|
length uint32 |
||||
|
} |
||||
|
|
||||
|
entries := f.dirtyMap.Snapshot() |
||||
|
if len(entries) == 0 { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Find the max LSN and max WAL offset to know where to advance tail.
|
||||
|
var maxLSN uint64 |
||||
|
var maxWALEnd uint64 |
||||
|
|
||||
|
for _, e := range entries { |
||||
|
// Read the WAL entry and copy data to extent region.
|
||||
|
headerBuf := make([]byte, walEntryHeaderSize) |
||||
|
absWALOff := int64(f.walOffset + e.WalOffset) |
||||
|
if _, err := f.fd.ReadAt(headerBuf, absWALOff); err != nil { |
||||
|
return fmt.Errorf("flusher: read WAL header at %d: %w", absWALOff, err) |
||||
|
} |
||||
|
|
||||
|
// Parse entry type and length.
|
||||
|
entryType := headerBuf[16] // Type at LSN(8)+Epoch(8)=16
|
||||
|
dataLen := parseLength(headerBuf) |
||||
|
|
||||
|
if entryType == EntryTypeWrite && dataLen > 0 { |
||||
|
// Read full entry.
|
||||
|
entryLen := walEntryHeaderSize + int(dataLen) |
||||
|
fullBuf := make([]byte, entryLen) |
||||
|
if _, err := f.fd.ReadAt(fullBuf, absWALOff); err != nil { |
||||
|
return fmt.Errorf("flusher: read WAL entry at %d: %w", absWALOff, err) |
||||
|
} |
||||
|
|
||||
|
entry, err := DecodeWALEntry(fullBuf) |
||||
|
if err != nil { |
||||
|
return fmt.Errorf("flusher: decode WAL entry: %w", err) |
||||
|
} |
||||
|
|
||||
|
// Write data to extent region.
|
||||
|
blocks := entry.Length / f.blockSize |
||||
|
for i := uint32(0); i < blocks; i++ { |
||||
|
blockLBA := entry.LBA + uint64(i) |
||||
|
extentOff := int64(f.extentStart + blockLBA*uint64(f.blockSize)) |
||||
|
blockData := entry.Data[i*f.blockSize : (i+1)*f.blockSize] |
||||
|
if _, err := f.fd.WriteAt(blockData, extentOff); err != nil { |
||||
|
return fmt.Errorf("flusher: write extent at LBA %d: %w", blockLBA, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Track WAL end position for tail advance.
|
||||
|
walEnd := e.WalOffset + uint64(entryLen) |
||||
|
if walEnd > maxWALEnd { |
||||
|
maxWALEnd = walEnd |
||||
|
} |
||||
|
} else if entryType == EntryTypeTrim { |
||||
|
// TRIM entries: zero the extent region for this LBA.
|
||||
|
// Each dirty map entry represents one trimmed block.
|
||||
|
zeroBlock := make([]byte, f.blockSize) |
||||
|
extentOff := int64(f.extentStart + e.Lba*uint64(f.blockSize)) |
||||
|
if _, err := f.fd.WriteAt(zeroBlock, extentOff); err != nil { |
||||
|
return fmt.Errorf("flusher: zero extent at LBA %d: %w", e.Lba, err) |
||||
|
} |
||||
|
|
||||
|
// TRIM entry has no data payload, just a header.
|
||||
|
walEnd := e.WalOffset + uint64(walEntryHeaderSize) |
||||
|
if walEnd > maxWALEnd { |
||||
|
maxWALEnd = walEnd |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if e.Lsn > maxLSN { |
||||
|
maxLSN = e.Lsn |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Fsync extent writes.
|
||||
|
if err := f.fd.Sync(); err != nil { |
||||
|
return fmt.Errorf("flusher: fsync extent: %w", err) |
||||
|
} |
||||
|
|
||||
|
// Remove flushed entries from dirty map.
|
||||
|
f.mu.Lock() |
||||
|
for _, e := range entries { |
||||
|
// Only remove if the dirty map entry still has the same LSN
|
||||
|
// (a newer write may have updated it).
|
||||
|
_, currentLSN, _, ok := f.dirtyMap.Get(e.Lba) |
||||
|
if ok && currentLSN == e.Lsn { |
||||
|
f.dirtyMap.Delete(e.Lba) |
||||
|
} |
||||
|
} |
||||
|
f.checkpointLSN = maxLSN |
||||
|
f.checkpointTail = maxWALEnd |
||||
|
f.mu.Unlock() |
||||
|
|
||||
|
// Advance WAL tail to free space.
|
||||
|
if maxWALEnd > 0 { |
||||
|
f.wal.AdvanceTail(maxWALEnd) |
||||
|
} |
||||
|
|
||||
|
// Update superblock checkpoint.
|
||||
|
f.updateSuperblockCheckpoint(maxLSN, f.wal.Tail()) |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// updateSuperblockCheckpoint writes the updated checkpoint to disk.
|
||||
|
func (f *Flusher) updateSuperblockCheckpoint(checkpointLSN uint64, walTail uint64) error { |
||||
|
f.super.WALCheckpointLSN = checkpointLSN |
||||
|
f.super.WALHead = f.wal.LogicalHead() |
||||
|
f.super.WALTail = f.wal.LogicalTail() |
||||
|
|
||||
|
if _, err := f.fd.Seek(0, 0); err != nil { |
||||
|
return fmt.Errorf("flusher: seek to superblock: %w", err) |
||||
|
} |
||||
|
if _, err := f.super.WriteTo(f.fd); err != nil { |
||||
|
return fmt.Errorf("flusher: write superblock: %w", err) |
||||
|
} |
||||
|
return f.fd.Sync() |
||||
|
} |
||||
|
|
||||
|
// CheckpointLSN returns the last flushed LSN.
|
||||
|
func (f *Flusher) CheckpointLSN() uint64 { |
||||
|
f.mu.Lock() |
||||
|
defer f.mu.Unlock() |
||||
|
return f.checkpointLSN |
||||
|
} |
||||
|
|
||||
|
// parseLength extracts the Length field from a WAL entry header buffer.
|
||||
|
func parseLength(headerBuf []byte) uint32 { |
||||
|
// Length at LSN(8)+Epoch(8)+Type(1)+Flags(1)+LBA(8) = 26
|
||||
|
return binary.LittleEndian.Uint32(headerBuf[26:]) |
||||
|
} |
||||
@ -0,0 +1,261 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"path/filepath" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
func TestFlusher(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "flush_moves_data", run: testFlushMovesData}, |
||||
|
{name: "flush_idempotent", run: testFlushIdempotent}, |
||||
|
{name: "flush_concurrent_writes", run: testFlushConcurrentWrites}, |
||||
|
{name: "flush_frees_wal_space", run: testFlushFreesWALSpace}, |
||||
|
{name: "flush_partial", run: testFlushPartial}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func createTestVolWithFlusher(t *testing.T) (*BlockVol, *Flusher) { |
||||
|
t.Helper() |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "test.blockvol") |
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: 1 * 1024 * 1024, // 1MB
|
||||
|
BlockSize: 4096, |
||||
|
WALSize: 256 * 1024, // 256KB WAL
|
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
|
||||
|
f := NewFlusher(FlusherConfig{ |
||||
|
FD: v.fd, |
||||
|
Super: &v.super, |
||||
|
WAL: v.wal, |
||||
|
DirtyMap: v.dirtyMap, |
||||
|
Interval: 1 * time.Hour, // don't auto-flush in tests
|
||||
|
}) |
||||
|
|
||||
|
return v, f |
||||
|
} |
||||
|
|
||||
|
func testFlushMovesData(t *testing.T) { |
||||
|
v, f := createTestVolWithFlusher(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write 10 blocks.
|
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if v.dirtyMap.Len() != 10 { |
||||
|
t.Fatalf("dirty map len = %d, want 10", v.dirtyMap.Len()) |
||||
|
} |
||||
|
|
||||
|
// Run flusher.
|
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Dirty map should be empty.
|
||||
|
if v.dirtyMap.Len() != 0 { |
||||
|
t.Errorf("after flush: dirty map len = %d, want 0", v.dirtyMap.Len()) |
||||
|
} |
||||
|
|
||||
|
// Checkpoint should have advanced.
|
||||
|
if f.CheckpointLSN() == 0 { |
||||
|
t.Error("checkpoint LSN should be > 0 after flush") |
||||
|
} |
||||
|
|
||||
|
// Read from extent (dirty map is empty, so reads go to extent).
|
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
got, err := v.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d) after flush: %v", i, err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock(byte('A'+i))) { |
||||
|
t.Errorf("block %d: data mismatch after flush", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testFlushIdempotent(t *testing.T) { |
||||
|
v, f := createTestVolWithFlusher(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
data := makeBlock('X') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Flush twice.
|
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce 1: %v", err) |
||||
|
} |
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce 2: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Data should still be correct.
|
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA after double flush: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("data mismatch after double flush") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testFlushConcurrentWrites(t *testing.T) { |
||||
|
v, f := createTestVolWithFlusher(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write blocks 0-4.
|
||||
|
for i := uint64(0); i < 5; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Flush (moves blocks 0-4 to extent).
|
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Write blocks 5-9 AFTER flush.
|
||||
|
for i := uint64(5); i < 10; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Blocks 0-4 should read from extent, blocks 5-9 from WAL.
|
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
got, err := v.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d): %v", i, err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock(byte('A'+i))) { |
||||
|
t.Errorf("block %d: data mismatch", i) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Dirty map should have 5 entries (blocks 5-9).
|
||||
|
if v.dirtyMap.Len() != 5 { |
||||
|
t.Errorf("dirty map len = %d, want 5", v.dirtyMap.Len()) |
||||
|
} |
||||
|
|
||||
|
// Also: overwrite block 0 after flush -- new write should go to WAL.
|
||||
|
newData := makeBlock('Z') |
||||
|
if err := v.WriteLBA(0, newData); err != nil { |
||||
|
t.Fatalf("WriteLBA(0) overwrite: %v", err) |
||||
|
} |
||||
|
got, err := v.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(0) after overwrite: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, newData) { |
||||
|
t.Error("block 0: should return overwritten data 'Z'") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testFlushFreesWALSpace(t *testing.T) { |
||||
|
v, f := createTestVolWithFlusher(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write enough blocks to fill a significant portion of WAL.
|
||||
|
entrySize := uint64(walEntryHeaderSize + 4096) |
||||
|
walCapacity := v.super.WALSize / entrySize |
||||
|
// Write ~80% of capacity.
|
||||
|
writeCount := int(walCapacity * 80 / 100) |
||||
|
|
||||
|
for i := 0; i < writeCount; i++ { |
||||
|
if err := v.WriteLBA(uint64(i), makeBlock(byte(i%26+'A'))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Try to write more -- should eventually fail with WAL full.
|
||||
|
var walFullBefore bool |
||||
|
for i := writeCount; i < writeCount+int(walCapacity); i++ { |
||||
|
if err := v.WriteLBA(uint64(i%writeCount), makeBlock('X')); err != nil { |
||||
|
walFullBefore = true |
||||
|
break |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Flush to free WAL space.
|
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce: %v", err) |
||||
|
} |
||||
|
|
||||
|
// WAL tail should have advanced (free space available).
|
||||
|
// New writes should succeed.
|
||||
|
if err := v.WriteLBA(0, makeBlock('Y')); err != nil { |
||||
|
t.Fatalf("WriteLBA after flush: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Log whether WAL was full before flush.
|
||||
|
if walFullBefore { |
||||
|
t.Log("WAL was full before flush, writes succeeded after flush") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testFlushPartial(t *testing.T) { |
||||
|
v, f := createTestVolWithFlusher(t) |
||||
|
defer v.Close() |
||||
|
|
||||
|
// Write blocks 0-4.
|
||||
|
for i := uint64(0); i < 5; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Flush once (all 5 blocks).
|
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce: %v", err) |
||||
|
} |
||||
|
|
||||
|
checkpointAfterFirst := f.CheckpointLSN() |
||||
|
|
||||
|
// Write blocks 5-9.
|
||||
|
for i := uint64(5); i < 10; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Simulate partial flush: flusher runs again, should handle new entries.
|
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce 2: %v", err) |
||||
|
} |
||||
|
|
||||
|
checkpointAfterSecond := f.CheckpointLSN() |
||||
|
if checkpointAfterSecond <= checkpointAfterFirst { |
||||
|
t.Errorf("checkpoint should advance: first=%d, second=%d", checkpointAfterFirst, checkpointAfterSecond) |
||||
|
} |
||||
|
|
||||
|
// All blocks should be readable from extent.
|
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
got, err := v.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d) after two flushes: %v", i, err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock(byte('A'+i))) { |
||||
|
t.Errorf("block %d: data mismatch after two flushes", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,167 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
var ErrGroupCommitShutdown = errors.New("blockvol: group committer shut down") |
||||
|
|
||||
|
// GroupCommitter batches SyncCache requests and performs a single fsync
|
||||
|
// for the entire batch. This amortizes the cost of fsync across many callers.
|
||||
|
type GroupCommitter struct { |
||||
|
syncFunc func() error // called to fsync (injectable for testing)
|
||||
|
maxDelay time.Duration // max wait before flushing a partial batch
|
||||
|
maxBatch int // flush immediately when this many waiters accumulate
|
||||
|
onDegraded func() // called when fsync fails
|
||||
|
|
||||
|
mu sync.Mutex |
||||
|
pending []chan error |
||||
|
stopped bool // set under mu by Run() before draining
|
||||
|
notifyCh chan struct{} |
||||
|
stopCh chan struct{} |
||||
|
done chan struct{} |
||||
|
stopOnce sync.Once |
||||
|
|
||||
|
syncCount atomic.Uint64 // number of fsyncs performed (for testing)
|
||||
|
} |
||||
|
|
||||
|
// GroupCommitterConfig configures the group committer.
|
||||
|
type GroupCommitterConfig struct { |
||||
|
SyncFunc func() error // required: the fsync function
|
||||
|
MaxDelay time.Duration // default 1ms
|
||||
|
MaxBatch int // default 64
|
||||
|
OnDegraded func() // optional: called on fsync error
|
||||
|
} |
||||
|
|
||||
|
// NewGroupCommitter creates a new group committer. Call Run() to start it.
|
||||
|
func NewGroupCommitter(cfg GroupCommitterConfig) *GroupCommitter { |
||||
|
if cfg.MaxDelay == 0 { |
||||
|
cfg.MaxDelay = 1 * time.Millisecond |
||||
|
} |
||||
|
if cfg.MaxBatch == 0 { |
||||
|
cfg.MaxBatch = 64 |
||||
|
} |
||||
|
if cfg.OnDegraded == nil { |
||||
|
cfg.OnDegraded = func() {} |
||||
|
} |
||||
|
return &GroupCommitter{ |
||||
|
syncFunc: cfg.SyncFunc, |
||||
|
maxDelay: cfg.MaxDelay, |
||||
|
maxBatch: cfg.MaxBatch, |
||||
|
onDegraded: cfg.OnDegraded, |
||||
|
notifyCh: make(chan struct{}, 1), |
||||
|
stopCh: make(chan struct{}), |
||||
|
done: make(chan struct{}), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Run is the main loop. Call this in a goroutine.
|
||||
|
func (gc *GroupCommitter) Run() { |
||||
|
defer close(gc.done) |
||||
|
for { |
||||
|
// Wait for first waiter or shutdown.
|
||||
|
select { |
||||
|
case <-gc.stopCh: |
||||
|
gc.markStoppedAndDrain() |
||||
|
return |
||||
|
case <-gc.notifyCh: |
||||
|
} |
||||
|
|
||||
|
// Collect batch: wait up to maxDelay for more waiters, or until maxBatch reached.
|
||||
|
deadline := time.NewTimer(gc.maxDelay) |
||||
|
for { |
||||
|
gc.mu.Lock() |
||||
|
n := len(gc.pending) |
||||
|
gc.mu.Unlock() |
||||
|
if n >= gc.maxBatch { |
||||
|
deadline.Stop() |
||||
|
break |
||||
|
} |
||||
|
select { |
||||
|
case <-gc.stopCh: |
||||
|
deadline.Stop() |
||||
|
gc.markStoppedAndDrain() |
||||
|
return |
||||
|
case <-deadline.C: |
||||
|
goto flush |
||||
|
case <-gc.notifyCh: |
||||
|
continue |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
flush: |
||||
|
// Take all pending waiters.
|
||||
|
gc.mu.Lock() |
||||
|
batch := gc.pending |
||||
|
gc.pending = nil |
||||
|
gc.mu.Unlock() |
||||
|
|
||||
|
if len(batch) == 0 { |
||||
|
continue |
||||
|
} |
||||
|
|
||||
|
// Perform fsync.
|
||||
|
err := gc.syncFunc() |
||||
|
gc.syncCount.Add(1) |
||||
|
if err != nil { |
||||
|
gc.onDegraded() |
||||
|
} |
||||
|
|
||||
|
// Wake all waiters.
|
||||
|
for _, ch := range batch { |
||||
|
ch <- err |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Submit submits a sync request and blocks until the batch fsync completes.
|
||||
|
// Returns nil on success, the fsync error, or ErrGroupCommitShutdown if stopped.
|
||||
|
func (gc *GroupCommitter) Submit() error { |
||||
|
ch := make(chan error, 1) |
||||
|
gc.mu.Lock() |
||||
|
if gc.stopped { |
||||
|
gc.mu.Unlock() |
||||
|
return ErrGroupCommitShutdown |
||||
|
} |
||||
|
gc.pending = append(gc.pending, ch) |
||||
|
gc.mu.Unlock() |
||||
|
|
||||
|
// Non-blocking notify to wake Run().
|
||||
|
select { |
||||
|
case gc.notifyCh <- struct{}{}: |
||||
|
default: |
||||
|
} |
||||
|
|
||||
|
return <-ch |
||||
|
} |
||||
|
|
||||
|
// Stop shuts down the group committer. Pending waiters receive ErrGroupCommitShutdown.
|
||||
|
// Safe to call multiple times.
|
||||
|
func (gc *GroupCommitter) Stop() { |
||||
|
gc.stopOnce.Do(func() { |
||||
|
close(gc.stopCh) |
||||
|
}) |
||||
|
<-gc.done |
||||
|
} |
||||
|
|
||||
|
// SyncCount returns the number of fsyncs performed (for testing).
|
||||
|
func (gc *GroupCommitter) SyncCount() uint64 { |
||||
|
return gc.syncCount.Load() |
||||
|
} |
||||
|
|
||||
|
// markStoppedAndDrain sets the stopped flag under mu and drains all pending
|
||||
|
// waiters with ErrGroupCommitShutdown. This ensures no new Submit() can
|
||||
|
// enqueue after we drain, closing the race window from QA-002.
|
||||
|
func (gc *GroupCommitter) markStoppedAndDrain() { |
||||
|
gc.mu.Lock() |
||||
|
gc.stopped = true |
||||
|
batch := gc.pending |
||||
|
gc.pending = nil |
||||
|
gc.mu.Unlock() |
||||
|
for _, ch := range batch { |
||||
|
ch <- ErrGroupCommitShutdown |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,287 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
func TestGroupCommitter(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "group_single_barrier", run: testGroupSingleBarrier}, |
||||
|
{name: "group_batch_10", run: testGroupBatch10}, |
||||
|
{name: "group_max_delay", run: testGroupMaxDelay}, |
||||
|
{name: "group_max_batch", run: testGroupMaxBatch}, |
||||
|
{name: "group_fsync_error", run: testGroupFsyncError}, |
||||
|
{name: "group_sequential", run: testGroupSequential}, |
||||
|
{name: "group_shutdown", run: testGroupShutdown}, |
||||
|
{name: "group_submit_after_stop", run: testGroupSubmitAfterStop}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupSingleBarrier(t *testing.T) { |
||||
|
var syncCalls atomic.Uint64 |
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { |
||||
|
syncCalls.Add(1) |
||||
|
return nil |
||||
|
}, |
||||
|
MaxDelay: 10 * time.Millisecond, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
defer gc.Stop() |
||||
|
|
||||
|
if err := gc.Submit(); err != nil { |
||||
|
t.Fatalf("Submit: %v", err) |
||||
|
} |
||||
|
|
||||
|
if c := syncCalls.Load(); c != 1 { |
||||
|
t.Errorf("syncCalls = %d, want 1", c) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupBatch10(t *testing.T) { |
||||
|
var syncCalls atomic.Uint64 |
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { |
||||
|
syncCalls.Add(1) |
||||
|
return nil |
||||
|
}, |
||||
|
MaxDelay: 50 * time.Millisecond, |
||||
|
MaxBatch: 64, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
defer gc.Stop() |
||||
|
|
||||
|
const n = 10 |
||||
|
var wg sync.WaitGroup |
||||
|
errs := make([]error, n) |
||||
|
|
||||
|
// Launch all 10 concurrently.
|
||||
|
wg.Add(n) |
||||
|
for i := 0; i < n; i++ { |
||||
|
go func(idx int) { |
||||
|
defer wg.Done() |
||||
|
errs[idx] = gc.Submit() |
||||
|
}(i) |
||||
|
} |
||||
|
wg.Wait() |
||||
|
|
||||
|
for i, err := range errs { |
||||
|
if err != nil { |
||||
|
t.Errorf("Submit[%d]: %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// All 10 should be batched into 1 fsync (maybe 2 if timing is unlucky).
|
||||
|
if c := syncCalls.Load(); c > 2 { |
||||
|
t.Errorf("syncCalls = %d, want 1-2 (batched)", c) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupMaxDelay(t *testing.T) { |
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { return nil }, |
||||
|
MaxDelay: 5 * time.Millisecond, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
defer gc.Stop() |
||||
|
|
||||
|
start := time.Now() |
||||
|
if err := gc.Submit(); err != nil { |
||||
|
t.Fatalf("Submit: %v", err) |
||||
|
} |
||||
|
elapsed := time.Since(start) |
||||
|
|
||||
|
// Should complete within ~maxDelay + some margin, not take much longer.
|
||||
|
if elapsed > 50*time.Millisecond { |
||||
|
t.Errorf("Submit took %v, expected ~5ms", elapsed) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupMaxBatch(t *testing.T) { |
||||
|
var syncCalls atomic.Uint64 |
||||
|
const batch = 64 |
||||
|
|
||||
|
// Use a slow sync to ensure we can detect immediate trigger.
|
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { |
||||
|
syncCalls.Add(1) |
||||
|
return nil |
||||
|
}, |
||||
|
MaxDelay: 5 * time.Second, // very long delay -- should NOT wait this long
|
||||
|
MaxBatch: batch, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
defer gc.Stop() |
||||
|
|
||||
|
var wg sync.WaitGroup |
||||
|
wg.Add(batch) |
||||
|
for i := 0; i < batch; i++ { |
||||
|
go func() { |
||||
|
defer wg.Done() |
||||
|
gc.Submit() |
||||
|
}() |
||||
|
} |
||||
|
|
||||
|
// Wait with timeout -- if maxBatch triggers, should complete fast.
|
||||
|
done := make(chan struct{}) |
||||
|
go func() { |
||||
|
wg.Wait() |
||||
|
close(done) |
||||
|
}() |
||||
|
|
||||
|
select { |
||||
|
case <-done: |
||||
|
// Good -- completed without waiting 5 seconds.
|
||||
|
case <-time.After(2 * time.Second): |
||||
|
t.Fatal("maxBatch=64 did not trigger immediate flush") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupFsyncError(t *testing.T) { |
||||
|
errIO := errors.New("simulated EIO") |
||||
|
var degraded atomic.Bool |
||||
|
|
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { |
||||
|
return errIO |
||||
|
}, |
||||
|
MaxDelay: 10 * time.Millisecond, |
||||
|
OnDegraded: func() { degraded.Store(true) }, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
defer gc.Stop() |
||||
|
|
||||
|
const n = 5 |
||||
|
var wg sync.WaitGroup |
||||
|
errs := make([]error, n) |
||||
|
|
||||
|
wg.Add(n) |
||||
|
for i := 0; i < n; i++ { |
||||
|
go func(idx int) { |
||||
|
defer wg.Done() |
||||
|
errs[idx] = gc.Submit() |
||||
|
}(i) |
||||
|
} |
||||
|
wg.Wait() |
||||
|
|
||||
|
for i, err := range errs { |
||||
|
if !errors.Is(err, errIO) { |
||||
|
t.Errorf("Submit[%d]: got %v, want %v", i, err, errIO) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if !degraded.Load() { |
||||
|
t.Error("onDegraded was not called") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupSequential(t *testing.T) { |
||||
|
var syncCalls atomic.Uint64 |
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { |
||||
|
syncCalls.Add(1) |
||||
|
return nil |
||||
|
}, |
||||
|
MaxDelay: 10 * time.Millisecond, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
defer gc.Stop() |
||||
|
|
||||
|
// First sync.
|
||||
|
if err := gc.Submit(); err != nil { |
||||
|
t.Fatalf("Submit 1: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Second sync (after first completes).
|
||||
|
if err := gc.Submit(); err != nil { |
||||
|
t.Fatalf("Submit 2: %v", err) |
||||
|
} |
||||
|
|
||||
|
if c := syncCalls.Load(); c != 2 { |
||||
|
t.Errorf("syncCalls = %d, want 2 (two separate fsyncs)", c) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupShutdown(t *testing.T) { |
||||
|
// Block fsync so we can shut down while waiters are pending.
|
||||
|
syncStarted := make(chan struct{}) |
||||
|
syncBlock := make(chan struct{}) |
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { |
||||
|
close(syncStarted) |
||||
|
<-syncBlock // block until test releases
|
||||
|
return nil |
||||
|
}, |
||||
|
MaxDelay: 1 * time.Millisecond, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
|
||||
|
// Submit one request that will block on fsync.
|
||||
|
errCh1 := make(chan error, 1) |
||||
|
go func() { |
||||
|
errCh1 <- gc.Submit() |
||||
|
}() |
||||
|
|
||||
|
// Wait for fsync to start.
|
||||
|
<-syncStarted |
||||
|
|
||||
|
// Submit another request while fsync is in progress.
|
||||
|
errCh2 := make(chan error, 1) |
||||
|
go func() { |
||||
|
errCh2 <- gc.Submit() |
||||
|
}() |
||||
|
time.Sleep(5 * time.Millisecond) // let it enqueue
|
||||
|
|
||||
|
// Stop the group committer (will drain pending with ErrGroupCommitShutdown).
|
||||
|
go func() { |
||||
|
time.Sleep(10 * time.Millisecond) |
||||
|
close(syncBlock) // unblock the fsync first
|
||||
|
}() |
||||
|
gc.Stop() |
||||
|
|
||||
|
// First waiter should get nil (fsync succeeded).
|
||||
|
if err := <-errCh1; err != nil { |
||||
|
t.Errorf("first waiter: %v, want nil", err) |
||||
|
} |
||||
|
|
||||
|
// Second waiter should get shutdown error.
|
||||
|
if err := <-errCh2; !errors.Is(err, ErrGroupCommitShutdown) { |
||||
|
t.Errorf("second waiter: %v, want ErrGroupCommitShutdown", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGroupSubmitAfterStop(t *testing.T) { |
||||
|
gc := NewGroupCommitter(GroupCommitterConfig{ |
||||
|
SyncFunc: func() error { return nil }, |
||||
|
MaxDelay: 10 * time.Millisecond, |
||||
|
}) |
||||
|
go gc.Run() |
||||
|
gc.Stop() |
||||
|
|
||||
|
// Submit after Stop must return ErrGroupCommitShutdown, not deadlock.
|
||||
|
done := make(chan error, 1) |
||||
|
go func() { |
||||
|
done <- gc.Submit() |
||||
|
}() |
||||
|
|
||||
|
select { |
||||
|
case err := <-done: |
||||
|
if !errors.Is(err, ErrGroupCommitShutdown) { |
||||
|
t.Errorf("Submit after Stop: %v, want ErrGroupCommitShutdown", err) |
||||
|
} |
||||
|
case <-time.After(2 * time.Second): |
||||
|
t.Fatal("Submit after Stop deadlocked") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,141 @@ |
|||||
|
// iscsi-target is a standalone iSCSI target backed by a BlockVol file.
|
||||
|
// Usage:
|
||||
|
//
|
||||
|
// iscsi-target -vol /path/to/volume.blk -addr :3260 -iqn iqn.2024.com.seaweedfs:vol1
|
||||
|
// iscsi-target -create -size 1G -vol /path/to/volume.blk -addr :3260
|
||||
|
package main |
||||
|
|
||||
|
import ( |
||||
|
"flag" |
||||
|
"fmt" |
||||
|
"log" |
||||
|
"os" |
||||
|
"os/signal" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
"syscall" |
||||
|
|
||||
|
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol" |
||||
|
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi" |
||||
|
) |
||||
|
|
||||
|
func main() { |
||||
|
volPath := flag.String("vol", "", "path to BlockVol file") |
||||
|
addr := flag.String("addr", ":3260", "listen address") |
||||
|
iqn := flag.String("iqn", "iqn.2024.com.seaweedfs:vol1", "target IQN") |
||||
|
create := flag.Bool("create", false, "create a new volume file") |
||||
|
size := flag.String("size", "1G", "volume size (e.g., 1G, 100M) — used with -create") |
||||
|
flag.Parse() |
||||
|
|
||||
|
if *volPath == "" { |
||||
|
fmt.Fprintln(os.Stderr, "error: -vol is required") |
||||
|
flag.Usage() |
||||
|
os.Exit(1) |
||||
|
} |
||||
|
|
||||
|
logger := log.New(os.Stdout, "[iscsi] ", log.LstdFlags) |
||||
|
|
||||
|
var vol *blockvol.BlockVol |
||||
|
var err error |
||||
|
|
||||
|
if *create { |
||||
|
volSize, parseErr := parseSize(*size) |
||||
|
if parseErr != nil { |
||||
|
log.Fatalf("invalid size %q: %v", *size, parseErr) |
||||
|
} |
||||
|
vol, err = blockvol.CreateBlockVol(*volPath, blockvol.CreateOptions{ |
||||
|
VolumeSize: volSize, |
||||
|
BlockSize: 4096, |
||||
|
WALSize: 64 * 1024 * 1024, |
||||
|
}) |
||||
|
if err != nil { |
||||
|
log.Fatalf("create volume: %v", err) |
||||
|
} |
||||
|
logger.Printf("created volume: %s (%s)", *volPath, *size) |
||||
|
} else { |
||||
|
vol, err = blockvol.OpenBlockVol(*volPath) |
||||
|
if err != nil { |
||||
|
log.Fatalf("open volume: %v", err) |
||||
|
} |
||||
|
logger.Printf("opened volume: %s", *volPath) |
||||
|
} |
||||
|
defer vol.Close() |
||||
|
|
||||
|
info := vol.Info() |
||||
|
logger.Printf("volume: %d bytes, block=%d, healthy=%v", |
||||
|
info.VolumeSize, info.BlockSize, info.Healthy) |
||||
|
|
||||
|
// Create adapter
|
||||
|
adapter := &blockVolAdapter{vol: vol} |
||||
|
|
||||
|
// Create target server
|
||||
|
config := iscsi.DefaultTargetConfig() |
||||
|
config.TargetName = *iqn |
||||
|
config.TargetAlias = "SeaweedFS BlockVol" |
||||
|
ts := iscsi.NewTargetServer(*addr, config, logger) |
||||
|
ts.AddVolume(*iqn, adapter) |
||||
|
|
||||
|
// Graceful shutdown on signal
|
||||
|
sigCh := make(chan os.Signal, 1) |
||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) |
||||
|
go func() { |
||||
|
sig := <-sigCh |
||||
|
logger.Printf("received %v, shutting down...", sig) |
||||
|
ts.Close() |
||||
|
}() |
||||
|
|
||||
|
logger.Printf("starting iSCSI target: %s on %s", *iqn, *addr) |
||||
|
if err := ts.ListenAndServe(); err != nil { |
||||
|
log.Fatalf("target server: %v", err) |
||||
|
} |
||||
|
logger.Println("target stopped") |
||||
|
} |
||||
|
|
||||
|
// blockVolAdapter wraps BlockVol to implement iscsi.BlockDevice.
|
||||
|
type blockVolAdapter struct { |
||||
|
vol *blockvol.BlockVol |
||||
|
} |
||||
|
|
||||
|
func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) { |
||||
|
return a.vol.ReadLBA(lba, length) |
||||
|
} |
||||
|
func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error { |
||||
|
return a.vol.WriteLBA(lba, data) |
||||
|
} |
||||
|
func (a *blockVolAdapter) Trim(lba uint64, length uint32) error { |
||||
|
return a.vol.Trim(lba, length) |
||||
|
} |
||||
|
func (a *blockVolAdapter) SyncCache() error { return a.vol.SyncCache() } |
||||
|
func (a *blockVolAdapter) BlockSize() uint32 { return a.vol.Info().BlockSize } |
||||
|
func (a *blockVolAdapter) VolumeSize() uint64 { return a.vol.Info().VolumeSize } |
||||
|
func (a *blockVolAdapter) IsHealthy() bool { return a.vol.Info().Healthy } |
||||
|
|
||||
|
func parseSize(s string) (uint64, error) { |
||||
|
s = strings.TrimSpace(s) |
||||
|
if len(s) == 0 { |
||||
|
return 0, fmt.Errorf("empty size") |
||||
|
} |
||||
|
|
||||
|
multiplier := uint64(1) |
||||
|
suffix := s[len(s)-1] |
||||
|
switch suffix { |
||||
|
case 'K', 'k': |
||||
|
multiplier = 1024 |
||||
|
s = s[:len(s)-1] |
||||
|
case 'M', 'm': |
||||
|
multiplier = 1024 * 1024 |
||||
|
s = s[:len(s)-1] |
||||
|
case 'G', 'g': |
||||
|
multiplier = 1024 * 1024 * 1024 |
||||
|
s = s[:len(s)-1] |
||||
|
case 'T', 't': |
||||
|
multiplier = 1024 * 1024 * 1024 * 1024 |
||||
|
s = s[:len(s)-1] |
||||
|
} |
||||
|
|
||||
|
n, err := strconv.ParseUint(s, 10, 64) |
||||
|
if err != nil { |
||||
|
return 0, err |
||||
|
} |
||||
|
return n * multiplier, nil |
||||
|
} |
||||
@ -0,0 +1,216 @@ |
|||||
|
#!/usr/bin/env bash |
||||
|
# smoke-test.sh — iscsiadm smoke test for SeaweedFS iSCSI target |
||||
|
# |
||||
|
# Prerequisites: |
||||
|
# - Linux host with iscsiadm (open-iscsi) installed |
||||
|
# - Root access (for iscsiadm, mkfs, mount) |
||||
|
# - iscsi-target binary built: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/ |
||||
|
# |
||||
|
# Usage: |
||||
|
# sudo ./smoke-test.sh [target-addr:port] |
||||
|
# |
||||
|
# Default: starts a local target on :3260, runs discovery + login + I/O, cleans up. |
||||
|
|
||||
|
set -euo pipefail |
||||
|
|
||||
|
TARGET_ADDR="${1:-127.0.0.1}" |
||||
|
TARGET_PORT="${2:-3260}" |
||||
|
TARGET_IQN="iqn.2024.com.seaweedfs:smoke" |
||||
|
VOL_FILE="/tmp/iscsi-smoke-test.blk" |
||||
|
MOUNT_POINT="/tmp/iscsi-smoke-mnt" |
||||
|
TARGET_PID="" |
||||
|
|
||||
|
RED='\033[0;31m' |
||||
|
GREEN='\033[0;32m' |
||||
|
YELLOW='\033[1;33m' |
||||
|
NC='\033[0m' |
||||
|
|
||||
|
log() { echo -e "${GREEN}[SMOKE]${NC} $*"; } |
||||
|
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } |
||||
|
fail() { echo -e "${RED}[FAIL]${NC} $*"; exit 1; } |
||||
|
|
||||
|
cleanup() { |
||||
|
log "Cleaning up..." |
||||
|
|
||||
|
# Unmount if mounted |
||||
|
if mountpoint -q "$MOUNT_POINT" 2>/dev/null; then |
||||
|
umount "$MOUNT_POINT" || true |
||||
|
fi |
||||
|
rmdir "$MOUNT_POINT" 2>/dev/null || true |
||||
|
|
||||
|
# Logout from iSCSI |
||||
|
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout 2>/dev/null || true |
||||
|
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" -o delete 2>/dev/null || true |
||||
|
|
||||
|
# Stop target |
||||
|
if [[ -n "$TARGET_PID" ]] && kill -0 "$TARGET_PID" 2>/dev/null; then |
||||
|
kill "$TARGET_PID" |
||||
|
wait "$TARGET_PID" 2>/dev/null || true |
||||
|
fi |
||||
|
|
||||
|
# Remove volume file |
||||
|
rm -f "$VOL_FILE" |
||||
|
|
||||
|
log "Cleanup complete" |
||||
|
} |
||||
|
|
||||
|
trap cleanup EXIT |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Preflight |
||||
|
# ------------------------------------------------------- |
||||
|
if [[ $EUID -ne 0 ]]; then |
||||
|
fail "This script must be run as root (for iscsiadm/mount)" |
||||
|
fi |
||||
|
|
||||
|
if ! command -v iscsiadm &>/dev/null; then |
||||
|
fail "iscsiadm not found. Install open-iscsi: apt install open-iscsi" |
||||
|
fi |
||||
|
|
||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" |
||||
|
TARGET_BIN="${SCRIPT_DIR}/iscsi-target" |
||||
|
if [[ ! -x "$TARGET_BIN" ]]; then |
||||
|
# Try the repo root |
||||
|
TARGET_BIN="$(cd "$SCRIPT_DIR/../../../../../.." && pwd)/iscsi-target" |
||||
|
fi |
||||
|
if [[ ! -x "$TARGET_BIN" ]]; then |
||||
|
fail "iscsi-target binary not found. Build with: go build ./weed/storage/blockvol/iscsi/cmd/iscsi-target/" |
||||
|
fi |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Step 1: Start iSCSI target |
||||
|
# ------------------------------------------------------- |
||||
|
log "Starting iSCSI target..." |
||||
|
"$TARGET_BIN" \ |
||||
|
-create \ |
||||
|
-vol "$VOL_FILE" \ |
||||
|
-size 100M \ |
||||
|
-addr "${TARGET_ADDR}:${TARGET_PORT}" \ |
||||
|
-iqn "$TARGET_IQN" & |
||||
|
TARGET_PID=$! |
||||
|
sleep 1 |
||||
|
|
||||
|
if ! kill -0 "$TARGET_PID" 2>/dev/null; then |
||||
|
fail "Target failed to start" |
||||
|
fi |
||||
|
log "Target started (PID $TARGET_PID)" |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Step 2: Discovery |
||||
|
# ------------------------------------------------------- |
||||
|
log "Running iscsiadm discovery..." |
||||
|
DISCOVERY=$(iscsiadm -m discovery -t sendtargets -p "${TARGET_ADDR}:${TARGET_PORT}" 2>&1) || { |
||||
|
fail "Discovery failed: $DISCOVERY" |
||||
|
} |
||||
|
echo "$DISCOVERY" |
||||
|
|
||||
|
if ! echo "$DISCOVERY" | grep -q "$TARGET_IQN"; then |
||||
|
fail "Target IQN not found in discovery output" |
||||
|
fi |
||||
|
log "Discovery OK" |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Step 3: Login |
||||
|
# ------------------------------------------------------- |
||||
|
log "Logging in to target..." |
||||
|
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --login || { |
||||
|
fail "Login failed" |
||||
|
} |
||||
|
log "Login OK" |
||||
|
|
||||
|
# Wait for device to appear |
||||
|
sleep 2 |
||||
|
ISCSI_DEV=$(iscsiadm -m session -P 3 2>/dev/null | grep "Attached scsi disk" | awk '{print $4}' | head -1) |
||||
|
if [[ -z "$ISCSI_DEV" ]]; then |
||||
|
warn "Could not determine attached device — trying /dev/sdb" |
||||
|
ISCSI_DEV="sdb" |
||||
|
fi |
||||
|
DEV_PATH="/dev/$ISCSI_DEV" |
||||
|
log "Attached device: $DEV_PATH" |
||||
|
|
||||
|
if [[ ! -b "$DEV_PATH" ]]; then |
||||
|
fail "Block device $DEV_PATH not found" |
||||
|
fi |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Step 4: dd write/read test |
||||
|
# ------------------------------------------------------- |
||||
|
log "Writing test pattern with dd..." |
||||
|
dd if=/dev/urandom of=/tmp/iscsi-smoke-pattern bs=1M count=1 2>/dev/null |
||||
|
dd if=/tmp/iscsi-smoke-pattern of="$DEV_PATH" bs=1M count=1 oflag=direct 2>/dev/null || { |
||||
|
fail "dd write failed" |
||||
|
} |
||||
|
|
||||
|
log "Reading back and verifying..." |
||||
|
dd if="$DEV_PATH" of=/tmp/iscsi-smoke-readback bs=1M count=1 iflag=direct 2>/dev/null || { |
||||
|
fail "dd read failed" |
||||
|
} |
||||
|
|
||||
|
if cmp -s /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback; then |
||||
|
log "Data integrity verified (dd write/read OK)" |
||||
|
else |
||||
|
fail "Data integrity check failed!" |
||||
|
fi |
||||
|
rm -f /tmp/iscsi-smoke-pattern /tmp/iscsi-smoke-readback |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Step 5: mkfs + mount + file I/O |
||||
|
# ------------------------------------------------------- |
||||
|
log "Creating ext4 filesystem..." |
||||
|
mkfs.ext4 -F -q "$DEV_PATH" || { |
||||
|
warn "mkfs.ext4 failed — skipping filesystem tests" |
||||
|
# Still consider the test a pass if dd worked |
||||
|
log "SMOKE TEST PASSED (dd only, mkfs skipped)" |
||||
|
exit 0 |
||||
|
} |
||||
|
|
||||
|
mkdir -p "$MOUNT_POINT" |
||||
|
mount "$DEV_PATH" "$MOUNT_POINT" || { |
||||
|
fail "mount failed" |
||||
|
} |
||||
|
log "Mounted at $MOUNT_POINT" |
||||
|
|
||||
|
# Write a file |
||||
|
echo "SeaweedFS iSCSI smoke test" > "$MOUNT_POINT/test.txt" |
||||
|
sync |
||||
|
|
||||
|
# Read it back |
||||
|
CONTENT=$(cat "$MOUNT_POINT/test.txt") |
||||
|
if [[ "$CONTENT" != "SeaweedFS iSCSI smoke test" ]]; then |
||||
|
fail "File content mismatch" |
||||
|
fi |
||||
|
log "File I/O verified" |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Step 6: Logout |
||||
|
# ------------------------------------------------------- |
||||
|
log "Unmounting and logging out..." |
||||
|
umount "$MOUNT_POINT" |
||||
|
iscsiadm -m node -T "$TARGET_IQN" -p "${TARGET_ADDR}:${TARGET_PORT}" --logout || { |
||||
|
fail "Logout failed" |
||||
|
} |
||||
|
log "Logout OK" |
||||
|
|
||||
|
# Verify no stale sessions |
||||
|
SESSION_COUNT=$(iscsiadm -m session 2>/dev/null | grep -c "$TARGET_IQN" || true) |
||||
|
if [[ "$SESSION_COUNT" -gt 0 ]]; then |
||||
|
fail "Stale session detected after logout" |
||||
|
fi |
||||
|
log "No stale sessions" |
||||
|
|
||||
|
# ------------------------------------------------------- |
||||
|
# Result |
||||
|
# ------------------------------------------------------- |
||||
|
echo "" |
||||
|
echo -e "${GREEN}========================================${NC}" |
||||
|
echo -e "${GREEN} SMOKE TEST PASSED${NC}" |
||||
|
echo -e "${GREEN}========================================${NC}" |
||||
|
echo "" |
||||
|
echo " Discovery: OK" |
||||
|
echo " Login: OK" |
||||
|
echo " dd I/O: OK" |
||||
|
echo " mkfs+mount: OK" |
||||
|
echo " File I/O: OK" |
||||
|
echo " Logout: OK" |
||||
|
echo " Cleanup: OK" |
||||
|
echo "" |
||||
@ -0,0 +1,207 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"io" |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrDataSNOrder = errors.New("iscsi: Data-Out DataSN out of order") |
||||
|
ErrDataOverflow = errors.New("iscsi: data exceeds expected transfer length") |
||||
|
ErrDataIncomplete = errors.New("iscsi: data transfer incomplete") |
||||
|
) |
||||
|
|
||||
|
// DataInWriter splits a read response into multiple Data-In PDUs, respecting
|
||||
|
// MaxRecvDataSegmentLength. The final PDU carries the S-bit (status) and F-bit.
|
||||
|
type DataInWriter struct { |
||||
|
maxSegLen uint32 // negotiated MaxRecvDataSegmentLength
|
||||
|
} |
||||
|
|
||||
|
// NewDataInWriter creates a writer with the given max segment length.
|
||||
|
func NewDataInWriter(maxSegLen uint32) *DataInWriter { |
||||
|
if maxSegLen == 0 { |
||||
|
maxSegLen = 8192 // sensible default
|
||||
|
} |
||||
|
return &DataInWriter{maxSegLen: maxSegLen} |
||||
|
} |
||||
|
|
||||
|
// WriteDataIn splits data into Data-In PDUs and writes them to w.
|
||||
|
// itt is the initiator task tag, ttt is the target transfer tag.
|
||||
|
// statSN is the current StatSN (incremented when S-bit is set).
|
||||
|
// Returns the number of PDUs written.
|
||||
|
func (d *DataInWriter) WriteDataIn(w io.Writer, data []byte, itt uint32, expCmdSN, maxCmdSN uint32, statSN *uint32) (int, error) { |
||||
|
totalLen := uint32(len(data)) |
||||
|
if totalLen == 0 { |
||||
|
// Zero-length read — send single Data-In with S-bit, no data
|
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIDataIn) |
||||
|
pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status
|
||||
|
pdu.SetInitiatorTaskTag(itt) |
||||
|
pdu.SetTargetTransferTag(0xFFFFFFFF) |
||||
|
pdu.SetStatSN(*statSN) |
||||
|
*statSN++ |
||||
|
pdu.SetExpCmdSN(expCmdSN) |
||||
|
pdu.SetMaxCmdSN(maxCmdSN) |
||||
|
pdu.SetDataSN(0) |
||||
|
pdu.SetSCSIStatus(SCSIStatusGood) |
||||
|
if err := WritePDU(w, pdu); err != nil { |
||||
|
return 0, err |
||||
|
} |
||||
|
return 1, nil |
||||
|
} |
||||
|
|
||||
|
var offset uint32 |
||||
|
var dataSN uint32 |
||||
|
count := 0 |
||||
|
|
||||
|
for offset < totalLen { |
||||
|
segLen := d.maxSegLen |
||||
|
if offset+segLen > totalLen { |
||||
|
segLen = totalLen - offset |
||||
|
} |
||||
|
isFinal := (offset + segLen) >= totalLen |
||||
|
|
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIDataIn) |
||||
|
pdu.SetInitiatorTaskTag(itt) |
||||
|
pdu.SetTargetTransferTag(0xFFFFFFFF) |
||||
|
pdu.SetExpCmdSN(expCmdSN) |
||||
|
pdu.SetMaxCmdSN(maxCmdSN) |
||||
|
pdu.SetDataSN(dataSN) |
||||
|
pdu.SetBufferOffset(offset) |
||||
|
|
||||
|
pdu.DataSegment = data[offset : offset+segLen] |
||||
|
|
||||
|
if isFinal { |
||||
|
pdu.SetOpSpecific1(FlagF | FlagS) // Final + Status
|
||||
|
pdu.SetStatSN(*statSN) |
||||
|
*statSN++ |
||||
|
pdu.SetSCSIStatus(SCSIStatusGood) |
||||
|
} else { |
||||
|
pdu.SetOpSpecific1(0) // no flags
|
||||
|
} |
||||
|
|
||||
|
if err := WritePDU(w, pdu); err != nil { |
||||
|
return count, err |
||||
|
} |
||||
|
count++ |
||||
|
dataSN++ |
||||
|
offset += segLen |
||||
|
} |
||||
|
|
||||
|
return count, nil |
||||
|
} |
||||
|
|
||||
|
// DataOutCollector collects Data-Out PDUs for a single write command,
|
||||
|
// assembling the full data buffer from potentially multiple PDUs.
|
||||
|
// It handles both immediate data and R2T-solicited data.
|
||||
|
type DataOutCollector struct { |
||||
|
expectedLen uint32 |
||||
|
buf []byte |
||||
|
received uint32 |
||||
|
nextDataSN uint32 |
||||
|
done bool |
||||
|
} |
||||
|
|
||||
|
// NewDataOutCollector creates a collector expecting the given total transfer length.
|
||||
|
func NewDataOutCollector(expectedLen uint32) *DataOutCollector { |
||||
|
return &DataOutCollector{ |
||||
|
expectedLen: expectedLen, |
||||
|
buf: make([]byte, expectedLen), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// AddImmediateData adds the data from the SCSI Command PDU (immediate data).
|
||||
|
func (c *DataOutCollector) AddImmediateData(data []byte) error { |
||||
|
if uint32(len(data)) > c.expectedLen { |
||||
|
return ErrDataOverflow |
||||
|
} |
||||
|
copy(c.buf, data) |
||||
|
c.received += uint32(len(data)) |
||||
|
if c.received >= c.expectedLen { |
||||
|
c.done = true |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// AddDataOut processes a Data-Out PDU and adds its data to the buffer.
|
||||
|
func (c *DataOutCollector) AddDataOut(pdu *PDU) error { |
||||
|
dataSN := pdu.DataSN() |
||||
|
if dataSN != c.nextDataSN { |
||||
|
return ErrDataSNOrder |
||||
|
} |
||||
|
c.nextDataSN++ |
||||
|
|
||||
|
offset := pdu.BufferOffset() |
||||
|
data := pdu.DataSegment |
||||
|
end := offset + uint32(len(data)) |
||||
|
|
||||
|
if end > c.expectedLen { |
||||
|
return ErrDataOverflow |
||||
|
} |
||||
|
|
||||
|
copy(c.buf[offset:], data) |
||||
|
c.received += uint32(len(data)) |
||||
|
|
||||
|
// Check F-bit
|
||||
|
if pdu.OpSpecific1()&FlagF != 0 { |
||||
|
c.done = true |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Done returns true if all expected data has been received.
|
||||
|
func (c *DataOutCollector) Done() bool { return c.done } |
||||
|
|
||||
|
// Data returns the assembled data buffer.
|
||||
|
func (c *DataOutCollector) Data() []byte { return c.buf } |
||||
|
|
||||
|
// Remaining returns how many bytes are still needed.
|
||||
|
func (c *DataOutCollector) Remaining() uint32 { |
||||
|
if c.received >= c.expectedLen { |
||||
|
return 0 |
||||
|
} |
||||
|
return c.expectedLen - c.received |
||||
|
} |
||||
|
|
||||
|
// BuildR2T creates an R2T PDU requesting more data from the initiator.
|
||||
|
func BuildR2T(itt, ttt uint32, r2tSN uint32, bufferOffset, desiredLen uint32, statSN, expCmdSN, maxCmdSN uint32) *PDU { |
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpR2T) |
||||
|
pdu.SetOpSpecific1(FlagF) // always Final for R2T
|
||||
|
pdu.SetInitiatorTaskTag(itt) |
||||
|
pdu.SetTargetTransferTag(ttt) |
||||
|
pdu.SetStatSN(statSN) |
||||
|
pdu.SetExpCmdSN(expCmdSN) |
||||
|
pdu.SetMaxCmdSN(maxCmdSN) |
||||
|
pdu.SetR2TSN(r2tSN) |
||||
|
pdu.SetBufferOffset(bufferOffset) |
||||
|
pdu.SetDesiredDataLength(desiredLen) |
||||
|
return pdu |
||||
|
} |
||||
|
|
||||
|
// SendSCSIResponse sends a SCSI Response PDU with optional sense data.
|
||||
|
func SendSCSIResponse(w io.Writer, result SCSIResult, itt uint32, statSN *uint32, expCmdSN, maxCmdSN uint32) error { |
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIResp) |
||||
|
pdu.SetOpSpecific1(FlagF) // Final
|
||||
|
pdu.SetSCSIResponse(ISCSIRespCompleted) |
||||
|
pdu.SetSCSIStatus(result.Status) |
||||
|
pdu.SetInitiatorTaskTag(itt) |
||||
|
pdu.SetStatSN(*statSN) |
||||
|
*statSN++ |
||||
|
pdu.SetExpCmdSN(expCmdSN) |
||||
|
pdu.SetMaxCmdSN(maxCmdSN) |
||||
|
|
||||
|
if result.Status == SCSIStatusCheckCond { |
||||
|
senseData := BuildSenseData(result.SenseKey, result.SenseASC, result.SenseASCQ) |
||||
|
// Sense data is wrapped in a 2-byte length prefix
|
||||
|
pdu.DataSegment = make([]byte, 2+len(senseData)) |
||||
|
pdu.DataSegment[0] = byte(len(senseData) >> 8) |
||||
|
pdu.DataSegment[1] = byte(len(senseData)) |
||||
|
copy(pdu.DataSegment[2:], senseData) |
||||
|
} |
||||
|
|
||||
|
return WritePDU(w, pdu) |
||||
|
} |
||||
@ -0,0 +1,410 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestDataIO(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"datain_single_pdu", testDataInSinglePDU}, |
||||
|
{"datain_multi_pdu", testDataInMultiPDU}, |
||||
|
{"datain_exact_boundary", testDataInExactBoundary}, |
||||
|
{"datain_zero_length", testDataInZeroLength}, |
||||
|
{"datain_datasn_ordering", testDataInDataSNOrdering}, |
||||
|
{"datain_fbit_sbit", testDataInFbitSbit}, |
||||
|
{"dataout_single_pdu", testDataOutSinglePDU}, |
||||
|
{"dataout_multi_pdu", testDataOutMultiPDU}, |
||||
|
{"dataout_immediate_data", testDataOutImmediateData}, |
||||
|
{"dataout_immediate_plus_r2t", testDataOutImmediatePlusR2T}, |
||||
|
{"dataout_wrong_datasn", testDataOutWrongDataSN}, |
||||
|
{"dataout_overflow", testDataOutOverflow}, |
||||
|
{"r2t_build", testR2TBuild}, |
||||
|
{"scsi_response_good", testSCSIResponseGood}, |
||||
|
{"scsi_response_check_condition", testSCSIResponseCheckCondition}, |
||||
|
{"datain_statsn_increment", testDataInStatSNIncrement}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataInSinglePDU(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
dw := NewDataInWriter(8192) |
||||
|
data := bytes.Repeat([]byte{0xAA}, 4096) |
||||
|
statSN := uint32(1) |
||||
|
|
||||
|
n, err := dw.WriteDataIn(w, data, 0x100, 1, 10, &statSN) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if n != 1 { |
||||
|
t.Fatalf("expected 1 PDU, got %d", n) |
||||
|
} |
||||
|
if statSN != 2 { |
||||
|
t.Fatalf("StatSN should be 2, got %d", statSN) |
||||
|
} |
||||
|
|
||||
|
pdu, err := ReadPDU(w) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if pdu.Opcode() != OpSCSIDataIn { |
||||
|
t.Fatal("wrong opcode") |
||||
|
} |
||||
|
if pdu.OpSpecific1()&FlagF == 0 { |
||||
|
t.Fatal("F-bit not set") |
||||
|
} |
||||
|
if pdu.OpSpecific1()&FlagS == 0 { |
||||
|
t.Fatal("S-bit not set on final PDU") |
||||
|
} |
||||
|
if len(pdu.DataSegment) != 4096 { |
||||
|
t.Fatalf("data length: %d", len(pdu.DataSegment)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataInMultiPDU(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
dw := NewDataInWriter(1024) // small segment
|
||||
|
data := bytes.Repeat([]byte{0xBB}, 3000) |
||||
|
statSN := uint32(1) |
||||
|
|
||||
|
n, err := dw.WriteDataIn(w, data, 0x200, 1, 10, &statSN) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
// 3000 / 1024 = 3 PDUs (1024 + 1024 + 952)
|
||||
|
if n != 3 { |
||||
|
t.Fatalf("expected 3 PDUs, got %d", n) |
||||
|
} |
||||
|
|
||||
|
// Read them back
|
||||
|
var reassembled []byte |
||||
|
for i := 0; i < 3; i++ { |
||||
|
pdu, err := ReadPDU(w) |
||||
|
if err != nil { |
||||
|
t.Fatalf("PDU %d: %v", i, err) |
||||
|
} |
||||
|
if pdu.DataSN() != uint32(i) { |
||||
|
t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN()) |
||||
|
} |
||||
|
reassembled = append(reassembled, pdu.DataSegment...) |
||||
|
|
||||
|
if i < 2 { |
||||
|
if pdu.OpSpecific1()&FlagF != 0 { |
||||
|
t.Fatalf("PDU %d should not have F-bit", i) |
||||
|
} |
||||
|
} else { |
||||
|
if pdu.OpSpecific1()&FlagF == 0 { |
||||
|
t.Fatal("last PDU should have F-bit") |
||||
|
} |
||||
|
if pdu.OpSpecific1()&FlagS == 0 { |
||||
|
t.Fatal("last PDU should have S-bit") |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if !bytes.Equal(reassembled, data) { |
||||
|
t.Fatal("reassembled data mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataInExactBoundary(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
dw := NewDataInWriter(1024) |
||||
|
data := bytes.Repeat([]byte{0xCC}, 2048) // exact 2 PDUs
|
||||
|
statSN := uint32(1) |
||||
|
|
||||
|
n, err := dw.WriteDataIn(w, data, 0x300, 1, 10, &statSN) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if n != 2 { |
||||
|
t.Fatalf("expected 2 PDUs, got %d", n) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataInZeroLength(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
dw := NewDataInWriter(8192) |
||||
|
statSN := uint32(5) |
||||
|
|
||||
|
n, err := dw.WriteDataIn(w, nil, 0x400, 1, 10, &statSN) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if n != 1 { |
||||
|
t.Fatalf("expected 1 PDU for zero-length, got %d", n) |
||||
|
} |
||||
|
if statSN != 6 { |
||||
|
t.Fatal("StatSN should still increment") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataInDataSNOrdering(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
dw := NewDataInWriter(512) |
||||
|
data := bytes.Repeat([]byte{0xDD}, 2048) // 4 PDUs
|
||||
|
statSN := uint32(1) |
||||
|
|
||||
|
dw.WriteDataIn(w, data, 0x500, 1, 10, &statSN) |
||||
|
|
||||
|
for i := 0; i < 4; i++ { |
||||
|
pdu, _ := ReadPDU(w) |
||||
|
if pdu.DataSN() != uint32(i) { |
||||
|
t.Fatalf("PDU %d: DataSN=%d", i, pdu.DataSN()) |
||||
|
} |
||||
|
expectedOffset := uint32(i) * 512 |
||||
|
if pdu.BufferOffset() != expectedOffset { |
||||
|
t.Fatalf("PDU %d: offset=%d, expected %d", i, pdu.BufferOffset(), expectedOffset) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataInFbitSbit(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
dw := NewDataInWriter(1000) |
||||
|
data := bytes.Repeat([]byte{0xEE}, 2500) |
||||
|
statSN := uint32(1) |
||||
|
|
||||
|
dw.WriteDataIn(w, data, 0x600, 1, 10, &statSN) |
||||
|
|
||||
|
for i := 0; i < 3; i++ { |
||||
|
pdu, _ := ReadPDU(w) |
||||
|
flags := pdu.OpSpecific1() |
||||
|
if i < 2 { |
||||
|
if flags&FlagF != 0 || flags&FlagS != 0 { |
||||
|
t.Fatalf("PDU %d: should have no F/S bits", i) |
||||
|
} |
||||
|
} else { |
||||
|
if flags&FlagF == 0 || flags&FlagS == 0 { |
||||
|
t.Fatal("last PDU must have F+S bits") |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataOutSinglePDU(t *testing.T) { |
||||
|
c := NewDataOutCollector(4096) |
||||
|
|
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIDataOut) |
||||
|
pdu.SetOpSpecific1(FlagF) |
||||
|
pdu.SetDataSN(0) |
||||
|
pdu.SetBufferOffset(0) |
||||
|
pdu.DataSegment = bytes.Repeat([]byte{0x11}, 4096) |
||||
|
|
||||
|
if err := c.AddDataOut(pdu); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if !c.Done() { |
||||
|
t.Fatal("should be done") |
||||
|
} |
||||
|
if c.Remaining() != 0 { |
||||
|
t.Fatal("remaining should be 0") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataOutMultiPDU(t *testing.T) { |
||||
|
c := NewDataOutCollector(8192) |
||||
|
|
||||
|
for i := 0; i < 2; i++ { |
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIDataOut) |
||||
|
pdu.SetDataSN(uint32(i)) |
||||
|
pdu.SetBufferOffset(uint32(i) * 4096) |
||||
|
pdu.DataSegment = bytes.Repeat([]byte{byte(i + 1)}, 4096) |
||||
|
if i == 1 { |
||||
|
pdu.SetOpSpecific1(FlagF) |
||||
|
} |
||||
|
if err := c.AddDataOut(pdu); err != nil { |
||||
|
t.Fatalf("PDU %d: %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if !c.Done() { |
||||
|
t.Fatal("should be done") |
||||
|
} |
||||
|
data := c.Data() |
||||
|
if data[0] != 0x01 || data[4096] != 0x02 { |
||||
|
t.Fatal("data assembly wrong") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataOutImmediateData(t *testing.T) { |
||||
|
c := NewDataOutCollector(4096) |
||||
|
err := c.AddImmediateData(bytes.Repeat([]byte{0xFF}, 4096)) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if !c.Done() { |
||||
|
t.Fatal("should be done with immediate data") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataOutImmediatePlusR2T(t *testing.T) { |
||||
|
c := NewDataOutCollector(8192) |
||||
|
|
||||
|
// Immediate: first 4096
|
||||
|
err := c.AddImmediateData(bytes.Repeat([]byte{0xAA}, 4096)) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if c.Done() { |
||||
|
t.Fatal("should not be done yet") |
||||
|
} |
||||
|
if c.Remaining() != 4096 { |
||||
|
t.Fatalf("remaining: %d", c.Remaining()) |
||||
|
} |
||||
|
|
||||
|
// R2T-solicited Data-Out: next 4096
|
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIDataOut) |
||||
|
pdu.SetOpSpecific1(FlagF) |
||||
|
pdu.SetDataSN(0) |
||||
|
pdu.SetBufferOffset(4096) |
||||
|
pdu.DataSegment = bytes.Repeat([]byte{0xBB}, 4096) |
||||
|
if err := c.AddDataOut(pdu); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if !c.Done() { |
||||
|
t.Fatal("should be done") |
||||
|
} |
||||
|
|
||||
|
data := c.Data() |
||||
|
if data[0] != 0xAA || data[4096] != 0xBB { |
||||
|
t.Fatal("assembly wrong") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataOutWrongDataSN(t *testing.T) { |
||||
|
c := NewDataOutCollector(8192) |
||||
|
|
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIDataOut) |
||||
|
pdu.SetDataSN(1) // should be 0
|
||||
|
pdu.SetBufferOffset(0) |
||||
|
pdu.DataSegment = make([]byte, 4096) |
||||
|
|
||||
|
err := c.AddDataOut(pdu) |
||||
|
if err != ErrDataSNOrder { |
||||
|
t.Fatalf("expected ErrDataSNOrder, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataOutOverflow(t *testing.T) { |
||||
|
c := NewDataOutCollector(4096) |
||||
|
|
||||
|
pdu := &PDU{} |
||||
|
pdu.SetOpcode(OpSCSIDataOut) |
||||
|
pdu.SetDataSN(0) |
||||
|
pdu.SetBufferOffset(0) |
||||
|
pdu.DataSegment = make([]byte, 8192) // more than expected
|
||||
|
|
||||
|
err := c.AddDataOut(pdu) |
||||
|
if err != ErrDataOverflow { |
||||
|
t.Fatalf("expected ErrDataOverflow, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testR2TBuild(t *testing.T) { |
||||
|
pdu := BuildR2T(0x100, 0x200, 0, 4096, 4096, 5, 10, 20) |
||||
|
if pdu.Opcode() != OpR2T { |
||||
|
t.Fatal("wrong opcode") |
||||
|
} |
||||
|
if pdu.InitiatorTaskTag() != 0x100 { |
||||
|
t.Fatal("ITT wrong") |
||||
|
} |
||||
|
if pdu.TargetTransferTag() != 0x200 { |
||||
|
t.Fatal("TTT wrong") |
||||
|
} |
||||
|
if pdu.R2TSN() != 0 { |
||||
|
t.Fatal("R2TSN wrong") |
||||
|
} |
||||
|
if pdu.BufferOffset() != 4096 { |
||||
|
t.Fatal("offset wrong") |
||||
|
} |
||||
|
if pdu.DesiredDataLength() != 4096 { |
||||
|
t.Fatal("desired length wrong") |
||||
|
} |
||||
|
if pdu.StatSN() != 5 { |
||||
|
t.Fatal("StatSN wrong") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSCSIResponseGood(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
statSN := uint32(10) |
||||
|
result := SCSIResult{Status: SCSIStatusGood} |
||||
|
err := SendSCSIResponse(w, result, 0x300, &statSN, 5, 15) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if statSN != 11 { |
||||
|
t.Fatal("StatSN not incremented") |
||||
|
} |
||||
|
|
||||
|
pdu, err := ReadPDU(w) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if pdu.Opcode() != OpSCSIResp { |
||||
|
t.Fatal("wrong opcode") |
||||
|
} |
||||
|
if pdu.SCSIStatus() != SCSIStatusGood { |
||||
|
t.Fatal("status wrong") |
||||
|
} |
||||
|
if len(pdu.DataSegment) != 0 { |
||||
|
t.Fatal("no data expected for good status") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSCSIResponseCheckCondition(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
statSN := uint32(20) |
||||
|
result := SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: SenseIllegalRequest, |
||||
|
SenseASC: ASCInvalidOpcode, |
||||
|
SenseASCQ: ASCQLuk, |
||||
|
} |
||||
|
err := SendSCSIResponse(w, result, 0x400, &statSN, 5, 15) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
pdu, err := ReadPDU(w) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if pdu.SCSIStatus() != SCSIStatusCheckCond { |
||||
|
t.Fatal("status wrong") |
||||
|
} |
||||
|
// Data segment should contain sense data with 2-byte length prefix
|
||||
|
if len(pdu.DataSegment) < 20 { // 2 + 18
|
||||
|
t.Fatalf("sense data too short: %d", len(pdu.DataSegment)) |
||||
|
} |
||||
|
senseLen := int(pdu.DataSegment[0])<<8 | int(pdu.DataSegment[1]) |
||||
|
if senseLen != 18 { |
||||
|
t.Fatalf("sense length: %d", senseLen) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataInStatSNIncrement(t *testing.T) { |
||||
|
w := &bytes.Buffer{} |
||||
|
dw := NewDataInWriter(1024) |
||||
|
data := bytes.Repeat([]byte{0x00}, 3072) // 3 PDUs
|
||||
|
statSN := uint32(100) |
||||
|
|
||||
|
dw.WriteDataIn(w, data, 0x700, 1, 10, &statSN) |
||||
|
// Only the final PDU has S-bit, so StatSN increments once
|
||||
|
if statSN != 101 { |
||||
|
t.Fatalf("StatSN should be 101, got %d", statSN) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,68 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
// HandleTextRequest processes a Text Request PDU.
|
||||
|
// Currently supports SendTargets discovery (RFC 7143, Section 12.3).
|
||||
|
func HandleTextRequest(req *PDU, targets []DiscoveryTarget) *PDU { |
||||
|
resp := &PDU{} |
||||
|
resp.SetOpcode(OpTextResp) |
||||
|
resp.SetOpSpecific1(FlagF) // Final
|
||||
|
resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) |
||||
|
resp.SetTargetTransferTag(0xFFFFFFFF) // no continuation
|
||||
|
|
||||
|
params, err := ParseParams(req.DataSegment) |
||||
|
if err != nil { |
||||
|
// Malformed — return empty response
|
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
val, ok := params.Get("SendTargets") |
||||
|
if !ok { |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
// Discovery responses can have duplicate keys (TargetName appears
|
||||
|
// for each target), so we use EncodeDiscoveryTargets directly.
|
||||
|
var matched []DiscoveryTarget |
||||
|
|
||||
|
switch val { |
||||
|
case "All": |
||||
|
matched = targets |
||||
|
default: |
||||
|
for _, tgt := range targets { |
||||
|
if tgt.Name == val { |
||||
|
matched = append(matched, tgt) |
||||
|
break |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if data := EncodeDiscoveryTargets(matched); len(data) > 0 { |
||||
|
resp.DataSegment = data |
||||
|
} |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
// DiscoveryTarget represents a target available for discovery.
|
||||
|
type DiscoveryTarget struct { |
||||
|
Name string // IQN, e.g., "iqn.2024.com.seaweedfs:vol1"
|
||||
|
Address string // IP:port,portal-group, e.g., "10.0.0.1:3260,1"
|
||||
|
} |
||||
|
|
||||
|
// EncodeDiscoveryTargets encodes multiple targets into text parameter format.
|
||||
|
// Each target produces TargetName=<iqn>\0TargetAddress=<addr>\0
|
||||
|
// This handles the special multi-value encoding where the same key can appear
|
||||
|
// multiple times (unlike normal negotiation params).
|
||||
|
func EncodeDiscoveryTargets(targets []DiscoveryTarget) []byte { |
||||
|
if len(targets) == 0 { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
var buf []byte |
||||
|
for _, tgt := range targets { |
||||
|
buf = append(buf, "TargetName="+tgt.Name+"\x00"...) |
||||
|
if tgt.Address != "" { |
||||
|
buf = append(buf, "TargetAddress="+tgt.Address+"\x00"...) |
||||
|
} |
||||
|
} |
||||
|
return buf |
||||
|
} |
||||
@ -0,0 +1,213 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"strings" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestDiscovery(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"send_targets_all", testSendTargetsAll}, |
||||
|
{"send_targets_specific", testSendTargetsSpecific}, |
||||
|
{"send_targets_not_found", testSendTargetsNotFound}, |
||||
|
{"send_targets_empty_list", testSendTargetsEmptyList}, |
||||
|
{"no_send_targets_key", testNoSendTargetsKey}, |
||||
|
{"malformed_text_request", testMalformedTextRequest}, |
||||
|
{"special_chars_in_iqn", testSpecialCharsInIQN}, |
||||
|
{"multiple_targets", testMultipleTargets}, |
||||
|
{"encode_discovery_targets", testEncodeDiscoveryTargets}, |
||||
|
{"encode_empty_targets", testEncodeEmptyTargets}, |
||||
|
{"target_without_address", testTargetWithoutAddress}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
var testTargets = []DiscoveryTarget{ |
||||
|
{Name: "iqn.2024.com.seaweedfs:vol1", Address: "10.0.0.1:3260,1"}, |
||||
|
{Name: "iqn.2024.com.seaweedfs:vol2", Address: "10.0.0.2:3260,1"}, |
||||
|
} |
||||
|
|
||||
|
func makeTextReq(params *Params) *PDU { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(OpTextReq) |
||||
|
p.SetOpSpecific1(FlagF) |
||||
|
p.SetInitiatorTaskTag(0x1234) |
||||
|
if params != nil { |
||||
|
p.DataSegment = params.Encode() |
||||
|
} |
||||
|
return p |
||||
|
} |
||||
|
|
||||
|
func testSendTargetsAll(t *testing.T) { |
||||
|
params := NewParams() |
||||
|
params.Set("SendTargets", "All") |
||||
|
req := makeTextReq(params) |
||||
|
|
||||
|
resp := HandleTextRequest(req, testTargets) |
||||
|
|
||||
|
if resp.Opcode() != OpTextResp { |
||||
|
t.Fatalf("opcode: 0x%02x", resp.Opcode()) |
||||
|
} |
||||
|
if resp.InitiatorTaskTag() != 0x1234 { |
||||
|
t.Fatal("ITT mismatch") |
||||
|
} |
||||
|
|
||||
|
body := string(resp.DataSegment) |
||||
|
if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol1") { |
||||
|
t.Fatal("missing vol1") |
||||
|
} |
||||
|
if !strings.Contains(body, "iqn.2024.com.seaweedfs:vol2") { |
||||
|
t.Fatal("missing vol2") |
||||
|
} |
||||
|
if !strings.Contains(body, "10.0.0.1:3260,1") { |
||||
|
t.Fatal("missing vol1 address") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSendTargetsSpecific(t *testing.T) { |
||||
|
params := NewParams() |
||||
|
params.Set("SendTargets", "iqn.2024.com.seaweedfs:vol2") |
||||
|
req := makeTextReq(params) |
||||
|
|
||||
|
resp := HandleTextRequest(req, testTargets) |
||||
|
body := string(resp.DataSegment) |
||||
|
if !strings.Contains(body, "vol2") { |
||||
|
t.Fatal("missing vol2") |
||||
|
} |
||||
|
// Should not contain vol1
|
||||
|
if strings.Contains(body, "vol1") { |
||||
|
t.Fatal("should not contain vol1") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSendTargetsNotFound(t *testing.T) { |
||||
|
params := NewParams() |
||||
|
params.Set("SendTargets", "iqn.2024.com.seaweedfs:nonexistent") |
||||
|
req := makeTextReq(params) |
||||
|
|
||||
|
resp := HandleTextRequest(req, testTargets) |
||||
|
if len(resp.DataSegment) != 0 { |
||||
|
t.Fatalf("expected empty response, got %q", resp.DataSegment) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSendTargetsEmptyList(t *testing.T) { |
||||
|
params := NewParams() |
||||
|
params.Set("SendTargets", "All") |
||||
|
req := makeTextReq(params) |
||||
|
|
||||
|
resp := HandleTextRequest(req, nil) |
||||
|
if len(resp.DataSegment) != 0 { |
||||
|
t.Fatalf("expected empty response, got %q", resp.DataSegment) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testNoSendTargetsKey(t *testing.T) { |
||||
|
params := NewParams() |
||||
|
params.Set("SomethingElse", "value") |
||||
|
req := makeTextReq(params) |
||||
|
|
||||
|
resp := HandleTextRequest(req, testTargets) |
||||
|
if len(resp.DataSegment) != 0 { |
||||
|
t.Fatal("expected empty response for non-SendTargets request") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testMalformedTextRequest(t *testing.T) { |
||||
|
req := &PDU{} |
||||
|
req.SetOpcode(OpTextReq) |
||||
|
req.SetInitiatorTaskTag(0x5678) |
||||
|
req.DataSegment = []byte("not a valid param format") // no '='
|
||||
|
|
||||
|
resp := HandleTextRequest(req, testTargets) |
||||
|
// Should return empty response without error
|
||||
|
if resp.Opcode() != OpTextResp { |
||||
|
t.Fatalf("opcode: 0x%02x", resp.Opcode()) |
||||
|
} |
||||
|
if resp.InitiatorTaskTag() != 0x5678 { |
||||
|
t.Fatal("ITT mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSpecialCharsInIQN(t *testing.T) { |
||||
|
targets := []DiscoveryTarget{ |
||||
|
{Name: "iqn.2024-01.com.example:storage.tape1.sys1.xyz", Address: "192.168.1.100:3260,1"}, |
||||
|
} |
||||
|
params := NewParams() |
||||
|
params.Set("SendTargets", "All") |
||||
|
req := makeTextReq(params) |
||||
|
|
||||
|
resp := HandleTextRequest(req, targets) |
||||
|
body := string(resp.DataSegment) |
||||
|
if !strings.Contains(body, "iqn.2024-01.com.example:storage.tape1.sys1.xyz") { |
||||
|
t.Fatal("IQN with special chars not found") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testMultipleTargets(t *testing.T) { |
||||
|
targets := make([]DiscoveryTarget, 10) |
||||
|
for i := range targets { |
||||
|
targets[i] = DiscoveryTarget{ |
||||
|
Name: "iqn.2024.com.test:vol" + string(rune('0'+i)), |
||||
|
Address: "10.0.0.1:3260,1", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("SendTargets", "All") |
||||
|
req := makeTextReq(params) |
||||
|
|
||||
|
resp := HandleTextRequest(req, targets) |
||||
|
body := string(resp.DataSegment) |
||||
|
// The response uses EncodeDiscoveryTargets internally via the params,
|
||||
|
// but since Params doesn't allow duplicate keys, the last one wins.
|
||||
|
// This is a known limitation — for multi-target discovery, we use
|
||||
|
// EncodeDiscoveryTargets directly. Let's verify at least the last target.
|
||||
|
if !strings.Contains(body, "TargetName=") { |
||||
|
t.Fatal("no TargetName in response") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testEncodeDiscoveryTargets(t *testing.T) { |
||||
|
encoded := EncodeDiscoveryTargets(testTargets) |
||||
|
body := string(encoded) |
||||
|
|
||||
|
// Should contain both targets with proper format
|
||||
|
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") { |
||||
|
t.Fatal("missing vol1") |
||||
|
} |
||||
|
if !strings.Contains(body, "TargetAddress=10.0.0.1:3260,1\x00") { |
||||
|
t.Fatal("missing vol1 address") |
||||
|
} |
||||
|
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol2\x00") { |
||||
|
t.Fatal("missing vol2") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testEncodeEmptyTargets(t *testing.T) { |
||||
|
encoded := EncodeDiscoveryTargets(nil) |
||||
|
if encoded != nil { |
||||
|
t.Fatalf("expected nil, got %q", encoded) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTargetWithoutAddress(t *testing.T) { |
||||
|
targets := []DiscoveryTarget{ |
||||
|
{Name: "iqn.2024.com.seaweedfs:vol1"}, // no address
|
||||
|
} |
||||
|
encoded := EncodeDiscoveryTargets(targets) |
||||
|
body := string(encoded) |
||||
|
if !strings.Contains(body, "TargetName=iqn.2024.com.seaweedfs:vol1\x00") { |
||||
|
t.Fatal("missing target name") |
||||
|
} |
||||
|
if strings.Contains(body, "TargetAddress") { |
||||
|
t.Fatal("should not have TargetAddress") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,412 @@ |
|||||
|
package iscsi_test |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"encoding/binary" |
||||
|
"io" |
||||
|
"log" |
||||
|
"net" |
||||
|
"path/filepath" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol" |
||||
|
"github.com/seaweedfs/seaweedfs/weed/storage/blockvol/iscsi" |
||||
|
) |
||||
|
|
||||
|
// blockVolAdapter wraps a BlockVol to implement iscsi.BlockDevice.
|
||||
|
type blockVolAdapter struct { |
||||
|
vol *blockvol.BlockVol |
||||
|
} |
||||
|
|
||||
|
func (a *blockVolAdapter) ReadAt(lba uint64, length uint32) ([]byte, error) { |
||||
|
return a.vol.ReadLBA(lba, length) |
||||
|
} |
||||
|
func (a *blockVolAdapter) WriteAt(lba uint64, data []byte) error { |
||||
|
return a.vol.WriteLBA(lba, data) |
||||
|
} |
||||
|
func (a *blockVolAdapter) Trim(lba uint64, length uint32) error { |
||||
|
return a.vol.Trim(lba, length) |
||||
|
} |
||||
|
func (a *blockVolAdapter) SyncCache() error { |
||||
|
return a.vol.SyncCache() |
||||
|
} |
||||
|
func (a *blockVolAdapter) BlockSize() uint32 { |
||||
|
return a.vol.Info().BlockSize |
||||
|
} |
||||
|
func (a *blockVolAdapter) VolumeSize() uint64 { |
||||
|
return a.vol.Info().VolumeSize |
||||
|
} |
||||
|
func (a *blockVolAdapter) IsHealthy() bool { |
||||
|
return a.vol.Info().Healthy |
||||
|
} |
||||
|
|
||||
|
const ( |
||||
|
intTargetName = "iqn.2024.com.seaweedfs:integration" |
||||
|
intInitiatorName = "iqn.2024.com.test:client" |
||||
|
) |
||||
|
|
||||
|
func createTestVol(t *testing.T) *blockvol.BlockVol { |
||||
|
t.Helper() |
||||
|
path := filepath.Join(t.TempDir(), "test.blk") |
||||
|
vol, err := blockvol.CreateBlockVol(path, blockvol.CreateOptions{ |
||||
|
VolumeSize: 1024 * 4096, // 1024 blocks = 4MB
|
||||
|
BlockSize: 4096, |
||||
|
WALSize: 1024 * 1024, // 1MB WAL
|
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
t.Cleanup(func() { vol.Close() }) |
||||
|
return vol |
||||
|
} |
||||
|
|
||||
|
func setupIntegrationTarget(t *testing.T) (net.Conn, *iscsi.TargetServer) { |
||||
|
t.Helper() |
||||
|
vol := createTestVol(t) |
||||
|
adapter := &blockVolAdapter{vol: vol} |
||||
|
|
||||
|
config := iscsi.DefaultTargetConfig() |
||||
|
config.TargetName = intTargetName |
||||
|
logger := log.New(io.Discard, "", 0) |
||||
|
ts := iscsi.NewTargetServer("127.0.0.1:0", config, logger) |
||||
|
ts.AddVolume(intTargetName, adapter) |
||||
|
|
||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0") |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
go ts.Serve(ln) |
||||
|
t.Cleanup(func() { ts.Close() }) |
||||
|
|
||||
|
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
t.Cleanup(func() { conn.Close() }) |
||||
|
|
||||
|
// Login
|
||||
|
params := iscsi.NewParams() |
||||
|
params.Set("InitiatorName", intInitiatorName) |
||||
|
params.Set("TargetName", intTargetName) |
||||
|
params.Set("SessionType", "Normal") |
||||
|
|
||||
|
loginReq := &iscsi.PDU{} |
||||
|
loginReq.SetOpcode(iscsi.OpLoginReq) |
||||
|
loginReq.SetLoginStages(iscsi.StageSecurityNeg, iscsi.StageFullFeature) |
||||
|
loginReq.SetLoginTransit(true) |
||||
|
loginReq.SetISID([6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01}) |
||||
|
loginReq.SetCmdSN(1) |
||||
|
loginReq.DataSegment = params.Encode() |
||||
|
|
||||
|
if err := iscsi.WritePDU(conn, loginReq); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
resp, err := iscsi.ReadPDU(conn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.LoginStatusClass() != iscsi.LoginStatusSuccess { |
||||
|
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
||||
|
} |
||||
|
|
||||
|
return conn, ts |
||||
|
} |
||||
|
|
||||
|
func TestIntegration(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"write_read_verify", testWriteReadVerify}, |
||||
|
{"multi_block_write_read", testMultiBlockWriteRead}, |
||||
|
{"inquiry_capacity", testInquiryCapacity}, |
||||
|
{"sync_cache", testIntSyncCache}, |
||||
|
{"test_unit_ready", testIntTestUnitReady}, |
||||
|
{"concurrent_readers_writers", testConcurrentReadersWriters}, |
||||
|
{"write_at_boundary", testWriteAtBoundary}, |
||||
|
{"unmap_integration", testUnmapIntegration}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func sendSCSICmd(t *testing.T, conn net.Conn, cdb [16]byte, cmdSN uint32, read bool, write bool, dataOut []byte, expLen uint32) *iscsi.PDU { |
||||
|
t.Helper() |
||||
|
cmd := &iscsi.PDU{} |
||||
|
cmd.SetOpcode(iscsi.OpSCSICmd) |
||||
|
flags := uint8(iscsi.FlagF) |
||||
|
if read { |
||||
|
flags |= iscsi.FlagR |
||||
|
} |
||||
|
if write { |
||||
|
flags |= iscsi.FlagW |
||||
|
} |
||||
|
cmd.SetOpSpecific1(flags) |
||||
|
cmd.SetInitiatorTaskTag(cmdSN) |
||||
|
cmd.SetExpectedDataTransferLength(expLen) |
||||
|
cmd.SetCmdSN(cmdSN) |
||||
|
cmd.SetCDB(cdb) |
||||
|
if dataOut != nil { |
||||
|
cmd.DataSegment = dataOut |
||||
|
} |
||||
|
|
||||
|
if err := iscsi.WritePDU(conn, cmd); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := iscsi.ReadPDU(conn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
func testWriteReadVerify(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
// Write pattern to LBA 0
|
||||
|
writeData := make([]byte, 4096) |
||||
|
for i := range writeData { |
||||
|
writeData[i] = byte(i % 251) // prime modulus for pattern
|
||||
|
} |
||||
|
|
||||
|
var wCDB [16]byte |
||||
|
wCDB[0] = iscsi.ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(wCDB[2:6], 0) // LBA 0
|
||||
|
binary.BigEndian.PutUint16(wCDB[7:9], 1) // 1 block
|
||||
|
|
||||
|
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatalf("write failed: status %d", resp.SCSIStatus()) |
||||
|
} |
||||
|
|
||||
|
// Read it back
|
||||
|
var rCDB [16]byte |
||||
|
rCDB[0] = iscsi.ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(rCDB[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(rCDB[7:9], 1) |
||||
|
|
||||
|
resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, 4096) |
||||
|
if resp2.Opcode() != iscsi.OpSCSIDataIn { |
||||
|
t.Fatalf("expected Data-In, got %s", iscsi.OpcodeName(resp2.Opcode())) |
||||
|
} |
||||
|
|
||||
|
if !bytes.Equal(resp2.DataSegment, writeData) { |
||||
|
t.Fatal("data integrity verification failed") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testMultiBlockWriteRead(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
// Write 4 blocks at LBA 10
|
||||
|
blockCount := uint16(4) |
||||
|
dataLen := int(blockCount) * 4096 |
||||
|
writeData := make([]byte, dataLen) |
||||
|
for i := range writeData { |
||||
|
writeData[i] = byte((i / 4096) + 1) // each block has different fill
|
||||
|
} |
||||
|
|
||||
|
var wCDB [16]byte |
||||
|
wCDB[0] = iscsi.ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(wCDB[2:6], 10) |
||||
|
binary.BigEndian.PutUint16(wCDB[7:9], blockCount) |
||||
|
|
||||
|
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, uint32(dataLen)) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatalf("write failed: status %d", resp.SCSIStatus()) |
||||
|
} |
||||
|
|
||||
|
// Read back all 4 blocks
|
||||
|
var rCDB [16]byte |
||||
|
rCDB[0] = iscsi.ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(rCDB[2:6], 10) |
||||
|
binary.BigEndian.PutUint16(rCDB[7:9], blockCount) |
||||
|
|
||||
|
resp2 := sendSCSICmd(t, conn, rCDB, 3, true, false, nil, uint32(dataLen)) |
||||
|
|
||||
|
// May be split across multiple Data-In PDUs — reassemble
|
||||
|
var readData []byte |
||||
|
readData = append(readData, resp2.DataSegment...) |
||||
|
|
||||
|
// If first PDU doesn't have S-bit, read more
|
||||
|
for resp2.OpSpecific1()&iscsi.FlagS == 0 { |
||||
|
var err error |
||||
|
resp2, err = iscsi.ReadPDU(conn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
readData = append(readData, resp2.DataSegment...) |
||||
|
} |
||||
|
|
||||
|
if !bytes.Equal(readData, writeData) { |
||||
|
t.Fatal("multi-block data integrity failed") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInquiryCapacity(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
// TEST UNIT READY
|
||||
|
var tuCDB [16]byte |
||||
|
tuCDB[0] = iscsi.ScsiTestUnitReady |
||||
|
resp := sendSCSICmd(t, conn, tuCDB, 2, false, false, nil, 0) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatal("test unit ready failed") |
||||
|
} |
||||
|
|
||||
|
// INQUIRY
|
||||
|
var iqCDB [16]byte |
||||
|
iqCDB[0] = iscsi.ScsiInquiry |
||||
|
binary.BigEndian.PutUint16(iqCDB[3:5], 96) |
||||
|
resp = sendSCSICmd(t, conn, iqCDB, 3, true, false, nil, 96) |
||||
|
if resp.Opcode() != iscsi.OpSCSIDataIn { |
||||
|
t.Fatal("expected Data-In for inquiry") |
||||
|
} |
||||
|
if resp.DataSegment[0] != 0x00 { // SBC device
|
||||
|
t.Fatal("not SBC device type") |
||||
|
} |
||||
|
|
||||
|
// READ CAPACITY 10
|
||||
|
var rcCDB [16]byte |
||||
|
rcCDB[0] = iscsi.ScsiReadCapacity10 |
||||
|
resp = sendSCSICmd(t, conn, rcCDB, 4, true, false, nil, 8) |
||||
|
if resp.Opcode() != iscsi.OpSCSIDataIn { |
||||
|
t.Fatal("expected Data-In for read capacity") |
||||
|
} |
||||
|
lastLBA := binary.BigEndian.Uint32(resp.DataSegment[0:4]) |
||||
|
blockSize := binary.BigEndian.Uint32(resp.DataSegment[4:8]) |
||||
|
if lastLBA != 1023 { // 1024 blocks, last = 1023
|
||||
|
t.Fatalf("last LBA: %d, expected 1023", lastLBA) |
||||
|
} |
||||
|
if blockSize != 4096 { |
||||
|
t.Fatalf("block size: %d", blockSize) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testIntSyncCache(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = iscsi.ScsiSyncCache10 |
||||
|
resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatalf("sync cache failed: %d", resp.SCSIStatus()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testIntTestUnitReady(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = iscsi.ScsiTestUnitReady |
||||
|
resp := sendSCSICmd(t, conn, cdb, 2, false, false, nil, 0) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatal("TUR failed") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testConcurrentReadersWriters(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
// Sequential writes to different LBAs, then sequential reads to verify
|
||||
|
cmdSN := uint32(2) |
||||
|
for lba := uint32(0); lba < 10; lba++ { |
||||
|
data := bytes.Repeat([]byte{byte(lba)}, 4096) |
||||
|
var wCDB [16]byte |
||||
|
wCDB[0] = iscsi.ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(wCDB[2:6], lba) |
||||
|
binary.BigEndian.PutUint16(wCDB[7:9], 1) |
||||
|
resp := sendSCSICmd(t, conn, wCDB, cmdSN, false, true, data, 4096) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatalf("write LBA %d failed", lba) |
||||
|
} |
||||
|
cmdSN++ |
||||
|
} |
||||
|
|
||||
|
// Read back and verify
|
||||
|
for lba := uint32(0); lba < 10; lba++ { |
||||
|
var rCDB [16]byte |
||||
|
rCDB[0] = iscsi.ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(rCDB[2:6], lba) |
||||
|
binary.BigEndian.PutUint16(rCDB[7:9], 1) |
||||
|
resp := sendSCSICmd(t, conn, rCDB, cmdSN, true, false, nil, 4096) |
||||
|
if resp.Opcode() != iscsi.OpSCSIDataIn { |
||||
|
t.Fatalf("LBA %d: expected Data-In", lba) |
||||
|
} |
||||
|
expected := bytes.Repeat([]byte{byte(lba)}, 4096) |
||||
|
if !bytes.Equal(resp.DataSegment, expected) { |
||||
|
t.Fatalf("LBA %d: data mismatch", lba) |
||||
|
} |
||||
|
cmdSN++ |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWriteAtBoundary(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
// Write at last valid LBA (1023)
|
||||
|
data := bytes.Repeat([]byte{0xFF}, 4096) |
||||
|
var wCDB [16]byte |
||||
|
wCDB[0] = iscsi.ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(wCDB[2:6], 1023) |
||||
|
binary.BigEndian.PutUint16(wCDB[7:9], 1) |
||||
|
resp := sendSCSICmd(t, conn, wCDB, 2, false, true, data, 4096) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatal("boundary write failed") |
||||
|
} |
||||
|
|
||||
|
// Write past end — should fail
|
||||
|
var wCDB2 [16]byte |
||||
|
wCDB2[0] = iscsi.ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(wCDB2[2:6], 1024) // out of bounds
|
||||
|
binary.BigEndian.PutUint16(wCDB2[7:9], 1) |
||||
|
resp2 := sendSCSICmd(t, conn, wCDB2, 3, false, true, data, 4096) |
||||
|
if resp2.SCSIStatus() == iscsi.SCSIStatusGood { |
||||
|
t.Fatal("OOB write should fail") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testUnmapIntegration(t *testing.T) { |
||||
|
conn, _ := setupIntegrationTarget(t) |
||||
|
|
||||
|
// Write data at LBA 5
|
||||
|
writeData := bytes.Repeat([]byte{0xCC}, 4096) |
||||
|
var wCDB [16]byte |
||||
|
wCDB[0] = iscsi.ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(wCDB[2:6], 5) |
||||
|
binary.BigEndian.PutUint16(wCDB[7:9], 1) |
||||
|
sendSCSICmd(t, conn, wCDB, 2, false, true, writeData, 4096) |
||||
|
|
||||
|
// UNMAP LBA 5
|
||||
|
unmapData := make([]byte, 24) |
||||
|
binary.BigEndian.PutUint16(unmapData[0:2], 22) |
||||
|
binary.BigEndian.PutUint16(unmapData[2:4], 16) |
||||
|
binary.BigEndian.PutUint64(unmapData[8:16], 5) |
||||
|
binary.BigEndian.PutUint32(unmapData[16:20], 1) |
||||
|
|
||||
|
var uCDB [16]byte |
||||
|
uCDB[0] = iscsi.ScsiUnmap |
||||
|
resp := sendSCSICmd(t, conn, uCDB, 3, false, true, unmapData, uint32(len(unmapData))) |
||||
|
if resp.SCSIStatus() != iscsi.SCSIStatusGood { |
||||
|
t.Fatalf("unmap failed: %d", resp.SCSIStatus()) |
||||
|
} |
||||
|
|
||||
|
// Read back — should be zeros
|
||||
|
var rCDB [16]byte |
||||
|
rCDB[0] = iscsi.ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(rCDB[2:6], 5) |
||||
|
binary.BigEndian.PutUint16(rCDB[7:9], 1) |
||||
|
resp2 := sendSCSICmd(t, conn, rCDB, 4, true, false, nil, 4096) |
||||
|
if resp2.Opcode() != iscsi.OpSCSIDataIn { |
||||
|
t.Fatal("expected Data-In") |
||||
|
} |
||||
|
zeros := make([]byte, 4096) |
||||
|
if !bytes.Equal(resp2.DataSegment, zeros) { |
||||
|
t.Fatal("unmapped block should return zeros") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,396 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"strconv" |
||||
|
) |
||||
|
|
||||
|
// Login status classes (RFC 7143, Section 11.13.5)
|
||||
|
const ( |
||||
|
LoginStatusSuccess uint8 = 0x00 |
||||
|
LoginStatusRedirect uint8 = 0x01 |
||||
|
LoginStatusInitiatorErr uint8 = 0x02 |
||||
|
LoginStatusTargetErr uint8 = 0x03 |
||||
|
) |
||||
|
|
||||
|
// Login status details
|
||||
|
const ( |
||||
|
LoginDetailSuccess uint8 = 0x00 |
||||
|
LoginDetailTargetMoved uint8 = 0x01 // redirect: permanently moved
|
||||
|
LoginDetailTargetMovedTemp uint8 = 0x02 // redirect: temporarily moved
|
||||
|
LoginDetailInitiatorError uint8 = 0x00 // initiator error (miscellaneous)
|
||||
|
LoginDetailAuthFailure uint8 = 0x01 // authentication failure
|
||||
|
LoginDetailAuthorizationFail uint8 = 0x02 // authorization failure
|
||||
|
LoginDetailNotFound uint8 = 0x03 // target not found
|
||||
|
LoginDetailTargetRemoved uint8 = 0x04 // target removed
|
||||
|
LoginDetailUnsupported uint8 = 0x05 // unsupported version
|
||||
|
LoginDetailTooManyConns uint8 = 0x06 // too many connections
|
||||
|
LoginDetailMissingParam uint8 = 0x07 // missing parameter
|
||||
|
LoginDetailNoSessionSlot uint8 = 0x08 // no session slot
|
||||
|
LoginDetailNoTCPConn uint8 = 0x09 // no TCP connection available
|
||||
|
LoginDetailNoSession uint8 = 0x0a // no existing session
|
||||
|
LoginDetailTargetError uint8 = 0x00 // target error (miscellaneous)
|
||||
|
LoginDetailServiceUnavail uint8 = 0x01 // service unavailable
|
||||
|
LoginDetailOutOfResources uint8 = 0x02 // out of resources
|
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrLoginInvalidStage = errors.New("iscsi: invalid login stage transition") |
||||
|
ErrLoginMissingParam = errors.New("iscsi: missing required login parameter") |
||||
|
ErrLoginInvalidISID = errors.New("iscsi: invalid ISID") |
||||
|
ErrLoginTargetNotFound = errors.New("iscsi: target not found") |
||||
|
ErrLoginSessionExists = errors.New("iscsi: session already exists for this ISID") |
||||
|
ErrLoginInvalidRequest = errors.New("iscsi: invalid login request") |
||||
|
) |
||||
|
|
||||
|
// LoginPhase tracks the current phase of login negotiation.
|
||||
|
type LoginPhase int |
||||
|
|
||||
|
const ( |
||||
|
LoginPhaseStart LoginPhase = iota // before first PDU
|
||||
|
LoginPhaseSecurity // CSG=SecurityNeg
|
||||
|
LoginPhaseOperational // CSG=LoginOp
|
||||
|
LoginPhaseDone // transition to FFP complete
|
||||
|
) |
||||
|
|
||||
|
// TargetConfig holds the target-side negotiation defaults.
|
||||
|
type TargetConfig struct { |
||||
|
TargetName string |
||||
|
TargetAlias string |
||||
|
MaxRecvDataSegmentLength int |
||||
|
MaxBurstLength int |
||||
|
FirstBurstLength int |
||||
|
MaxConnections int |
||||
|
MaxOutstandingR2T int |
||||
|
DefaultTime2Wait int |
||||
|
DefaultTime2Retain int |
||||
|
DataPDUInOrder bool |
||||
|
DataSequenceInOrder bool |
||||
|
InitialR2T bool |
||||
|
ImmediateData bool |
||||
|
ErrorRecoveryLevel int |
||||
|
} |
||||
|
|
||||
|
// DefaultTargetConfig returns sensible defaults for a target.
|
||||
|
func DefaultTargetConfig() TargetConfig { |
||||
|
return TargetConfig{ |
||||
|
MaxRecvDataSegmentLength: 262144, // 256KB
|
||||
|
MaxBurstLength: 262144, |
||||
|
FirstBurstLength: 65536, |
||||
|
MaxConnections: 1, |
||||
|
MaxOutstandingR2T: 1, |
||||
|
DefaultTime2Wait: 2, |
||||
|
DefaultTime2Retain: 0, |
||||
|
DataPDUInOrder: true, |
||||
|
DataSequenceInOrder: true, |
||||
|
InitialR2T: true, |
||||
|
ImmediateData: true, |
||||
|
ErrorRecoveryLevel: 0, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// LoginNegotiator handles the target side of login negotiation.
|
||||
|
type LoginNegotiator struct { |
||||
|
config TargetConfig |
||||
|
phase LoginPhase |
||||
|
isid [6]byte |
||||
|
tsih uint16 |
||||
|
targetOK bool // target name validated
|
||||
|
|
||||
|
// Negotiated values (updated during negotiation)
|
||||
|
NegMaxRecvDataSegLen int |
||||
|
NegMaxBurstLength int |
||||
|
NegFirstBurstLength int |
||||
|
NegInitialR2T bool |
||||
|
NegImmediateData bool |
||||
|
|
||||
|
// Initiator/target info captured during login
|
||||
|
InitiatorName string |
||||
|
TargetName string |
||||
|
SessionType string // "Normal" or "Discovery"
|
||||
|
} |
||||
|
|
||||
|
// NewLoginNegotiator creates a negotiator for a new login sequence.
|
||||
|
func NewLoginNegotiator(config TargetConfig) *LoginNegotiator { |
||||
|
return &LoginNegotiator{ |
||||
|
config: config, |
||||
|
phase: LoginPhaseStart, |
||||
|
NegMaxRecvDataSegLen: config.MaxRecvDataSegmentLength, |
||||
|
NegMaxBurstLength: config.MaxBurstLength, |
||||
|
NegFirstBurstLength: config.FirstBurstLength, |
||||
|
NegInitialR2T: config.InitialR2T, |
||||
|
NegImmediateData: config.ImmediateData, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// HandleLoginPDU processes one login request PDU and returns the response PDU.
|
||||
|
// It manages stage transitions and parameter negotiation.
|
||||
|
func (ln *LoginNegotiator) HandleLoginPDU(req *PDU, resolver TargetResolver) *PDU { |
||||
|
resp := &PDU{} |
||||
|
resp.SetOpcode(OpLoginResp) |
||||
|
resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) |
||||
|
resp.SetISID(req.ISID()) |
||||
|
resp.SetTSIH(req.TSIH()) |
||||
|
|
||||
|
// Validate opcode
|
||||
|
if req.Opcode() != OpLoginReq { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
csg := req.LoginCSG() |
||||
|
nsg := req.LoginNSG() |
||||
|
transit := req.LoginTransit() |
||||
|
|
||||
|
// Parse text parameters from data segment
|
||||
|
params, err := ParseParams(req.DataSegment) |
||||
|
if err != nil { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
// Process based on current stage
|
||||
|
respParams := NewParams() |
||||
|
|
||||
|
switch csg { |
||||
|
case StageSecurityNeg: |
||||
|
if ln.phase != LoginPhaseStart && ln.phase != LoginPhaseSecurity { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
||||
|
return resp |
||||
|
} |
||||
|
ln.phase = LoginPhaseSecurity |
||||
|
|
||||
|
// Capture initiator name
|
||||
|
if name, ok := params.Get("InitiatorName"); ok { |
||||
|
ln.InitiatorName = name |
||||
|
} else if ln.InitiatorName == "" { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
// Session type
|
||||
|
if st, ok := params.Get("SessionType"); ok { |
||||
|
ln.SessionType = st |
||||
|
} |
||||
|
if ln.SessionType == "" { |
||||
|
ln.SessionType = "Normal" |
||||
|
} |
||||
|
|
||||
|
// Target name (required for Normal sessions)
|
||||
|
if tn, ok := params.Get("TargetName"); ok { |
||||
|
if ln.SessionType == "Normal" { |
||||
|
if resolver == nil || !resolver.HasTarget(tn) { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailNotFound) |
||||
|
return resp |
||||
|
} |
||||
|
} |
||||
|
ln.TargetName = tn |
||||
|
ln.targetOK = true |
||||
|
} else if ln.SessionType == "Normal" && !ln.targetOK { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
// ISID
|
||||
|
ln.isid = req.ISID() |
||||
|
|
||||
|
// We don't implement CHAP — declare AuthMethod=None
|
||||
|
respParams.Set("AuthMethod", "None") |
||||
|
|
||||
|
if transit { |
||||
|
if nsg == StageLoginOp { |
||||
|
ln.phase = LoginPhaseOperational |
||||
|
} else if nsg == StageFullFeature { |
||||
|
ln.phase = LoginPhaseDone |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
case StageLoginOp: |
||||
|
if ln.phase != LoginPhaseOperational && ln.phase != LoginPhaseSecurity { |
||||
|
// Allow direct jump to LoginOp if security was skipped
|
||||
|
if ln.phase == LoginPhaseStart { |
||||
|
// Need InitiatorName at minimum
|
||||
|
if name, ok := params.Get("InitiatorName"); ok { |
||||
|
ln.InitiatorName = name |
||||
|
} else if ln.InitiatorName == "" { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailMissingParam) |
||||
|
return resp |
||||
|
} |
||||
|
} else { |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
||||
|
return resp |
||||
|
} |
||||
|
} |
||||
|
ln.phase = LoginPhaseOperational |
||||
|
|
||||
|
// Negotiate operational parameters
|
||||
|
ln.negotiateParams(params, respParams) |
||||
|
|
||||
|
if transit && nsg == StageFullFeature { |
||||
|
ln.phase = LoginPhaseDone |
||||
|
} |
||||
|
|
||||
|
default: |
||||
|
setLoginReject(resp, LoginStatusInitiatorErr, LoginDetailInitiatorError) |
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
// Build response
|
||||
|
resp.SetLoginStages(csg, nsg) |
||||
|
if transit { |
||||
|
resp.SetLoginTransit(true) |
||||
|
} |
||||
|
resp.SetLoginStatus(LoginStatusSuccess, LoginDetailSuccess) |
||||
|
|
||||
|
// Assign TSIH on first successful login
|
||||
|
if ln.tsih == 0 { |
||||
|
ln.tsih = 1 // simplified: single session
|
||||
|
} |
||||
|
resp.SetTSIH(ln.tsih) |
||||
|
|
||||
|
// Encode response params
|
||||
|
if respParams.Len() > 0 { |
||||
|
resp.DataSegment = respParams.Encode() |
||||
|
} |
||||
|
|
||||
|
return resp |
||||
|
} |
||||
|
|
||||
|
// Done returns true if login negotiation is complete.
|
||||
|
func (ln *LoginNegotiator) Done() bool { |
||||
|
return ln.phase == LoginPhaseDone |
||||
|
} |
||||
|
|
||||
|
// Phase returns the current login phase.
|
||||
|
func (ln *LoginNegotiator) Phase() LoginPhase { |
||||
|
return ln.phase |
||||
|
} |
||||
|
|
||||
|
// negotiateParams processes operational parameter negotiation.
|
||||
|
func (ln *LoginNegotiator) negotiateParams(req *Params, resp *Params) { |
||||
|
req.Each(func(key, value string) { |
||||
|
switch key { |
||||
|
case "MaxRecvDataSegmentLength": |
||||
|
if v, err := NegotiateNumber(value, ln.config.MaxRecvDataSegmentLength, 512, 16777215); err == nil { |
||||
|
ln.NegMaxRecvDataSegLen = v |
||||
|
resp.Set(key, strconv.Itoa(ln.config.MaxRecvDataSegmentLength)) |
||||
|
} |
||||
|
case "MaxBurstLength": |
||||
|
if v, err := NegotiateNumber(value, ln.config.MaxBurstLength, 512, 16777215); err == nil { |
||||
|
ln.NegMaxBurstLength = v |
||||
|
resp.Set(key, strconv.Itoa(v)) |
||||
|
} |
||||
|
case "FirstBurstLength": |
||||
|
if v, err := NegotiateNumber(value, ln.config.FirstBurstLength, 512, 16777215); err == nil { |
||||
|
ln.NegFirstBurstLength = v |
||||
|
resp.Set(key, strconv.Itoa(v)) |
||||
|
} |
||||
|
case "InitialR2T": |
||||
|
if v, err := NegotiateBool(value, ln.config.InitialR2T); err == nil { |
||||
|
// InitialR2T uses OR semantics: result is Yes if either side says Yes
|
||||
|
if value == "Yes" || ln.config.InitialR2T { |
||||
|
ln.NegInitialR2T = true |
||||
|
} else { |
||||
|
ln.NegInitialR2T = v |
||||
|
} |
||||
|
resp.Set(key, BoolStr(ln.NegInitialR2T)) |
||||
|
} |
||||
|
case "ImmediateData": |
||||
|
if v, err := NegotiateBool(value, ln.config.ImmediateData); err == nil { |
||||
|
ln.NegImmediateData = v |
||||
|
resp.Set(key, BoolStr(v)) |
||||
|
} |
||||
|
case "MaxConnections": |
||||
|
resp.Set(key, strconv.Itoa(ln.config.MaxConnections)) |
||||
|
case "DataPDUInOrder": |
||||
|
resp.Set(key, BoolStr(ln.config.DataPDUInOrder)) |
||||
|
case "DataSequenceInOrder": |
||||
|
resp.Set(key, BoolStr(ln.config.DataSequenceInOrder)) |
||||
|
case "DefaultTime2Wait": |
||||
|
resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Wait)) |
||||
|
case "DefaultTime2Retain": |
||||
|
resp.Set(key, strconv.Itoa(ln.config.DefaultTime2Retain)) |
||||
|
case "MaxOutstandingR2T": |
||||
|
resp.Set(key, strconv.Itoa(ln.config.MaxOutstandingR2T)) |
||||
|
case "ErrorRecoveryLevel": |
||||
|
resp.Set(key, strconv.Itoa(ln.config.ErrorRecoveryLevel)) |
||||
|
case "HeaderDigest": |
||||
|
resp.Set(key, "None") |
||||
|
case "DataDigest": |
||||
|
resp.Set(key, "None") |
||||
|
case "TargetName", "InitiatorName", "SessionType", "AuthMethod": |
||||
|
// Already handled or declarative — skip
|
||||
|
case "TargetAlias": |
||||
|
// Informational from initiator — skip
|
||||
|
default: |
||||
|
// Unknown keys: respond with NotUnderstood
|
||||
|
resp.Set(key, "NotUnderstood") |
||||
|
} |
||||
|
}) |
||||
|
|
||||
|
// Always declare our TargetAlias if configured
|
||||
|
if ln.config.TargetAlias != "" { |
||||
|
resp.Set("TargetAlias", ln.config.TargetAlias) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TargetResolver allows the login state machine to check if a target exists.
|
||||
|
type TargetResolver interface { |
||||
|
HasTarget(name string) bool |
||||
|
} |
||||
|
|
||||
|
func setLoginReject(resp *PDU, class, detail uint8) { |
||||
|
resp.SetLoginStatus(class, detail) |
||||
|
resp.SetLoginTransit(false) |
||||
|
} |
||||
|
|
||||
|
// LoginResult contains the outcome of a completed login negotiation.
|
||||
|
type LoginResult struct { |
||||
|
InitiatorName string |
||||
|
TargetName string |
||||
|
SessionType string |
||||
|
ISID [6]byte |
||||
|
TSIH uint16 |
||||
|
MaxRecvDataSegLen int |
||||
|
MaxBurstLength int |
||||
|
FirstBurstLength int |
||||
|
InitialR2T bool |
||||
|
ImmediateData bool |
||||
|
} |
||||
|
|
||||
|
// Result returns the negotiation outcome. Only valid after Done() returns true.
|
||||
|
func (ln *LoginNegotiator) Result() LoginResult { |
||||
|
return LoginResult{ |
||||
|
InitiatorName: ln.InitiatorName, |
||||
|
TargetName: ln.TargetName, |
||||
|
SessionType: ln.SessionType, |
||||
|
ISID: ln.isid, |
||||
|
TSIH: ln.tsih, |
||||
|
MaxRecvDataSegLen: ln.NegMaxRecvDataSegLen, |
||||
|
MaxBurstLength: ln.NegMaxBurstLength, |
||||
|
FirstBurstLength: ln.NegFirstBurstLength, |
||||
|
InitialR2T: ln.NegInitialR2T, |
||||
|
ImmediateData: ln.NegImmediateData, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// BuildRedirectResponse creates a login response that redirects the initiator
|
||||
|
// to a different target address.
|
||||
|
func BuildRedirectResponse(req *PDU, addr string, permanent bool) *PDU { |
||||
|
resp := &PDU{} |
||||
|
resp.SetOpcode(OpLoginResp) |
||||
|
resp.SetInitiatorTaskTag(req.InitiatorTaskTag()) |
||||
|
resp.SetISID(req.ISID()) |
||||
|
|
||||
|
detail := LoginDetailTargetMovedTemp |
||||
|
if permanent { |
||||
|
detail = LoginDetailTargetMoved |
||||
|
} |
||||
|
resp.SetLoginStatus(LoginStatusRedirect, detail) |
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("TargetAddress", addr) |
||||
|
resp.DataSegment = params.Encode() |
||||
|
|
||||
|
return resp |
||||
|
} |
||||
@ -0,0 +1,444 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
// mockResolver implements TargetResolver for tests.
|
||||
|
type mockResolver struct { |
||||
|
targets map[string]bool |
||||
|
} |
||||
|
|
||||
|
func (m *mockResolver) HasTarget(name string) bool { |
||||
|
return m.targets[name] |
||||
|
} |
||||
|
|
||||
|
func newResolver(names ...string) *mockResolver { |
||||
|
r := &mockResolver{targets: make(map[string]bool)} |
||||
|
for _, n := range names { |
||||
|
r.targets[n] = true |
||||
|
} |
||||
|
return r |
||||
|
} |
||||
|
|
||||
|
const testTargetName = "iqn.2024.com.seaweedfs:vol1" |
||||
|
const testInitiatorName = "iqn.2024.com.test:client1" |
||||
|
|
||||
|
func TestLogin(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"single_pdu_security_to_ffp", testSinglePDUSecurityToFFP}, |
||||
|
{"two_phase_login", testTwoPhaseLogin}, |
||||
|
{"security_to_loginop_to_ffp", testSecurityToLoginOpToFFP}, |
||||
|
{"missing_initiator_name", testMissingInitiatorName}, |
||||
|
{"target_not_found", testTargetNotFound}, |
||||
|
{"discovery_session_no_target", testDiscoverySessionNoTarget}, |
||||
|
{"wrong_opcode", testWrongOpcode}, |
||||
|
{"operational_negotiation", testOperationalNegotiation}, |
||||
|
{"redirect_permanent", testRedirectPermanent}, |
||||
|
{"redirect_temporary", testRedirectTemporary}, |
||||
|
{"login_result", testLoginResult}, |
||||
|
{"unknown_key_not_understood", testUnknownKeyNotUnderstood}, |
||||
|
{"header_data_digest_none", testHeaderDataDigestNone}, |
||||
|
{"duplicate_login_pdu", testDuplicateLoginPDU}, |
||||
|
{"invalid_csg", testInvalidCSG}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func makeLoginReq(csg, nsg uint8, transit bool, params *Params) *PDU { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(OpLoginReq) |
||||
|
p.SetLoginStages(csg, nsg) |
||||
|
if transit { |
||||
|
p.SetLoginTransit(true) |
||||
|
} |
||||
|
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} |
||||
|
p.SetISID(isid) |
||||
|
if params != nil { |
||||
|
p.DataSegment = params.Encode() |
||||
|
} |
||||
|
return p |
||||
|
} |
||||
|
|
||||
|
func testSinglePDUSecurityToFFP(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
params.Set("SessionType", "Normal") |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
resp := ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("status class: %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
if !resp.LoginTransit() { |
||||
|
t.Fatal("transit should be set") |
||||
|
} |
||||
|
if !ln.Done() { |
||||
|
t.Fatal("should be done") |
||||
|
} |
||||
|
if resp.TSIH() == 0 { |
||||
|
t.Fatal("TSIH should be assigned") |
||||
|
} |
||||
|
|
||||
|
// Verify AuthMethod=None in response
|
||||
|
rp, err := ParseParams(resp.DataSegment) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if v, ok := rp.Get("AuthMethod"); !ok || v != "None" { |
||||
|
t.Fatalf("AuthMethod: %q, %v", v, ok) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTwoPhaseLogin(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
// Phase 1: Security -> LoginOp
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
params.Set("SessionType", "Normal") |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
||||
|
resp := ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("phase 1 status: %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
if ln.Done() { |
||||
|
t.Fatal("should not be done yet") |
||||
|
} |
||||
|
|
||||
|
// Phase 2: LoginOp -> FFP
|
||||
|
params2 := NewParams() |
||||
|
params2.Set("MaxRecvDataSegmentLength", "65536") |
||||
|
|
||||
|
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
||||
|
resp2 := ln.HandleLoginPDU(req2, resolver) |
||||
|
|
||||
|
if resp2.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("phase 2 status: %d", resp2.LoginStatusClass()) |
||||
|
} |
||||
|
if !ln.Done() { |
||||
|
t.Fatal("should be done") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSecurityToLoginOpToFFP(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
// Security (no transit)
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params) |
||||
|
resp := ln.HandleLoginPDU(req, resolver) |
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("status: %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
|
||||
|
// Security -> LoginOp (transit)
|
||||
|
req2 := makeLoginReq(StageSecurityNeg, StageLoginOp, true, NewParams()) |
||||
|
resp2 := ln.HandleLoginPDU(req2, resolver) |
||||
|
if resp2.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("status: %d", resp2.LoginStatusClass()) |
||||
|
} |
||||
|
|
||||
|
// LoginOp -> FFP (transit)
|
||||
|
req3 := makeLoginReq(StageLoginOp, StageFullFeature, true, NewParams()) |
||||
|
resp3 := ln.HandleLoginPDU(req3, resolver) |
||||
|
if resp3.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("status: %d", resp3.LoginStatusClass()) |
||||
|
} |
||||
|
if !ln.Done() { |
||||
|
t.Fatal("should be done") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testMissingInitiatorName(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
// No InitiatorName
|
||||
|
params := NewParams() |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
resp := ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
||||
|
t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
if resp.LoginStatusDetail() != LoginDetailMissingParam { |
||||
|
t.Fatalf("expected missing param, got %d", resp.LoginStatusDetail()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTargetNotFound(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver() // empty
|
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", "iqn.2024.com.seaweedfs:nonexistent") |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
resp := ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
||||
|
t.Fatalf("expected initiator error, got %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
if resp.LoginStatusDetail() != LoginDetailNotFound { |
||||
|
t.Fatalf("expected not found, got %d", resp.LoginStatusDetail()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDiscoverySessionNoTarget(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("SessionType", "Discovery") |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
resp := ln.HandleLoginPDU(req, nil) // no resolver needed
|
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("status: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
||||
|
} |
||||
|
if !ln.Done() { |
||||
|
t.Fatal("should be done") |
||||
|
} |
||||
|
if ln.SessionType != "Discovery" { |
||||
|
t.Fatalf("session type: %q", ln.SessionType) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWrongOpcode(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
|
||||
|
req := &PDU{} |
||||
|
req.SetOpcode(OpSCSICmd) // wrong
|
||||
|
resp := ln.HandleLoginPDU(req, nil) |
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
||||
|
t.Fatalf("expected error, got %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testOperationalNegotiation(t *testing.T) { |
||||
|
config := DefaultTargetConfig() |
||||
|
config.TargetAlias = "test-target" |
||||
|
ln := NewLoginNegotiator(config) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
// First: security phase
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
||||
|
ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
// Second: operational negotiation
|
||||
|
params2 := NewParams() |
||||
|
params2.Set("MaxRecvDataSegmentLength", "65536") |
||||
|
params2.Set("MaxBurstLength", "131072") |
||||
|
params2.Set("FirstBurstLength", "32768") |
||||
|
params2.Set("InitialR2T", "Yes") |
||||
|
params2.Set("ImmediateData", "No") |
||||
|
params2.Set("HeaderDigest", "CRC32C,None") |
||||
|
params2.Set("DataDigest", "CRC32C,None") |
||||
|
|
||||
|
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
||||
|
resp := ln.HandleLoginPDU(req2, resolver) |
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("status: %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
|
||||
|
rp, err := ParseParams(resp.DataSegment) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
// Verify negotiated values
|
||||
|
if v, ok := rp.Get("HeaderDigest"); !ok || v != "None" { |
||||
|
t.Fatalf("HeaderDigest: %q", v) |
||||
|
} |
||||
|
if v, ok := rp.Get("DataDigest"); !ok || v != "None" { |
||||
|
t.Fatalf("DataDigest: %q", v) |
||||
|
} |
||||
|
if v, ok := rp.Get("InitialR2T"); !ok || v != "Yes" { |
||||
|
t.Fatalf("InitialR2T: %q", v) |
||||
|
} |
||||
|
if v, ok := rp.Get("ImmediateData"); !ok || v != "No" { |
||||
|
t.Fatalf("ImmediateData: %q", v) |
||||
|
} |
||||
|
if v, ok := rp.Get("TargetAlias"); !ok || v != "test-target" { |
||||
|
t.Fatalf("TargetAlias: %q", v) |
||||
|
} |
||||
|
|
||||
|
// Check negotiated state
|
||||
|
if ln.NegInitialR2T != true { |
||||
|
t.Fatal("NegInitialR2T should be true") |
||||
|
} |
||||
|
if ln.NegImmediateData != false { |
||||
|
t.Fatal("NegImmediateData should be false") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRedirectPermanent(t *testing.T) { |
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams()) |
||||
|
resp := BuildRedirectResponse(req, "10.0.0.1:3260,1", true) |
||||
|
|
||||
|
if resp.LoginStatusClass() != LoginStatusRedirect { |
||||
|
t.Fatalf("class: %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
if resp.LoginStatusDetail() != LoginDetailTargetMoved { |
||||
|
t.Fatalf("detail: %d", resp.LoginStatusDetail()) |
||||
|
} |
||||
|
|
||||
|
rp, err := ParseParams(resp.DataSegment) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if v, _ := rp.Get("TargetAddress"); v != "10.0.0.1:3260,1" { |
||||
|
t.Fatalf("TargetAddress: %q", v) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRedirectTemporary(t *testing.T) { |
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, NewParams()) |
||||
|
resp := BuildRedirectResponse(req, "192.168.1.100:3260,2", false) |
||||
|
|
||||
|
if resp.LoginStatusDetail() != LoginDetailTargetMovedTemp { |
||||
|
t.Fatalf("detail: %d", resp.LoginStatusDetail()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLoginResult(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
result := ln.Result() |
||||
|
if result.InitiatorName != testInitiatorName { |
||||
|
t.Fatalf("initiator: %q", result.InitiatorName) |
||||
|
} |
||||
|
if result.SessionType != "Normal" { |
||||
|
t.Fatalf("session type: %q", result.SessionType) |
||||
|
} |
||||
|
if result.TSIH == 0 { |
||||
|
t.Fatal("TSIH should be assigned") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testUnknownKeyNotUnderstood(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
// Security phase first
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
||||
|
ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
// Op phase with unknown key
|
||||
|
params2 := NewParams() |
||||
|
params2.Set("X-CustomKey", "whatever") |
||||
|
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
||||
|
resp := ln.HandleLoginPDU(req2, resolver) |
||||
|
|
||||
|
rp, err := ParseParams(resp.DataSegment) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if v, ok := rp.Get("X-CustomKey"); !ok || v != "NotUnderstood" { |
||||
|
t.Fatalf("expected NotUnderstood, got %q, %v", v, ok) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testHeaderDataDigestNone(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
req := makeLoginReq(StageSecurityNeg, StageLoginOp, true, params) |
||||
|
ln.HandleLoginPDU(req, resolver) |
||||
|
|
||||
|
params2 := NewParams() |
||||
|
params2.Set("HeaderDigest", "CRC32C") |
||||
|
params2.Set("DataDigest", "CRC32C") |
||||
|
req2 := makeLoginReq(StageLoginOp, StageFullFeature, true, params2) |
||||
|
resp := ln.HandleLoginPDU(req2, resolver) |
||||
|
|
||||
|
rp, _ := ParseParams(resp.DataSegment) |
||||
|
if v, _ := rp.Get("HeaderDigest"); v != "None" { |
||||
|
t.Fatalf("HeaderDigest: %q (we always negotiate None)", v) |
||||
|
} |
||||
|
if v, _ := rp.Get("DataDigest"); v != "None" { |
||||
|
t.Fatalf("DataDigest: %q", v) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDuplicateLoginPDU(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
resolver := newResolver(testTargetName) |
||||
|
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
|
||||
|
// Send same security PDU twice (no transit)
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageSecurityNeg, false, params) |
||||
|
resp1 := ln.HandleLoginPDU(req, resolver) |
||||
|
if resp1.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("first: %d", resp1.LoginStatusClass()) |
||||
|
} |
||||
|
|
||||
|
// Same stage again should still work
|
||||
|
resp2 := ln.HandleLoginPDU(req, resolver) |
||||
|
if resp2.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("second: %d", resp2.LoginStatusClass()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInvalidCSG(t *testing.T) { |
||||
|
ln := NewLoginNegotiator(DefaultTargetConfig()) |
||||
|
|
||||
|
// CSG=FullFeature (3) is invalid as a current stage
|
||||
|
req := &PDU{} |
||||
|
req.SetOpcode(OpLoginReq) |
||||
|
req.SetLoginStages(StageFullFeature, StageFullFeature) |
||||
|
req.SetLoginTransit(true) |
||||
|
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} |
||||
|
req.SetISID(isid) |
||||
|
|
||||
|
resp := ln.HandleLoginPDU(req, nil) |
||||
|
if resp.LoginStatusClass() != LoginStatusInitiatorErr { |
||||
|
t.Fatalf("expected error for invalid CSG, got %d", resp.LoginStatusClass()) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,183 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
) |
||||
|
|
||||
|
// iSCSI key-value text parameters (RFC 7143, Section 6).
|
||||
|
// Parameters are encoded as "Key=Value\0" in the data segment.
|
||||
|
|
||||
|
var ( |
||||
|
ErrMalformedParam = errors.New("iscsi: malformed parameter (missing '=')") |
||||
|
ErrEmptyKey = errors.New("iscsi: empty parameter key") |
||||
|
ErrDuplicateKey = errors.New("iscsi: duplicate parameter key") |
||||
|
) |
||||
|
|
||||
|
// Params is an ordered list of iSCSI key-value parameters.
|
||||
|
// Order matters for negotiation, so we use a slice rather than a map.
|
||||
|
type Params struct { |
||||
|
items []paramItem |
||||
|
} |
||||
|
|
||||
|
type paramItem struct { |
||||
|
key string |
||||
|
value string |
||||
|
} |
||||
|
|
||||
|
// ParseParams decodes "Key=Value\0Key=Value\0..." from raw bytes.
|
||||
|
// Empty trailing segments (from a trailing \0) are ignored.
|
||||
|
// Returns ErrDuplicateKey if the same key appears more than once.
|
||||
|
func ParseParams(data []byte) (*Params, error) { |
||||
|
p := &Params{} |
||||
|
if len(data) == 0 { |
||||
|
return p, nil |
||||
|
} |
||||
|
|
||||
|
seen := make(map[string]bool) |
||||
|
s := string(data) |
||||
|
|
||||
|
// Split by null separator
|
||||
|
parts := strings.Split(s, "\x00") |
||||
|
for _, part := range parts { |
||||
|
if part == "" { |
||||
|
continue // trailing null or empty
|
||||
|
} |
||||
|
idx := strings.IndexByte(part, '=') |
||||
|
if idx < 0 { |
||||
|
return nil, fmt.Errorf("%w: %q", ErrMalformedParam, part) |
||||
|
} |
||||
|
key := part[:idx] |
||||
|
if key == "" { |
||||
|
return nil, ErrEmptyKey |
||||
|
} |
||||
|
value := part[idx+1:] |
||||
|
|
||||
|
if seen[key] { |
||||
|
return nil, fmt.Errorf("%w: %q", ErrDuplicateKey, key) |
||||
|
} |
||||
|
seen[key] = true |
||||
|
p.items = append(p.items, paramItem{key: key, value: value}) |
||||
|
} |
||||
|
|
||||
|
return p, nil |
||||
|
} |
||||
|
|
||||
|
// Encode serializes parameters to "Key=Value\0" format.
|
||||
|
func (p *Params) Encode() []byte { |
||||
|
if len(p.items) == 0 { |
||||
|
return nil |
||||
|
} |
||||
|
var b strings.Builder |
||||
|
for _, item := range p.items { |
||||
|
b.WriteString(item.key) |
||||
|
b.WriteByte('=') |
||||
|
b.WriteString(item.value) |
||||
|
b.WriteByte(0) |
||||
|
} |
||||
|
return []byte(b.String()) |
||||
|
} |
||||
|
|
||||
|
// Get returns the value for a key, or ("", false) if not present.
|
||||
|
func (p *Params) Get(key string) (string, bool) { |
||||
|
for _, item := range p.items { |
||||
|
if item.key == key { |
||||
|
return item.value, true |
||||
|
} |
||||
|
} |
||||
|
return "", false |
||||
|
} |
||||
|
|
||||
|
// Set adds or replaces a parameter. If the key already exists, its value
|
||||
|
// is updated in place. Otherwise, the key is appended.
|
||||
|
func (p *Params) Set(key, value string) { |
||||
|
for i, item := range p.items { |
||||
|
if item.key == key { |
||||
|
p.items[i].value = value |
||||
|
return |
||||
|
} |
||||
|
} |
||||
|
p.items = append(p.items, paramItem{key: key, value: value}) |
||||
|
} |
||||
|
|
||||
|
// Del removes a key if present.
|
||||
|
func (p *Params) Del(key string) { |
||||
|
for i, item := range p.items { |
||||
|
if item.key == key { |
||||
|
p.items = append(p.items[:i], p.items[i+1:]...) |
||||
|
return |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Keys returns all keys in order.
|
||||
|
func (p *Params) Keys() []string { |
||||
|
keys := make([]string, len(p.items)) |
||||
|
for i, item := range p.items { |
||||
|
keys[i] = item.key |
||||
|
} |
||||
|
return keys |
||||
|
} |
||||
|
|
||||
|
// Len returns the number of parameters.
|
||||
|
func (p *Params) Len() int { return len(p.items) } |
||||
|
|
||||
|
// Each iterates over all key-value pairs in order.
|
||||
|
func (p *Params) Each(fn func(key, value string)) { |
||||
|
for _, item := range p.items { |
||||
|
fn(item.key, item.value) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// --- Negotiation helpers ---
|
||||
|
|
||||
|
// NegotiateNumber applies the iSCSI numeric negotiation rule:
|
||||
|
// for "min" semantics (e.g., MaxRecvDataSegmentLength), return min(offer, ours).
|
||||
|
// For "max" semantics (e.g., MaxBurstLength), return min(offer, ours).
|
||||
|
// Both directions clamp to the smaller value, so min() is the general rule.
|
||||
|
func NegotiateNumber(offered string, ours int, min, max int) (int, error) { |
||||
|
v, err := strconv.Atoi(offered) |
||||
|
if err != nil { |
||||
|
return 0, fmt.Errorf("iscsi: invalid numeric value %q: %w", offered, err) |
||||
|
} |
||||
|
// Clamp to valid range
|
||||
|
if v < min { |
||||
|
v = min |
||||
|
} |
||||
|
if v > max { |
||||
|
v = max |
||||
|
} |
||||
|
// Standard negotiation: result is min(offer, ours)
|
||||
|
if ours < v { |
||||
|
return ours, nil |
||||
|
} |
||||
|
return v, nil |
||||
|
} |
||||
|
|
||||
|
// NegotiateBool applies boolean negotiation (AND semantics per RFC 7143).
|
||||
|
// Result is true only if both sides agree to true.
|
||||
|
func NegotiateBool(offered string, ours bool) (bool, error) { |
||||
|
switch offered { |
||||
|
case "Yes": |
||||
|
return ours, nil |
||||
|
case "No": |
||||
|
return false, nil |
||||
|
default: |
||||
|
return false, fmt.Errorf("iscsi: invalid boolean value %q", offered) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// BoolStr returns "Yes" or "No" for a boolean value.
|
||||
|
func BoolStr(v bool) string { |
||||
|
if v { |
||||
|
return "Yes" |
||||
|
} |
||||
|
return "No" |
||||
|
} |
||||
|
|
||||
|
// NewParams creates a new empty Params.
|
||||
|
func NewParams() *Params { |
||||
|
return &Params{} |
||||
|
} |
||||
@ -0,0 +1,357 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"strings" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestParams(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"parse_single", testParseSingle}, |
||||
|
{"parse_multiple", testParseMultiple}, |
||||
|
{"parse_empty", testParseEmpty}, |
||||
|
{"parse_trailing_null", testParseTrailingNull}, |
||||
|
{"parse_malformed_no_equals", testParseMalformedNoEquals}, |
||||
|
{"parse_empty_key", testParseEmptyKey}, |
||||
|
{"parse_empty_value", testParseEmptyValue}, |
||||
|
{"parse_duplicate_key", testParseDuplicateKey}, |
||||
|
{"roundtrip", testParamsRoundtrip}, |
||||
|
{"encode_empty", testEncodeEmpty}, |
||||
|
{"set_new_key", testSetNewKey}, |
||||
|
{"set_existing_key", testSetExistingKey}, |
||||
|
{"del_key", testDelKey}, |
||||
|
{"del_nonexistent", testDelNonexistent}, |
||||
|
{"keys_order", testKeysOrder}, |
||||
|
{"each", testEach}, |
||||
|
{"negotiate_number_min", testNegotiateNumberMin}, |
||||
|
{"negotiate_number_clamp", testNegotiateNumberClamp}, |
||||
|
{"negotiate_number_invalid", testNegotiateNumberInvalid}, |
||||
|
{"negotiate_bool_and", testNegotiateBoolAnd}, |
||||
|
{"negotiate_bool_invalid", testNegotiateBoolInvalid}, |
||||
|
{"bool_str", testBoolStr}, |
||||
|
{"value_with_equals", testValueWithEquals}, |
||||
|
{"parse_binary_data_value", testParseBinaryDataValue}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseSingle(t *testing.T) { |
||||
|
data := []byte("TargetName=iqn.2024.com.seaweedfs:vol1\x00") |
||||
|
p, err := ParseParams(data) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if p.Len() != 1 { |
||||
|
t.Fatalf("expected 1 param, got %d", p.Len()) |
||||
|
} |
||||
|
v, ok := p.Get("TargetName") |
||||
|
if !ok || v != "iqn.2024.com.seaweedfs:vol1" { |
||||
|
t.Fatalf("got %q, %v", v, ok) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseMultiple(t *testing.T) { |
||||
|
data := []byte("InitiatorName=iqn.2024.com.test\x00TargetName=iqn.2024.com.seaweedfs:vol1\x00SessionType=Normal\x00") |
||||
|
p, err := ParseParams(data) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if p.Len() != 3 { |
||||
|
t.Fatalf("expected 3 params, got %d", p.Len()) |
||||
|
} |
||||
|
keys := p.Keys() |
||||
|
if keys[0] != "InitiatorName" || keys[1] != "TargetName" || keys[2] != "SessionType" { |
||||
|
t.Fatalf("wrong order: %v", keys) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseEmpty(t *testing.T) { |
||||
|
p, err := ParseParams(nil) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if p.Len() != 0 { |
||||
|
t.Fatalf("expected 0, got %d", p.Len()) |
||||
|
} |
||||
|
|
||||
|
p2, err := ParseParams([]byte{}) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if p2.Len() != 0 { |
||||
|
t.Fatalf("expected 0, got %d", p2.Len()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseTrailingNull(t *testing.T) { |
||||
|
// Multiple trailing nulls should not cause errors
|
||||
|
data := []byte("Key=Value\x00\x00\x00") |
||||
|
p, err := ParseParams(data) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if p.Len() != 1 { |
||||
|
t.Fatalf("expected 1, got %d", p.Len()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseMalformedNoEquals(t *testing.T) { |
||||
|
data := []byte("KeyWithoutValue\x00") |
||||
|
_, err := ParseParams(data) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected error") |
||||
|
} |
||||
|
if !strings.Contains(err.Error(), "malformed") { |
||||
|
t.Fatalf("unexpected error: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseEmptyKey(t *testing.T) { |
||||
|
data := []byte("=Value\x00") |
||||
|
_, err := ParseParams(data) |
||||
|
if err != ErrEmptyKey { |
||||
|
t.Fatalf("expected ErrEmptyKey, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseEmptyValue(t *testing.T) { |
||||
|
// Empty value is valid in iSCSI (e.g., reject with empty value)
|
||||
|
data := []byte("Key=\x00") |
||||
|
p, err := ParseParams(data) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
v, ok := p.Get("Key") |
||||
|
if !ok || v != "" { |
||||
|
t.Fatalf("got %q, %v", v, ok) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseDuplicateKey(t *testing.T) { |
||||
|
data := []byte("Key=Value1\x00Key=Value2\x00") |
||||
|
_, err := ParseParams(data) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected error") |
||||
|
} |
||||
|
if !strings.Contains(err.Error(), "duplicate") { |
||||
|
t.Fatalf("unexpected error: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParamsRoundtrip(t *testing.T) { |
||||
|
original := []byte("InitiatorName=iqn.2024.com.test\x00MaxRecvDataSegmentLength=65536\x00") |
||||
|
p, err := ParseParams(original) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
encoded := p.Encode() |
||||
|
if !bytes.Equal(encoded, original) { |
||||
|
t.Fatalf("roundtrip mismatch:\n got: %q\n want: %q", encoded, original) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testEncodeEmpty(t *testing.T) { |
||||
|
p := NewParams() |
||||
|
if encoded := p.Encode(); encoded != nil { |
||||
|
t.Fatalf("expected nil, got %q", encoded) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSetNewKey(t *testing.T) { |
||||
|
p := NewParams() |
||||
|
p.Set("Key1", "Val1") |
||||
|
p.Set("Key2", "Val2") |
||||
|
if p.Len() != 2 { |
||||
|
t.Fatalf("expected 2, got %d", p.Len()) |
||||
|
} |
||||
|
v, _ := p.Get("Key2") |
||||
|
if v != "Val2" { |
||||
|
t.Fatalf("got %q", v) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSetExistingKey(t *testing.T) { |
||||
|
p := NewParams() |
||||
|
p.Set("Key", "Old") |
||||
|
p.Set("Key", "New") |
||||
|
if p.Len() != 1 { |
||||
|
t.Fatalf("expected 1, got %d", p.Len()) |
||||
|
} |
||||
|
v, _ := p.Get("Key") |
||||
|
if v != "New" { |
||||
|
t.Fatalf("got %q", v) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDelKey(t *testing.T) { |
||||
|
p := NewParams() |
||||
|
p.Set("A", "1") |
||||
|
p.Set("B", "2") |
||||
|
p.Set("C", "3") |
||||
|
p.Del("B") |
||||
|
if p.Len() != 2 { |
||||
|
t.Fatalf("expected 2, got %d", p.Len()) |
||||
|
} |
||||
|
if _, ok := p.Get("B"); ok { |
||||
|
t.Fatal("B should be deleted") |
||||
|
} |
||||
|
keys := p.Keys() |
||||
|
if keys[0] != "A" || keys[1] != "C" { |
||||
|
t.Fatalf("wrong keys: %v", keys) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDelNonexistent(t *testing.T) { |
||||
|
p := NewParams() |
||||
|
p.Set("A", "1") |
||||
|
p.Del("Z") // should not panic
|
||||
|
if p.Len() != 1 { |
||||
|
t.Fatalf("expected 1, got %d", p.Len()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testKeysOrder(t *testing.T) { |
||||
|
data := []byte("Z=1\x00A=2\x00M=3\x00") |
||||
|
p, err := ParseParams(data) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
keys := p.Keys() |
||||
|
if keys[0] != "Z" || keys[1] != "A" || keys[2] != "M" { |
||||
|
t.Fatalf("order not preserved: %v", keys) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testEach(t *testing.T) { |
||||
|
p := NewParams() |
||||
|
p.Set("A", "1") |
||||
|
p.Set("B", "2") |
||||
|
var collected []string |
||||
|
p.Each(func(k, v string) { |
||||
|
collected = append(collected, k+"="+v) |
||||
|
}) |
||||
|
if len(collected) != 2 || collected[0] != "A=1" || collected[1] != "B=2" { |
||||
|
t.Fatalf("Each: %v", collected) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testNegotiateNumberMin(t *testing.T) { |
||||
|
// Both sides offer a value, result is min
|
||||
|
result, err := NegotiateNumber("65536", 262144, 512, 16777215) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if result != 65536 { |
||||
|
t.Fatalf("expected 65536, got %d", result) |
||||
|
} |
||||
|
|
||||
|
// Our value is smaller
|
||||
|
result, err = NegotiateNumber("262144", 65536, 512, 16777215) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if result != 65536 { |
||||
|
t.Fatalf("expected 65536, got %d", result) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testNegotiateNumberClamp(t *testing.T) { |
||||
|
// Offered value below minimum
|
||||
|
result, err := NegotiateNumber("100", 65536, 512, 16777215) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if result != 512 { |
||||
|
t.Fatalf("expected 512, got %d", result) |
||||
|
} |
||||
|
|
||||
|
// Offered value above maximum
|
||||
|
result, err = NegotiateNumber("99999999", 65536, 512, 16777215) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if result != 65536 { |
||||
|
t.Fatalf("expected 65536, got %d", result) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testNegotiateNumberInvalid(t *testing.T) { |
||||
|
_, err := NegotiateNumber("notanumber", 65536, 512, 16777215) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected error") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testNegotiateBoolAnd(t *testing.T) { |
||||
|
// Both yes
|
||||
|
r, err := NegotiateBool("Yes", true) |
||||
|
if err != nil || !r { |
||||
|
t.Fatalf("Yes+true = %v, %v", r, err) |
||||
|
} |
||||
|
// Initiator yes, target no
|
||||
|
r, err = NegotiateBool("Yes", false) |
||||
|
if err != nil || r { |
||||
|
t.Fatalf("Yes+false = %v, %v", r, err) |
||||
|
} |
||||
|
// Initiator no, target yes
|
||||
|
r, err = NegotiateBool("No", true) |
||||
|
if err != nil || r { |
||||
|
t.Fatalf("No+true = %v, %v", r, err) |
||||
|
} |
||||
|
// Both no
|
||||
|
r, err = NegotiateBool("No", false) |
||||
|
if err != nil || r { |
||||
|
t.Fatalf("No+false = %v, %v", r, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testNegotiateBoolInvalid(t *testing.T) { |
||||
|
_, err := NegotiateBool("Maybe", true) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected error") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testBoolStr(t *testing.T) { |
||||
|
if BoolStr(true) != "Yes" { |
||||
|
t.Fatal("true should be Yes") |
||||
|
} |
||||
|
if BoolStr(false) != "No" { |
||||
|
t.Fatal("false should be No") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testValueWithEquals(t *testing.T) { |
||||
|
// Value containing '=' is valid (only first '=' splits key from value)
|
||||
|
data := []byte("TargetAddress=10.0.0.1:3260,1\x00Key=a=b=c\x00") |
||||
|
p, err := ParseParams(data) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
v, ok := p.Get("Key") |
||||
|
if !ok || v != "a=b=c" { |
||||
|
t.Fatalf("got %q, %v", v, ok) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testParseBinaryDataValue(t *testing.T) { |
||||
|
// Values can contain arbitrary bytes except null
|
||||
|
data := []byte("Key=\x01\x02\x03\x00") |
||||
|
p, err := ParseParams(data) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
v, ok := p.Get("Key") |
||||
|
if !ok || v != "\x01\x02\x03" { |
||||
|
t.Fatalf("got %q, %v", v, ok) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,447 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
) |
||||
|
|
||||
|
// BHS (Basic Header Segment) is 48 bytes, big-endian.
|
||||
|
// RFC 7143, Section 12.1
|
||||
|
const BHSLength = 48 |
||||
|
|
||||
|
// Opcode constants — initiator opcodes (RFC 7143, Section 12.1.1)
|
||||
|
const ( |
||||
|
OpNOPOut uint8 = 0x00 |
||||
|
OpSCSICmd uint8 = 0x01 |
||||
|
OpSCSITaskMgmt uint8 = 0x02 |
||||
|
OpLoginReq uint8 = 0x03 |
||||
|
OpTextReq uint8 = 0x04 |
||||
|
OpSCSIDataOut uint8 = 0x05 |
||||
|
OpLogoutReq uint8 = 0x06 |
||||
|
OpSNACKReq uint8 = 0x0c |
||||
|
) |
||||
|
|
||||
|
// Target opcodes (RFC 7143, Section 12.1.1)
|
||||
|
const ( |
||||
|
OpNOPIn uint8 = 0x20 |
||||
|
OpSCSIResp uint8 = 0x21 |
||||
|
OpSCSITaskResp uint8 = 0x22 |
||||
|
OpLoginResp uint8 = 0x23 |
||||
|
OpTextResp uint8 = 0x24 |
||||
|
OpSCSIDataIn uint8 = 0x25 |
||||
|
OpLogoutResp uint8 = 0x26 |
||||
|
OpR2T uint8 = 0x31 |
||||
|
OpAsyncMsg uint8 = 0x32 |
||||
|
OpReject uint8 = 0x3f |
||||
|
) |
||||
|
|
||||
|
// BHS flag masks
|
||||
|
const ( |
||||
|
opcMask = 0x3f // lower 6 bits of byte 0
|
||||
|
FlagI = 0x40 // Immediate delivery bit (byte 0)
|
||||
|
FlagF = 0x80 // Final bit (byte 1)
|
||||
|
FlagR = 0x40 // Read bit (byte 1, SCSI command)
|
||||
|
FlagW = 0x20 // Write bit (byte 1, SCSI command)
|
||||
|
FlagC = 0x40 // Continue bit (byte 1, login)
|
||||
|
FlagT = 0x80 // Transit bit (byte 1, login)
|
||||
|
FlagS = 0x01 // Status bit (byte 1, Data-In)
|
||||
|
FlagU = 0x02 // Underflow bit (byte 1, SCSI Response)
|
||||
|
FlagO = 0x04 // Overflow bit (byte 1, SCSI Response)
|
||||
|
FlagBiU = 0x08 // Bidi underflow (byte 1, SCSI Response)
|
||||
|
FlagBiO = 0x10 // Bidi overflow (byte 1, SCSI Response)
|
||||
|
FlagA = 0x40 // Acknowledge bit (byte 1, Data-Out)
|
||||
|
) |
||||
|
|
||||
|
// Login stage constants (CSG/NSG, 2-bit fields in byte 1 of login PDU)
|
||||
|
const ( |
||||
|
StageSecurityNeg uint8 = 0 // Security Negotiation
|
||||
|
StageLoginOp uint8 = 1 // Login Operational Negotiation
|
||||
|
StageFullFeature uint8 = 3 // Full Feature Phase
|
||||
|
) |
||||
|
|
||||
|
// SCSI status codes
|
||||
|
const ( |
||||
|
SCSIStatusGood uint8 = 0x00 |
||||
|
SCSIStatusCheckCond uint8 = 0x02 |
||||
|
SCSIStatusBusy uint8 = 0x08 |
||||
|
SCSIStatusResvConflict uint8 = 0x18 |
||||
|
) |
||||
|
|
||||
|
// iSCSI response codes
|
||||
|
const ( |
||||
|
ISCSIRespCompleted uint8 = 0x00 |
||||
|
) |
||||
|
|
||||
|
// MaxDataSegmentLength limits the maximum data segment we'll accept.
|
||||
|
// The iSCSI data segment length is a 3-byte field (max 16MB-1).
|
||||
|
// We cap at 8MB which is well above typical MaxRecvDataSegmentLength values.
|
||||
|
const MaxDataSegmentLength = 8 * 1024 * 1024 // 8 MB
|
||||
|
|
||||
|
var ( |
||||
|
ErrPDUTruncated = errors.New("iscsi: PDU truncated") |
||||
|
ErrPDUTooLarge = errors.New("iscsi: data segment exceeds maximum length") |
||||
|
ErrInvalidAHSLength = errors.New("iscsi: invalid AHS length (not multiple of 4)") |
||||
|
ErrUnknownOpcode = errors.New("iscsi: unknown opcode") |
||||
|
) |
||||
|
|
||||
|
// PDU represents a full iSCSI Protocol Data Unit.
|
||||
|
type PDU struct { |
||||
|
BHS [BHSLength]byte |
||||
|
AHS []byte // Additional Header Segment (multiple of 4 bytes)
|
||||
|
DataSegment []byte // Data segment (padded to 4-byte boundary on wire)
|
||||
|
} |
||||
|
|
||||
|
// --- BHS field accessors ---
|
||||
|
|
||||
|
// Opcode returns the opcode (lower 6 bits of byte 0).
|
||||
|
func (p *PDU) Opcode() uint8 { return p.BHS[0] & opcMask } |
||||
|
|
||||
|
// SetOpcode sets the opcode (lower 6 bits of byte 0).
|
||||
|
func (p *PDU) SetOpcode(op uint8) { |
||||
|
p.BHS[0] = (p.BHS[0] & ^uint8(opcMask)) | (op & opcMask) |
||||
|
} |
||||
|
|
||||
|
// Immediate returns true if the immediate delivery bit is set.
|
||||
|
func (p *PDU) Immediate() bool { return p.BHS[0]&FlagI != 0 } |
||||
|
|
||||
|
// SetImmediate sets the immediate delivery bit.
|
||||
|
func (p *PDU) SetImmediate(v bool) { |
||||
|
if v { |
||||
|
p.BHS[0] |= FlagI |
||||
|
} else { |
||||
|
p.BHS[0] &^= FlagI |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// OpSpecific1 returns byte 1 (opcode-specific flags).
|
||||
|
func (p *PDU) OpSpecific1() uint8 { return p.BHS[1] } |
||||
|
|
||||
|
// SetOpSpecific1 sets byte 1.
|
||||
|
func (p *PDU) SetOpSpecific1(v uint8) { p.BHS[1] = v } |
||||
|
|
||||
|
// TotalAHSLength returns the total AHS length in 4-byte words (byte 4).
|
||||
|
func (p *PDU) TotalAHSLength() uint8 { return p.BHS[4] } |
||||
|
|
||||
|
// DataSegmentLength returns the 3-byte data segment length (bytes 5-7).
|
||||
|
func (p *PDU) DataSegmentLength() uint32 { |
||||
|
return uint32(p.BHS[5])<<16 | uint32(p.BHS[6])<<8 | uint32(p.BHS[7]) |
||||
|
} |
||||
|
|
||||
|
// SetDataSegmentLength sets the 3-byte data segment length field.
|
||||
|
func (p *PDU) SetDataSegmentLength(n uint32) { |
||||
|
p.BHS[5] = byte(n >> 16) |
||||
|
p.BHS[6] = byte(n >> 8) |
||||
|
p.BHS[7] = byte(n) |
||||
|
} |
||||
|
|
||||
|
// LUN returns the 8-byte LUN field (bytes 8-15).
|
||||
|
func (p *PDU) LUN() uint64 { return binary.BigEndian.Uint64(p.BHS[8:16]) } |
||||
|
|
||||
|
// SetLUN sets the LUN field.
|
||||
|
func (p *PDU) SetLUN(lun uint64) { binary.BigEndian.PutUint64(p.BHS[8:16], lun) } |
||||
|
|
||||
|
// InitiatorTaskTag returns bytes 16-19.
|
||||
|
func (p *PDU) InitiatorTaskTag() uint32 { return binary.BigEndian.Uint32(p.BHS[16:20]) } |
||||
|
|
||||
|
// SetInitiatorTaskTag sets bytes 16-19.
|
||||
|
func (p *PDU) SetInitiatorTaskTag(tag uint32) { binary.BigEndian.PutUint32(p.BHS[16:20], tag) } |
||||
|
|
||||
|
// Field32 reads a generic 4-byte field at the given BHS offset.
|
||||
|
func (p *PDU) Field32(offset int) uint32 { return binary.BigEndian.Uint32(p.BHS[offset : offset+4]) } |
||||
|
|
||||
|
// SetField32 writes a generic 4-byte field at the given BHS offset.
|
||||
|
func (p *PDU) SetField32(offset int, v uint32) { |
||||
|
binary.BigEndian.PutUint32(p.BHS[offset:offset+4], v) |
||||
|
} |
||||
|
|
||||
|
// --- Common BHS offsets for named fields ---
|
||||
|
|
||||
|
// TSIH returns the Target Session Identifying Handle (bytes 14-15, login PDU).
|
||||
|
func (p *PDU) TSIH() uint16 { return binary.BigEndian.Uint16(p.BHS[14:16]) } |
||||
|
|
||||
|
// SetTSIH sets the TSIH field.
|
||||
|
func (p *PDU) SetTSIH(v uint16) { binary.BigEndian.PutUint16(p.BHS[14:16], v) } |
||||
|
|
||||
|
// CmdSN returns bytes 24-27.
|
||||
|
func (p *PDU) CmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) } |
||||
|
|
||||
|
// SetCmdSN sets bytes 24-27.
|
||||
|
func (p *PDU) SetCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) } |
||||
|
|
||||
|
// ExpStatSN returns bytes 28-31.
|
||||
|
func (p *PDU) ExpStatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) } |
||||
|
|
||||
|
// SetExpStatSN sets bytes 28-31.
|
||||
|
func (p *PDU) SetExpStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) } |
||||
|
|
||||
|
// StatSN returns bytes 24-27 (target response PDUs).
|
||||
|
func (p *PDU) StatSN() uint32 { return binary.BigEndian.Uint32(p.BHS[24:28]) } |
||||
|
|
||||
|
// SetStatSN sets bytes 24-27.
|
||||
|
func (p *PDU) SetStatSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[24:28], v) } |
||||
|
|
||||
|
// ExpCmdSN returns bytes 28-31 (target response PDUs).
|
||||
|
func (p *PDU) ExpCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[28:32]) } |
||||
|
|
||||
|
// SetExpCmdSN sets bytes 28-31.
|
||||
|
func (p *PDU) SetExpCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[28:32], v) } |
||||
|
|
||||
|
// MaxCmdSN returns bytes 32-35 (target response PDUs).
|
||||
|
func (p *PDU) MaxCmdSN() uint32 { return binary.BigEndian.Uint32(p.BHS[32:36]) } |
||||
|
|
||||
|
// SetMaxCmdSN sets bytes 32-35.
|
||||
|
func (p *PDU) SetMaxCmdSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[32:36], v) } |
||||
|
|
||||
|
// DataSN returns bytes 36-39 (Data-In / Data-Out PDUs).
|
||||
|
func (p *PDU) DataSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) } |
||||
|
|
||||
|
// SetDataSN sets bytes 36-39.
|
||||
|
func (p *PDU) SetDataSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) } |
||||
|
|
||||
|
// BufferOffset returns bytes 40-43 (Data-In / Data-Out PDUs).
|
||||
|
func (p *PDU) BufferOffset() uint32 { return binary.BigEndian.Uint32(p.BHS[40:44]) } |
||||
|
|
||||
|
// SetBufferOffset sets bytes 40-43.
|
||||
|
func (p *PDU) SetBufferOffset(v uint32) { binary.BigEndian.PutUint32(p.BHS[40:44], v) } |
||||
|
|
||||
|
// R2TSN returns bytes 36-39 (R2T PDUs).
|
||||
|
func (p *PDU) R2TSN() uint32 { return binary.BigEndian.Uint32(p.BHS[36:40]) } |
||||
|
|
||||
|
// SetR2TSN sets bytes 36-39.
|
||||
|
func (p *PDU) SetR2TSN(v uint32) { binary.BigEndian.PutUint32(p.BHS[36:40], v) } |
||||
|
|
||||
|
// DesiredDataLength returns bytes 44-47 (R2T PDUs).
|
||||
|
func (p *PDU) DesiredDataLength() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) } |
||||
|
|
||||
|
// SetDesiredDataLength sets bytes 44-47.
|
||||
|
func (p *PDU) SetDesiredDataLength(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) } |
||||
|
|
||||
|
// ExpectedDataTransferLength returns bytes 20-23 (SCSI Command PDUs).
|
||||
|
func (p *PDU) ExpectedDataTransferLength() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) } |
||||
|
|
||||
|
// SetExpectedDataTransferLength sets bytes 20-23.
|
||||
|
func (p *PDU) SetExpectedDataTransferLength(v uint32) { |
||||
|
binary.BigEndian.PutUint32(p.BHS[20:24], v) |
||||
|
} |
||||
|
|
||||
|
// TargetTransferTag returns bytes 20-23 (target PDUs: R2T, Data-In, etc.).
|
||||
|
func (p *PDU) TargetTransferTag() uint32 { return binary.BigEndian.Uint32(p.BHS[20:24]) } |
||||
|
|
||||
|
// SetTargetTransferTag sets bytes 20-23.
|
||||
|
func (p *PDU) SetTargetTransferTag(v uint32) { binary.BigEndian.PutUint32(p.BHS[20:24], v) } |
||||
|
|
||||
|
// ISID returns the 6-byte Initiator Session ID (bytes 8-13, login PDU).
|
||||
|
func (p *PDU) ISID() [6]byte { |
||||
|
var id [6]byte |
||||
|
copy(id[:], p.BHS[8:14]) |
||||
|
return id |
||||
|
} |
||||
|
|
||||
|
// SetISID sets the 6-byte ISID field.
|
||||
|
func (p *PDU) SetISID(id [6]byte) { copy(p.BHS[8:14], id[:]) } |
||||
|
|
||||
|
// CDB returns the 16-byte SCSI CDB from the BHS (bytes 32-47, SCSI Command PDU).
|
||||
|
func (p *PDU) CDB() [16]byte { |
||||
|
var cdb [16]byte |
||||
|
copy(cdb[:], p.BHS[32:48]) |
||||
|
return cdb |
||||
|
} |
||||
|
|
||||
|
// SetCDB sets the 16-byte SCSI CDB in the BHS.
|
||||
|
func (p *PDU) SetCDB(cdb [16]byte) { copy(p.BHS[32:48], cdb[:]) } |
||||
|
|
||||
|
// ResidualCount returns bytes 44-47 (SCSI Response PDUs).
|
||||
|
func (p *PDU) ResidualCount() uint32 { return binary.BigEndian.Uint32(p.BHS[44:48]) } |
||||
|
|
||||
|
// SetResidualCount sets bytes 44-47.
|
||||
|
func (p *PDU) SetResidualCount(v uint32) { binary.BigEndian.PutUint32(p.BHS[44:48], v) } |
||||
|
|
||||
|
// --- Wire I/O ---
|
||||
|
|
||||
|
// pad4 rounds n up to the next multiple of 4.
|
||||
|
func pad4(n uint32) uint32 { return (n + 3) &^ 3 } |
||||
|
|
||||
|
// ReadPDU reads a complete PDU from r.
|
||||
|
func ReadPDU(r io.Reader) (*PDU, error) { |
||||
|
p := &PDU{} |
||||
|
|
||||
|
// Read BHS (48 bytes)
|
||||
|
if _, err := io.ReadFull(r, p.BHS[:]); err != nil { |
||||
|
if err == io.ErrUnexpectedEOF { |
||||
|
return nil, ErrPDUTruncated |
||||
|
} |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
// AHS
|
||||
|
ahsLen := uint32(p.TotalAHSLength()) * 4 |
||||
|
if ahsLen > 0 { |
||||
|
p.AHS = make([]byte, ahsLen) |
||||
|
if _, err := io.ReadFull(r, p.AHS); err != nil { |
||||
|
if err == io.ErrUnexpectedEOF { |
||||
|
return nil, ErrPDUTruncated |
||||
|
} |
||||
|
return nil, err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Data segment
|
||||
|
dsLen := p.DataSegmentLength() |
||||
|
if dsLen > MaxDataSegmentLength { |
||||
|
return nil, fmt.Errorf("%w: %d bytes", ErrPDUTooLarge, dsLen) |
||||
|
} |
||||
|
|
||||
|
if dsLen > 0 { |
||||
|
paddedLen := pad4(dsLen) |
||||
|
buf := make([]byte, paddedLen) |
||||
|
if _, err := io.ReadFull(r, buf); err != nil { |
||||
|
if err == io.ErrUnexpectedEOF { |
||||
|
return nil, ErrPDUTruncated |
||||
|
} |
||||
|
return nil, err |
||||
|
} |
||||
|
p.DataSegment = buf[:dsLen] // strip padding
|
||||
|
} |
||||
|
|
||||
|
return p, nil |
||||
|
} |
||||
|
|
||||
|
// WritePDU writes a complete PDU to w, with proper padding.
|
||||
|
func WritePDU(w io.Writer, p *PDU) error { |
||||
|
// Update header lengths from actual AHS/DataSegment
|
||||
|
if len(p.AHS) > 0 { |
||||
|
if len(p.AHS)%4 != 0 { |
||||
|
return ErrInvalidAHSLength |
||||
|
} |
||||
|
p.BHS[4] = uint8(len(p.AHS) / 4) |
||||
|
} else { |
||||
|
p.BHS[4] = 0 |
||||
|
} |
||||
|
p.SetDataSegmentLength(uint32(len(p.DataSegment))) |
||||
|
|
||||
|
// Write BHS
|
||||
|
if _, err := w.Write(p.BHS[:]); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// Write AHS
|
||||
|
if len(p.AHS) > 0 { |
||||
|
if _, err := w.Write(p.AHS); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Write data segment with padding
|
||||
|
if len(p.DataSegment) > 0 { |
||||
|
if _, err := w.Write(p.DataSegment); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
padLen := pad4(uint32(len(p.DataSegment))) - uint32(len(p.DataSegment)) |
||||
|
if padLen > 0 { |
||||
|
var pad [3]byte |
||||
|
if _, err := w.Write(pad[:padLen]); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// OpcodeName returns a human-readable name for the given opcode.
|
||||
|
func OpcodeName(op uint8) string { |
||||
|
switch op { |
||||
|
case OpNOPOut: |
||||
|
return "NOP-Out" |
||||
|
case OpSCSICmd: |
||||
|
return "SCSI-Command" |
||||
|
case OpSCSITaskMgmt: |
||||
|
return "SCSI-Task-Mgmt" |
||||
|
case OpLoginReq: |
||||
|
return "Login-Request" |
||||
|
case OpTextReq: |
||||
|
return "Text-Request" |
||||
|
case OpSCSIDataOut: |
||||
|
return "SCSI-Data-Out" |
||||
|
case OpLogoutReq: |
||||
|
return "Logout-Request" |
||||
|
case OpSNACKReq: |
||||
|
return "SNACK-Request" |
||||
|
case OpNOPIn: |
||||
|
return "NOP-In" |
||||
|
case OpSCSIResp: |
||||
|
return "SCSI-Response" |
||||
|
case OpSCSITaskResp: |
||||
|
return "SCSI-Task-Mgmt-Response" |
||||
|
case OpLoginResp: |
||||
|
return "Login-Response" |
||||
|
case OpTextResp: |
||||
|
return "Text-Response" |
||||
|
case OpSCSIDataIn: |
||||
|
return "SCSI-Data-In" |
||||
|
case OpLogoutResp: |
||||
|
return "Logout-Response" |
||||
|
case OpR2T: |
||||
|
return "R2T" |
||||
|
case OpAsyncMsg: |
||||
|
return "Async-Message" |
||||
|
case OpReject: |
||||
|
return "Reject" |
||||
|
default: |
||||
|
return fmt.Sprintf("Unknown(0x%02x)", op) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Login PDU helpers
|
||||
|
|
||||
|
// LoginCSG returns the Current Stage from byte 1 of a login PDU (bits 3-2).
|
||||
|
func (p *PDU) LoginCSG() uint8 { return (p.BHS[1] >> 2) & 0x03 } |
||||
|
|
||||
|
// LoginNSG returns the Next Stage from byte 1 of a login PDU (bits 1-0).
|
||||
|
func (p *PDU) LoginNSG() uint8 { return p.BHS[1] & 0x03 } |
||||
|
|
||||
|
// SetLoginStages sets the CSG and NSG fields in byte 1, preserving T and C flags.
|
||||
|
func (p *PDU) SetLoginStages(csg, nsg uint8) { |
||||
|
p.BHS[1] = (p.BHS[1] & 0xF0) | ((csg & 0x03) << 2) | (nsg & 0x03) |
||||
|
} |
||||
|
|
||||
|
// LoginTransit returns true if the Transit bit is set (byte 1, bit 7).
|
||||
|
func (p *PDU) LoginTransit() bool { return p.BHS[1]&FlagT != 0 } |
||||
|
|
||||
|
// SetLoginTransit sets or clears the Transit bit.
|
||||
|
func (p *PDU) SetLoginTransit(v bool) { |
||||
|
if v { |
||||
|
p.BHS[1] |= FlagT |
||||
|
} else { |
||||
|
p.BHS[1] &^= FlagT |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// LoginContinue returns true if the Continue bit is set.
|
||||
|
func (p *PDU) LoginContinue() bool { return p.BHS[1]&FlagC != 0 } |
||||
|
|
||||
|
// LoginStatusClass returns the status class (byte 36) of a login response.
|
||||
|
func (p *PDU) LoginStatusClass() uint8 { return p.BHS[36] } |
||||
|
|
||||
|
// LoginStatusDetail returns the status detail (byte 37) of a login response.
|
||||
|
func (p *PDU) LoginStatusDetail() uint8 { return p.BHS[37] } |
||||
|
|
||||
|
// SetLoginStatus sets the status class and detail in a login response.
|
||||
|
func (p *PDU) SetLoginStatus(class, detail uint8) { |
||||
|
p.BHS[36] = class |
||||
|
p.BHS[37] = detail |
||||
|
} |
||||
|
|
||||
|
// SCSIResponse returns the iSCSI response byte (byte 2) of a SCSI Response PDU.
|
||||
|
func (p *PDU) SCSIResponse() uint8 { return p.BHS[2] } |
||||
|
|
||||
|
// SetSCSIResponse sets byte 2.
|
||||
|
func (p *PDU) SetSCSIResponse(v uint8) { p.BHS[2] = v } |
||||
|
|
||||
|
// SCSIStatus returns the SCSI status byte (byte 3) of a SCSI Response PDU.
|
||||
|
func (p *PDU) SCSIStatus() uint8 { return p.BHS[3] } |
||||
|
|
||||
|
// SetSCSIStatus sets byte 3.
|
||||
|
func (p *PDU) SetSCSIStatus(v uint8) { p.BHS[3] = v } |
||||
@ -0,0 +1,559 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"encoding/binary" |
||||
|
"io" |
||||
|
"strings" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestPDU(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"roundtrip_bhs_only", testRoundtripBHSOnly}, |
||||
|
{"roundtrip_with_data", testRoundtripWithData}, |
||||
|
{"roundtrip_with_ahs_and_data", testRoundtripWithAHSAndData}, |
||||
|
{"data_padding", testDataPadding}, |
||||
|
{"opcode_accessors", testOpcodeAccessors}, |
||||
|
{"immediate_flag", testImmediateFlag}, |
||||
|
{"field32_accessors", testField32Accessors}, |
||||
|
{"lun_accessor", testLUNAccessor}, |
||||
|
{"isid_accessor", testISIDAccessor}, |
||||
|
{"cdb_accessor", testCDBAccessor}, |
||||
|
{"login_stage_accessors", testLoginStageAccessors}, |
||||
|
{"login_transit_continue", testLoginTransitContinue}, |
||||
|
{"login_status", testLoginStatus}, |
||||
|
{"scsi_response_status", testSCSIResponseStatus}, |
||||
|
{"data_segment_length_3byte", testDataSegmentLength3Byte}, |
||||
|
{"tsih_accessor", testTSIHAccessor}, |
||||
|
{"cmdsn_expstatsn", testCmdSNExpStatSN}, |
||||
|
{"statsn_expcmdsn_maxcmdsn", testStatSNExpCmdSNMaxCmdSN}, |
||||
|
{"datasn_bufferoffset", testDataSNBufferOffset}, |
||||
|
{"r2t_fields", testR2TFields}, |
||||
|
{"residual_count", testResidualCount}, |
||||
|
{"expected_data_transfer_length", testExpectedDataTransferLength}, |
||||
|
{"target_transfer_tag", testTargetTransferTag}, |
||||
|
{"opcode_name", testOpcodeName}, |
||||
|
{"read_truncated_bhs", testReadTruncatedBHS}, |
||||
|
{"read_truncated_data", testReadTruncatedData}, |
||||
|
{"read_truncated_ahs", testReadTruncatedAHS}, |
||||
|
{"read_oversized_data", testReadOversizedData}, |
||||
|
{"read_eof", testReadEOF}, |
||||
|
{"write_invalid_ahs_length", testWriteInvalidAHSLength}, |
||||
|
{"roundtrip_all_opcodes", testRoundtripAllOpcodes}, |
||||
|
{"data_segment_exact_4byte_boundary", testDataSegmentExact4ByteBoundary}, |
||||
|
{"zero_length_data_segment", testZeroLengthDataSegment}, |
||||
|
{"max_3byte_data_length", testMax3ByteDataLength}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRoundtripBHSOnly(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(OpSCSICmd) |
||||
|
p.SetInitiatorTaskTag(0xDEADBEEF) |
||||
|
p.SetCmdSN(42) |
||||
|
|
||||
|
var buf bytes.Buffer |
||||
|
if err := WritePDU(&buf, p); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if buf.Len() != BHSLength { |
||||
|
t.Fatalf("expected %d bytes, got %d", BHSLength, buf.Len()) |
||||
|
} |
||||
|
|
||||
|
p2, err := ReadPDU(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if p2.Opcode() != OpSCSICmd { |
||||
|
t.Fatalf("opcode mismatch: got 0x%02x", p2.Opcode()) |
||||
|
} |
||||
|
if p2.InitiatorTaskTag() != 0xDEADBEEF { |
||||
|
t.Fatalf("ITT mismatch: got 0x%x", p2.InitiatorTaskTag()) |
||||
|
} |
||||
|
if p2.CmdSN() != 42 { |
||||
|
t.Fatalf("CmdSN mismatch: got %d", p2.CmdSN()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRoundtripWithData(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(OpSCSIDataIn) |
||||
|
p.DataSegment = []byte("hello, iSCSI world!") |
||||
|
|
||||
|
var buf bytes.Buffer |
||||
|
if err := WritePDU(&buf, p); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
// BHS(48) + data(19) + padding(1) = 68
|
||||
|
expectedLen := 48 + pad4(uint32(len(p.DataSegment))) |
||||
|
if uint32(buf.Len()) != expectedLen { |
||||
|
t.Fatalf("wire length: expected %d, got %d", expectedLen, buf.Len()) |
||||
|
} |
||||
|
|
||||
|
p2, err := ReadPDU(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if !bytes.Equal(p2.DataSegment, []byte("hello, iSCSI world!")) { |
||||
|
t.Fatalf("data mismatch: %q", p2.DataSegment) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRoundtripWithAHSAndData(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(OpSCSICmd) |
||||
|
p.AHS = make([]byte, 8) // 2 words
|
||||
|
p.AHS[0] = 0xAA |
||||
|
p.AHS[7] = 0xBB |
||||
|
p.DataSegment = []byte{1, 2, 3, 4, 5} |
||||
|
|
||||
|
var buf bytes.Buffer |
||||
|
if err := WritePDU(&buf, p); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
p2, err := ReadPDU(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if len(p2.AHS) != 8 { |
||||
|
t.Fatalf("AHS length: expected 8, got %d", len(p2.AHS)) |
||||
|
} |
||||
|
if p2.AHS[0] != 0xAA || p2.AHS[7] != 0xBB { |
||||
|
t.Fatal("AHS content mismatch") |
||||
|
} |
||||
|
if !bytes.Equal(p2.DataSegment, []byte{1, 2, 3, 4, 5}) { |
||||
|
t.Fatal("data segment mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataPadding(t *testing.T) { |
||||
|
for _, dataLen := range []int{1, 2, 3, 4, 5, 7, 8, 100, 1023} { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(OpSCSIDataIn) |
||||
|
p.DataSegment = bytes.Repeat([]byte{0xFF}, dataLen) |
||||
|
|
||||
|
var buf bytes.Buffer |
||||
|
if err := WritePDU(&buf, p); err != nil { |
||||
|
t.Fatalf("dataLen=%d: write: %v", dataLen, err) |
||||
|
} |
||||
|
|
||||
|
expectedWire := BHSLength + int(pad4(uint32(dataLen))) |
||||
|
if buf.Len() != expectedWire { |
||||
|
t.Fatalf("dataLen=%d: wire=%d, expected=%d", dataLen, buf.Len(), expectedWire) |
||||
|
} |
||||
|
|
||||
|
p2, err := ReadPDU(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("dataLen=%d: read: %v", dataLen, err) |
||||
|
} |
||||
|
if len(p2.DataSegment) != dataLen { |
||||
|
t.Fatalf("dataLen=%d: got %d", dataLen, len(p2.DataSegment)) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testOpcodeAccessors(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
// Set immediate first, then opcode — verify no interference
|
||||
|
p.SetImmediate(true) |
||||
|
p.SetOpcode(OpLoginReq) |
||||
|
if p.Opcode() != OpLoginReq { |
||||
|
t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpLoginReq) |
||||
|
} |
||||
|
if !p.Immediate() { |
||||
|
t.Fatal("immediate flag lost") |
||||
|
} |
||||
|
|
||||
|
// Change opcode, verify immediate preserved
|
||||
|
p.SetOpcode(OpTextReq) |
||||
|
if p.Opcode() != OpTextReq { |
||||
|
t.Fatalf("opcode: got 0x%02x, want 0x%02x", p.Opcode(), OpTextReq) |
||||
|
} |
||||
|
if !p.Immediate() { |
||||
|
t.Fatal("immediate flag lost after opcode change") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testImmediateFlag(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
if p.Immediate() { |
||||
|
t.Fatal("should start false") |
||||
|
} |
||||
|
p.SetImmediate(true) |
||||
|
if !p.Immediate() { |
||||
|
t.Fatal("should be true") |
||||
|
} |
||||
|
p.SetImmediate(false) |
||||
|
if p.Immediate() { |
||||
|
t.Fatal("should be false again") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testField32Accessors(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetField32(20, 0x12345678) |
||||
|
if got := p.Field32(20); got != 0x12345678 { |
||||
|
t.Fatalf("got 0x%08x", got) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLUNAccessor(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetLUN(0x0001000000000000) // LUN 1 in SAM encoding
|
||||
|
if p.LUN() != 0x0001000000000000 { |
||||
|
t.Fatalf("LUN mismatch: 0x%016x", p.LUN()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testISIDAccessor(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
isid := [6]byte{0x00, 0x02, 0x3D, 0x00, 0x00, 0x01} |
||||
|
p.SetISID(isid) |
||||
|
got := p.ISID() |
||||
|
if got != isid { |
||||
|
t.Fatalf("ISID mismatch: %v", got) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testCDBAccessor(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = 0x28 // READ_10
|
||||
|
cdb[1] = 0x00 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 100) // LBA
|
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 8) // transfer length
|
||||
|
p.SetCDB(cdb) |
||||
|
got := p.CDB() |
||||
|
if got != cdb { |
||||
|
t.Fatal("CDB mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLoginStageAccessors(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetLoginStages(StageSecurityNeg, StageLoginOp) |
||||
|
if p.LoginCSG() != StageSecurityNeg { |
||||
|
t.Fatalf("CSG: got %d", p.LoginCSG()) |
||||
|
} |
||||
|
if p.LoginNSG() != StageLoginOp { |
||||
|
t.Fatalf("NSG: got %d", p.LoginNSG()) |
||||
|
} |
||||
|
|
||||
|
// Change to LoginOp -> FullFeature
|
||||
|
p.SetLoginStages(StageLoginOp, StageFullFeature) |
||||
|
if p.LoginCSG() != StageLoginOp { |
||||
|
t.Fatalf("CSG: got %d", p.LoginCSG()) |
||||
|
} |
||||
|
if p.LoginNSG() != StageFullFeature { |
||||
|
t.Fatalf("NSG: got %d", p.LoginNSG()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLoginTransitContinue(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
if p.LoginTransit() { |
||||
|
t.Fatal("Transit should start false") |
||||
|
} |
||||
|
if p.LoginContinue() { |
||||
|
t.Fatal("Continue should start false") |
||||
|
} |
||||
|
|
||||
|
p.SetLoginTransit(true) |
||||
|
p.SetLoginStages(StageLoginOp, StageFullFeature) |
||||
|
if !p.LoginTransit() { |
||||
|
t.Fatal("Transit should be true") |
||||
|
} |
||||
|
// Verify stages preserved
|
||||
|
if p.LoginCSG() != StageLoginOp { |
||||
|
t.Fatal("CSG lost after setting transit") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLoginStatus(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetLoginStatus(0x02, 0x01) // Initiator error, authentication failure
|
||||
|
if p.LoginStatusClass() != 0x02 { |
||||
|
t.Fatalf("class: got %d", p.LoginStatusClass()) |
||||
|
} |
||||
|
if p.LoginStatusDetail() != 0x01 { |
||||
|
t.Fatalf("detail: got %d", p.LoginStatusDetail()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSCSIResponseStatus(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetSCSIResponse(ISCSIRespCompleted) |
||||
|
p.SetSCSIStatus(SCSIStatusGood) |
||||
|
if p.SCSIResponse() != ISCSIRespCompleted { |
||||
|
t.Fatal("response mismatch") |
||||
|
} |
||||
|
if p.SCSIStatus() != SCSIStatusGood { |
||||
|
t.Fatal("status mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataSegmentLength3Byte(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
// Test various sizes including >64KB (needs all 3 bytes)
|
||||
|
for _, size := range []uint32{0, 1, 255, 256, 65535, 65536, 1<<24 - 1} { |
||||
|
p.SetDataSegmentLength(size) |
||||
|
if got := p.DataSegmentLength(); got != size { |
||||
|
t.Fatalf("size %d: got %d", size, got) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTSIHAccessor(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetTSIH(0x1234) |
||||
|
if p.TSIH() != 0x1234 { |
||||
|
t.Fatalf("TSIH: got 0x%04x", p.TSIH()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testCmdSNExpStatSN(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetCmdSN(100) |
||||
|
p.SetExpStatSN(200) |
||||
|
if p.CmdSN() != 100 { |
||||
|
t.Fatal("CmdSN mismatch") |
||||
|
} |
||||
|
if p.ExpStatSN() != 200 { |
||||
|
t.Fatal("ExpStatSN mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testStatSNExpCmdSNMaxCmdSN(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetStatSN(10) |
||||
|
p.SetExpCmdSN(20) |
||||
|
p.SetMaxCmdSN(30) |
||||
|
if p.StatSN() != 10 { |
||||
|
t.Fatal("StatSN mismatch") |
||||
|
} |
||||
|
if p.ExpCmdSN() != 20 { |
||||
|
t.Fatal("ExpCmdSN mismatch") |
||||
|
} |
||||
|
if p.MaxCmdSN() != 30 { |
||||
|
t.Fatal("MaxCmdSN mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataSNBufferOffset(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetDataSN(5) |
||||
|
p.SetBufferOffset(8192) |
||||
|
if p.DataSN() != 5 { |
||||
|
t.Fatal("DataSN mismatch") |
||||
|
} |
||||
|
if p.BufferOffset() != 8192 { |
||||
|
t.Fatal("BufferOffset mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testR2TFields(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetR2TSN(3) |
||||
|
p.SetDesiredDataLength(65536) |
||||
|
if p.R2TSN() != 3 { |
||||
|
t.Fatal("R2TSN mismatch") |
||||
|
} |
||||
|
if p.DesiredDataLength() != 65536 { |
||||
|
t.Fatal("DesiredDataLength mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testResidualCount(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetResidualCount(512) |
||||
|
if p.ResidualCount() != 512 { |
||||
|
t.Fatal("ResidualCount mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testExpectedDataTransferLength(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetExpectedDataTransferLength(1048576) |
||||
|
if p.ExpectedDataTransferLength() != 1048576 { |
||||
|
t.Fatal("mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTargetTransferTag(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetTargetTransferTag(0xFFFFFFFF) |
||||
|
if p.TargetTransferTag() != 0xFFFFFFFF { |
||||
|
t.Fatal("mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testOpcodeName(t *testing.T) { |
||||
|
if OpcodeName(OpSCSICmd) != "SCSI-Command" { |
||||
|
t.Fatal("wrong name for SCSI-Command") |
||||
|
} |
||||
|
if OpcodeName(OpLoginReq) != "Login-Request" { |
||||
|
t.Fatal("wrong name for Login-Request") |
||||
|
} |
||||
|
if !strings.HasPrefix(OpcodeName(0xFF), "Unknown") { |
||||
|
t.Fatal("unknown opcode should have Unknown prefix") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadTruncatedBHS(t *testing.T) { |
||||
|
// Only 20 bytes — not enough for BHS
|
||||
|
buf := bytes.NewReader(make([]byte, 20)) |
||||
|
_, err := ReadPDU(buf) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected error") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadTruncatedData(t *testing.T) { |
||||
|
// Valid BHS claiming 100 bytes of data, but only 10 bytes follow
|
||||
|
var bhs [BHSLength]byte |
||||
|
bhs[5] = 0 |
||||
|
bhs[6] = 0 |
||||
|
bhs[7] = 100 // DataSegmentLength = 100
|
||||
|
buf := make([]byte, BHSLength+10) |
||||
|
copy(buf, bhs[:]) |
||||
|
|
||||
|
_, err := ReadPDU(bytes.NewReader(buf)) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected truncation error") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadTruncatedAHS(t *testing.T) { |
||||
|
// BHS claiming 2 words of AHS but only 4 bytes follow
|
||||
|
var bhs [BHSLength]byte |
||||
|
bhs[4] = 2 // TotalAHSLength = 2 words = 8 bytes
|
||||
|
buf := make([]byte, BHSLength+4) |
||||
|
copy(buf, bhs[:]) |
||||
|
|
||||
|
_, err := ReadPDU(bytes.NewReader(buf)) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected truncation error") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadOversizedData(t *testing.T) { |
||||
|
// BHS claiming MaxDataSegmentLength+1 bytes
|
||||
|
var bhs [BHSLength]byte |
||||
|
oversized := uint32(MaxDataSegmentLength + 1) |
||||
|
bhs[5] = byte(oversized >> 16) |
||||
|
bhs[6] = byte(oversized >> 8) |
||||
|
bhs[7] = byte(oversized) |
||||
|
|
||||
|
_, err := ReadPDU(bytes.NewReader(bhs[:])) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected oversized error") |
||||
|
} |
||||
|
if !strings.Contains(err.Error(), "exceeds maximum") { |
||||
|
t.Fatalf("unexpected error: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadEOF(t *testing.T) { |
||||
|
_, err := ReadPDU(bytes.NewReader(nil)) |
||||
|
if err != io.EOF { |
||||
|
t.Fatalf("expected io.EOF, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWriteInvalidAHSLength(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.AHS = make([]byte, 5) // not multiple of 4
|
||||
|
var buf bytes.Buffer |
||||
|
err := WritePDU(&buf, p) |
||||
|
if err != ErrInvalidAHSLength { |
||||
|
t.Fatalf("expected ErrInvalidAHSLength, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRoundtripAllOpcodes(t *testing.T) { |
||||
|
opcodes := []uint8{ |
||||
|
OpNOPOut, OpSCSICmd, OpSCSITaskMgmt, OpLoginReq, OpTextReq, |
||||
|
OpSCSIDataOut, OpLogoutReq, OpSNACKReq, |
||||
|
OpNOPIn, OpSCSIResp, OpSCSITaskResp, OpLoginResp, OpTextResp, |
||||
|
OpSCSIDataIn, OpLogoutResp, OpR2T, OpAsyncMsg, OpReject, |
||||
|
} |
||||
|
for _, op := range opcodes { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(op) |
||||
|
var buf bytes.Buffer |
||||
|
if err := WritePDU(&buf, p); err != nil { |
||||
|
t.Fatalf("op=0x%02x: write: %v", op, err) |
||||
|
} |
||||
|
p2, err := ReadPDU(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("op=0x%02x: read: %v", op, err) |
||||
|
} |
||||
|
if p2.Opcode() != op { |
||||
|
t.Fatalf("op=0x%02x: got 0x%02x", op, p2.Opcode()) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDataSegmentExact4ByteBoundary(t *testing.T) { |
||||
|
for _, size := range []int{4, 8, 12, 16, 256, 4096} { |
||||
|
p := &PDU{} |
||||
|
p.DataSegment = bytes.Repeat([]byte{0xAB}, size) |
||||
|
var buf bytes.Buffer |
||||
|
if err := WritePDU(&buf, p); err != nil { |
||||
|
t.Fatalf("size=%d: %v", size, err) |
||||
|
} |
||||
|
// No padding needed
|
||||
|
if buf.Len() != BHSLength+size { |
||||
|
t.Fatalf("size=%d: wire=%d, expected=%d", size, buf.Len(), BHSLength+size) |
||||
|
} |
||||
|
p2, err := ReadPDU(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("size=%d: read: %v", size, err) |
||||
|
} |
||||
|
if len(p2.DataSegment) != size { |
||||
|
t.Fatalf("size=%d: got %d", size, len(p2.DataSegment)) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testZeroLengthDataSegment(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
p.SetOpcode(OpNOPOut) |
||||
|
// DataSegment is nil
|
||||
|
|
||||
|
var buf bytes.Buffer |
||||
|
if err := WritePDU(&buf, p); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if buf.Len() != BHSLength { |
||||
|
t.Fatalf("expected %d, got %d", BHSLength, buf.Len()) |
||||
|
} |
||||
|
|
||||
|
p2, err := ReadPDU(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if len(p2.DataSegment) != 0 { |
||||
|
t.Fatalf("expected empty data segment, got %d bytes", len(p2.DataSegment)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testMax3ByteDataLength(t *testing.T) { |
||||
|
p := &PDU{} |
||||
|
maxLen := uint32(1<<24 - 1) // 16MB - 1
|
||||
|
p.SetDataSegmentLength(maxLen) |
||||
|
if got := p.DataSegmentLength(); got != maxLen { |
||||
|
t.Fatalf("expected %d, got %d", maxLen, got) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,463 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
) |
||||
|
|
||||
|
// SCSI opcode constants (SPC-5 / SBC-4)
|
||||
|
const ( |
||||
|
ScsiTestUnitReady uint8 = 0x00 |
||||
|
ScsiInquiry uint8 = 0x12 |
||||
|
ScsiModeSense6 uint8 = 0x1a |
||||
|
ScsiReadCapacity10 uint8 = 0x25 |
||||
|
ScsiRead10 uint8 = 0x28 |
||||
|
ScsiWrite10 uint8 = 0x2a |
||||
|
ScsiSyncCache10 uint8 = 0x35 |
||||
|
ScsiUnmap uint8 = 0x42 |
||||
|
ScsiReportLuns uint8 = 0xa0 |
||||
|
ScsiRead16 uint8 = 0x88 |
||||
|
ScsiWrite16 uint8 = 0x8a |
||||
|
ScsiReadCapacity16 uint8 = 0x9e // SERVICE ACTION IN (16), SA=0x10
|
||||
|
ScsiSyncCache16 uint8 = 0x91 |
||||
|
) |
||||
|
|
||||
|
// Service action for READ CAPACITY (16)
|
||||
|
const ScsiSAReadCapacity16 uint8 = 0x10 |
||||
|
|
||||
|
// SCSI sense keys
|
||||
|
const ( |
||||
|
SenseNoSense uint8 = 0x00 |
||||
|
SenseNotReady uint8 = 0x02 |
||||
|
SenseMediumError uint8 = 0x03 |
||||
|
SenseHardwareError uint8 = 0x04 |
||||
|
SenseIllegalRequest uint8 = 0x05 |
||||
|
SenseAbortedCommand uint8 = 0x0b |
||||
|
) |
||||
|
|
||||
|
// ASC/ASCQ pairs
|
||||
|
const ( |
||||
|
ASCInvalidOpcode uint8 = 0x20 |
||||
|
ASCQLuk uint8 = 0x00 |
||||
|
ASCInvalidFieldInCDB uint8 = 0x24 |
||||
|
ASCLBAOutOfRange uint8 = 0x21 |
||||
|
ASCNotReady uint8 = 0x04 |
||||
|
ASCQNotReady uint8 = 0x03 // manual intervention required
|
||||
|
) |
||||
|
|
||||
|
// BlockDevice is the interface that the SCSI command handler uses to
|
||||
|
// interact with the underlying storage. This maps onto BlockVol.
|
||||
|
type BlockDevice interface { |
||||
|
ReadAt(lba uint64, length uint32) ([]byte, error) |
||||
|
WriteAt(lba uint64, data []byte) error |
||||
|
Trim(lba uint64, length uint32) error |
||||
|
SyncCache() error |
||||
|
BlockSize() uint32 |
||||
|
VolumeSize() uint64 // total size in bytes
|
||||
|
IsHealthy() bool |
||||
|
} |
||||
|
|
||||
|
// SCSIHandler processes SCSI commands from iSCSI PDUs.
|
||||
|
type SCSIHandler struct { |
||||
|
dev BlockDevice |
||||
|
vendorID string // 8 bytes for INQUIRY
|
||||
|
prodID string // 16 bytes for INQUIRY
|
||||
|
serial string // for VPD page 0x80
|
||||
|
} |
||||
|
|
||||
|
// NewSCSIHandler creates a SCSI command handler for the given block device.
|
||||
|
func NewSCSIHandler(dev BlockDevice) *SCSIHandler { |
||||
|
return &SCSIHandler{ |
||||
|
dev: dev, |
||||
|
vendorID: "SeaweedF", |
||||
|
prodID: "BlockVol ", |
||||
|
serial: "SWF00001", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// SCSIResult holds the result of a SCSI command execution.
|
||||
|
type SCSIResult struct { |
||||
|
Status uint8 // SCSI status
|
||||
|
Data []byte // Response data (for Data-In)
|
||||
|
SenseKey uint8 // Sense key (if CHECK_CONDITION)
|
||||
|
SenseASC uint8 // Additional sense code
|
||||
|
SenseASCQ uint8 // Additional sense code qualifier
|
||||
|
} |
||||
|
|
||||
|
// HandleCommand dispatches a SCSI CDB to the appropriate handler.
|
||||
|
// dataOut contains any data sent by the initiator (for WRITE commands).
|
||||
|
func (h *SCSIHandler) HandleCommand(cdb [16]byte, dataOut []byte) SCSIResult { |
||||
|
opcode := cdb[0] |
||||
|
|
||||
|
switch opcode { |
||||
|
case ScsiTestUnitReady: |
||||
|
return h.testUnitReady() |
||||
|
case ScsiInquiry: |
||||
|
return h.inquiry(cdb) |
||||
|
case ScsiModeSense6: |
||||
|
return h.modeSense6(cdb) |
||||
|
case ScsiReadCapacity10: |
||||
|
return h.readCapacity10() |
||||
|
case ScsiReadCapacity16: |
||||
|
sa := cdb[1] & 0x1f |
||||
|
if sa == ScsiSAReadCapacity16 { |
||||
|
return h.readCapacity16(cdb) |
||||
|
} |
||||
|
return illegalRequest(ASCInvalidOpcode, ASCQLuk) |
||||
|
case ScsiReportLuns: |
||||
|
return h.reportLuns(cdb) |
||||
|
case ScsiRead10: |
||||
|
return h.read10(cdb) |
||||
|
case ScsiRead16: |
||||
|
return h.read16(cdb) |
||||
|
case ScsiWrite10: |
||||
|
return h.write10(cdb, dataOut) |
||||
|
case ScsiWrite16: |
||||
|
return h.write16(cdb, dataOut) |
||||
|
case ScsiSyncCache10: |
||||
|
return h.syncCache() |
||||
|
case ScsiSyncCache16: |
||||
|
return h.syncCache() |
||||
|
case ScsiUnmap: |
||||
|
return h.unmap(cdb, dataOut) |
||||
|
default: |
||||
|
return illegalRequest(ASCInvalidOpcode, ASCQLuk) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// --- Metadata commands ---
|
||||
|
|
||||
|
func (h *SCSIHandler) testUnitReady() SCSIResult { |
||||
|
if !h.dev.IsHealthy() { |
||||
|
return SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: SenseNotReady, |
||||
|
SenseASC: ASCNotReady, |
||||
|
SenseASCQ: ASCQNotReady, |
||||
|
} |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) inquiry(cdb [16]byte) SCSIResult { |
||||
|
evpd := cdb[1] & 0x01 |
||||
|
pageCode := cdb[2] |
||||
|
allocLen := binary.BigEndian.Uint16(cdb[3:5]) |
||||
|
if allocLen == 0 { |
||||
|
allocLen = 36 |
||||
|
} |
||||
|
|
||||
|
if evpd != 0 { |
||||
|
return h.inquiryVPD(pageCode, allocLen) |
||||
|
} |
||||
|
|
||||
|
// Standard INQUIRY response (SPC-5, Section 6.6.1)
|
||||
|
data := make([]byte, 96) |
||||
|
data[0] = 0x00 // Peripheral device type: SBC (direct access block device)
|
||||
|
data[1] = 0x00 // RMB=0 (not removable)
|
||||
|
data[2] = 0x06 // SPC-4 version
|
||||
|
data[3] = 0x02 // Response data format = 2 (SPC-2+)
|
||||
|
data[4] = 91 // Additional length (96-5)
|
||||
|
data[5] = 0x00 // SCCS, ACC, TPGS, 3PC
|
||||
|
data[6] = 0x00 // Obsolete, EncServ, VS, MultiP
|
||||
|
data[7] = 0x02 // CmdQue=1 (supports command queuing)
|
||||
|
|
||||
|
// Vendor ID (bytes 8-15, 8 chars, space padded)
|
||||
|
copy(data[8:16], padRight(h.vendorID, 8)) |
||||
|
// Product ID (bytes 16-31, 16 chars, space padded)
|
||||
|
copy(data[16:32], padRight(h.prodID, 16)) |
||||
|
// Product revision (bytes 32-35, 4 chars)
|
||||
|
copy(data[32:36], "0001") |
||||
|
|
||||
|
if int(allocLen) < len(data) { |
||||
|
data = data[:allocLen] |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) inquiryVPD(pageCode uint8, allocLen uint16) SCSIResult { |
||||
|
switch pageCode { |
||||
|
case 0x00: // Supported VPD pages
|
||||
|
data := []byte{ |
||||
|
0x00, // device type
|
||||
|
0x00, // page code
|
||||
|
0x00, 0x03, // page length
|
||||
|
0x00, // supported pages: 0x00
|
||||
|
0x80, // 0x80 (serial)
|
||||
|
0x83, // 0x83 (device identification)
|
||||
|
} |
||||
|
if int(allocLen) < len(data) { |
||||
|
data = data[:allocLen] |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
|
||||
|
case 0x80: // Unit serial number
|
||||
|
serial := padRight(h.serial, 8) |
||||
|
data := make([]byte, 4+len(serial)) |
||||
|
data[0] = 0x00 // device type
|
||||
|
data[1] = 0x80 // page code
|
||||
|
binary.BigEndian.PutUint16(data[2:4], uint16(len(serial))) |
||||
|
copy(data[4:], serial) |
||||
|
if int(allocLen) < len(data) { |
||||
|
data = data[:allocLen] |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
|
||||
|
case 0x83: // Device identification
|
||||
|
// NAA identifier (8 bytes)
|
||||
|
naaID := []byte{ |
||||
|
0x01, // code set: binary
|
||||
|
0x03, // identifier type: NAA
|
||||
|
0x00, // reserved
|
||||
|
0x08, // identifier length
|
||||
|
0x60, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // NAA-6 fake
|
||||
|
} |
||||
|
data := make([]byte, 4+len(naaID)) |
||||
|
data[0] = 0x00 // device type
|
||||
|
data[1] = 0x83 // page code
|
||||
|
binary.BigEndian.PutUint16(data[2:4], uint16(len(naaID))) |
||||
|
copy(data[4:], naaID) |
||||
|
if int(allocLen) < len(data) { |
||||
|
data = data[:allocLen] |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
|
||||
|
default: |
||||
|
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) readCapacity10() SCSIResult { |
||||
|
blockSize := h.dev.BlockSize() |
||||
|
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
||||
|
|
||||
|
data := make([]byte, 8) |
||||
|
// If >2TB (blocks > 0xFFFFFFFF), return 0xFFFFFFFF to signal use READ_CAPACITY_16
|
||||
|
if totalBlocks > 0xFFFFFFFF { |
||||
|
binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF) |
||||
|
} else { |
||||
|
binary.BigEndian.PutUint32(data[0:4], uint32(totalBlocks-1)) // last LBA
|
||||
|
} |
||||
|
binary.BigEndian.PutUint32(data[4:8], blockSize) |
||||
|
|
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) readCapacity16(cdb [16]byte) SCSIResult { |
||||
|
allocLen := binary.BigEndian.Uint32(cdb[10:14]) |
||||
|
if allocLen < 32 { |
||||
|
allocLen = 32 |
||||
|
} |
||||
|
|
||||
|
blockSize := h.dev.BlockSize() |
||||
|
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
||||
|
|
||||
|
data := make([]byte, 32) |
||||
|
binary.BigEndian.PutUint64(data[0:8], totalBlocks-1) // last LBA
|
||||
|
binary.BigEndian.PutUint32(data[8:12], blockSize) // block length
|
||||
|
// data[12]: LBPME (logical block provisioning management enabled) = 1 for UNMAP support
|
||||
|
data[14] = 0x80 // LBPME bit
|
||||
|
|
||||
|
if allocLen < uint32(len(data)) { |
||||
|
data = data[:allocLen] |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) modeSense6(cdb [16]byte) SCSIResult { |
||||
|
// Minimal MODE SENSE(6) response — no mode pages
|
||||
|
allocLen := cdb[4] |
||||
|
if allocLen == 0 { |
||||
|
allocLen = 4 |
||||
|
} |
||||
|
|
||||
|
data := make([]byte, 4) |
||||
|
data[0] = 3 // Mode data length (3 bytes follow)
|
||||
|
data[1] = 0x00 // Medium type: default
|
||||
|
data[2] = 0x00 // Device-specific parameter (no write protect)
|
||||
|
data[3] = 0x00 // Block descriptor length = 0
|
||||
|
|
||||
|
if int(allocLen) < len(data) { |
||||
|
data = data[:allocLen] |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) reportLuns(cdb [16]byte) SCSIResult { |
||||
|
allocLen := binary.BigEndian.Uint32(cdb[6:10]) |
||||
|
if allocLen < 16 { |
||||
|
allocLen = 16 |
||||
|
} |
||||
|
|
||||
|
// Report a single LUN (LUN 0)
|
||||
|
data := make([]byte, 16) |
||||
|
binary.BigEndian.PutUint32(data[0:4], 8) // LUN list length (8 bytes, 1 LUN)
|
||||
|
// data[4:7] reserved
|
||||
|
// data[8:16] = LUN 0 (all zeros)
|
||||
|
|
||||
|
if allocLen < uint32(len(data)) { |
||||
|
data = data[:allocLen] |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
} |
||||
|
|
||||
|
// --- Data commands (Task 2.6) ---
|
||||
|
|
||||
|
func (h *SCSIHandler) read10(cdb [16]byte) SCSIResult { |
||||
|
lba := uint64(binary.BigEndian.Uint32(cdb[2:6])) |
||||
|
transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9])) |
||||
|
return h.doRead(lba, transferLen) |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) read16(cdb [16]byte) SCSIResult { |
||||
|
lba := binary.BigEndian.Uint64(cdb[2:10]) |
||||
|
transferLen := binary.BigEndian.Uint32(cdb[10:14]) |
||||
|
return h.doRead(lba, transferLen) |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) write10(cdb [16]byte, dataOut []byte) SCSIResult { |
||||
|
lba := uint64(binary.BigEndian.Uint32(cdb[2:6])) |
||||
|
transferLen := uint32(binary.BigEndian.Uint16(cdb[7:9])) |
||||
|
return h.doWrite(lba, transferLen, dataOut) |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) write16(cdb [16]byte, dataOut []byte) SCSIResult { |
||||
|
lba := binary.BigEndian.Uint64(cdb[2:10]) |
||||
|
transferLen := binary.BigEndian.Uint32(cdb[10:14]) |
||||
|
return h.doWrite(lba, transferLen, dataOut) |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) doRead(lba uint64, transferLen uint32) SCSIResult { |
||||
|
if transferLen == 0 { |
||||
|
return SCSIResult{Status: SCSIStatusGood} |
||||
|
} |
||||
|
|
||||
|
blockSize := h.dev.BlockSize() |
||||
|
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
||||
|
|
||||
|
if lba+uint64(transferLen) > totalBlocks { |
||||
|
return illegalRequest(ASCLBAOutOfRange, ASCQLuk) |
||||
|
} |
||||
|
|
||||
|
byteLen := transferLen * blockSize |
||||
|
data, err := h.dev.ReadAt(lba, byteLen) |
||||
|
if err != nil { |
||||
|
return SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: SenseMediumError, |
||||
|
SenseASC: 0x11, // Unrecovered read error
|
||||
|
SenseASCQ: 0x00, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return SCSIResult{Status: SCSIStatusGood, Data: data} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) doWrite(lba uint64, transferLen uint32, dataOut []byte) SCSIResult { |
||||
|
if transferLen == 0 { |
||||
|
return SCSIResult{Status: SCSIStatusGood} |
||||
|
} |
||||
|
|
||||
|
blockSize := h.dev.BlockSize() |
||||
|
totalBlocks := h.dev.VolumeSize() / uint64(blockSize) |
||||
|
|
||||
|
if lba+uint64(transferLen) > totalBlocks { |
||||
|
return illegalRequest(ASCLBAOutOfRange, ASCQLuk) |
||||
|
} |
||||
|
|
||||
|
expectedBytes := transferLen * blockSize |
||||
|
if uint32(len(dataOut)) < expectedBytes { |
||||
|
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
||||
|
} |
||||
|
|
||||
|
if err := h.dev.WriteAt(lba, dataOut[:expectedBytes]); err != nil { |
||||
|
return SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: SenseMediumError, |
||||
|
SenseASC: 0x0C, // Write error
|
||||
|
SenseASCQ: 0x00, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return SCSIResult{Status: SCSIStatusGood} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) syncCache() SCSIResult { |
||||
|
if err := h.dev.SyncCache(); err != nil { |
||||
|
return SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: SenseHardwareError, |
||||
|
SenseASC: 0x00, |
||||
|
SenseASCQ: 0x00, |
||||
|
} |
||||
|
} |
||||
|
return SCSIResult{Status: SCSIStatusGood} |
||||
|
} |
||||
|
|
||||
|
func (h *SCSIHandler) unmap(cdb [16]byte, dataOut []byte) SCSIResult { |
||||
|
if len(dataOut) < 8 { |
||||
|
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
||||
|
} |
||||
|
|
||||
|
// UNMAP parameter list header (8 bytes)
|
||||
|
// descLen := binary.BigEndian.Uint16(dataOut[0:2]) // data length (unused)
|
||||
|
blockDescLen := binary.BigEndian.Uint16(dataOut[2:4]) |
||||
|
|
||||
|
if int(blockDescLen)+8 > len(dataOut) { |
||||
|
return illegalRequest(ASCInvalidFieldInCDB, ASCQLuk) |
||||
|
} |
||||
|
|
||||
|
// Each UNMAP block descriptor is 16 bytes
|
||||
|
descData := dataOut[8 : 8+blockDescLen] |
||||
|
for len(descData) >= 16 { |
||||
|
lba := binary.BigEndian.Uint64(descData[0:8]) |
||||
|
numBlocks := binary.BigEndian.Uint32(descData[8:12]) |
||||
|
// descData[12:16] reserved
|
||||
|
|
||||
|
if numBlocks > 0 { |
||||
|
blockSize := h.dev.BlockSize() |
||||
|
if err := h.dev.Trim(lba, numBlocks*blockSize); err != nil { |
||||
|
return SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: SenseMediumError, |
||||
|
SenseASC: 0x0C, |
||||
|
SenseASCQ: 0x00, |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
descData = descData[16:] |
||||
|
} |
||||
|
|
||||
|
return SCSIResult{Status: SCSIStatusGood} |
||||
|
} |
||||
|
|
||||
|
// BuildSenseData constructs a fixed-format sense data buffer (18 bytes).
|
||||
|
func BuildSenseData(key, asc, ascq uint8) []byte { |
||||
|
data := make([]byte, 18) |
||||
|
data[0] = 0x70 // Response code: current errors, fixed format
|
||||
|
data[2] = key & 0x0f // Sense key
|
||||
|
data[7] = 10 // Additional sense length
|
||||
|
data[12] = asc // ASC
|
||||
|
data[13] = ascq // ASCQ
|
||||
|
return data |
||||
|
} |
||||
|
|
||||
|
func illegalRequest(asc, ascq uint8) SCSIResult { |
||||
|
return SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: SenseIllegalRequest, |
||||
|
SenseASC: asc, |
||||
|
SenseASCQ: ascq, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func padRight(s string, n int) string { |
||||
|
if len(s) >= n { |
||||
|
return s[:n] |
||||
|
} |
||||
|
b := make([]byte, n) |
||||
|
copy(b, s) |
||||
|
for i := len(s); i < n; i++ { |
||||
|
b[i] = ' ' |
||||
|
} |
||||
|
return string(b) |
||||
|
} |
||||
@ -0,0 +1,692 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
// mockBlockDevice implements BlockDevice for testing.
|
||||
|
type mockBlockDevice struct { |
||||
|
blockSize uint32 |
||||
|
volumeSize uint64 |
||||
|
healthy bool |
||||
|
blocks map[uint64][]byte // LBA -> data
|
||||
|
syncErr error |
||||
|
readErr error |
||||
|
writeErr error |
||||
|
trimErr error |
||||
|
} |
||||
|
|
||||
|
func newMockDevice(volumeSize uint64) *mockBlockDevice { |
||||
|
return &mockBlockDevice{ |
||||
|
blockSize: 4096, |
||||
|
volumeSize: volumeSize, |
||||
|
healthy: true, |
||||
|
blocks: make(map[uint64][]byte), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (m *mockBlockDevice) ReadAt(lba uint64, length uint32) ([]byte, error) { |
||||
|
if m.readErr != nil { |
||||
|
return nil, m.readErr |
||||
|
} |
||||
|
blockCount := length / m.blockSize |
||||
|
result := make([]byte, length) |
||||
|
for i := uint32(0); i < blockCount; i++ { |
||||
|
if data, ok := m.blocks[lba+uint64(i)]; ok { |
||||
|
copy(result[i*m.blockSize:], data) |
||||
|
} |
||||
|
// Unwritten blocks return zeros (already zeroed)
|
||||
|
} |
||||
|
return result, nil |
||||
|
} |
||||
|
|
||||
|
func (m *mockBlockDevice) WriteAt(lba uint64, data []byte) error { |
||||
|
if m.writeErr != nil { |
||||
|
return m.writeErr |
||||
|
} |
||||
|
blockCount := uint32(len(data)) / m.blockSize |
||||
|
for i := uint32(0); i < blockCount; i++ { |
||||
|
block := make([]byte, m.blockSize) |
||||
|
copy(block, data[i*m.blockSize:]) |
||||
|
m.blocks[lba+uint64(i)] = block |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (m *mockBlockDevice) Trim(lba uint64, length uint32) error { |
||||
|
if m.trimErr != nil { |
||||
|
return m.trimErr |
||||
|
} |
||||
|
blockCount := length / m.blockSize |
||||
|
for i := uint32(0); i < blockCount; i++ { |
||||
|
delete(m.blocks, lba+uint64(i)) |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (m *mockBlockDevice) SyncCache() error { return m.syncErr } |
||||
|
func (m *mockBlockDevice) BlockSize() uint32 { return m.blockSize } |
||||
|
func (m *mockBlockDevice) VolumeSize() uint64 { return m.volumeSize } |
||||
|
func (m *mockBlockDevice) IsHealthy() bool { return m.healthy } |
||||
|
|
||||
|
func TestSCSI(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"test_unit_ready_good", testTestUnitReadyGood}, |
||||
|
{"test_unit_ready_not_ready", testTestUnitReadyNotReady}, |
||||
|
{"inquiry_standard", testInquiryStandard}, |
||||
|
{"inquiry_vpd_supported_pages", testInquiryVPDSupportedPages}, |
||||
|
{"inquiry_vpd_serial", testInquiryVPDSerial}, |
||||
|
{"inquiry_vpd_device_id", testInquiryVPDDeviceID}, |
||||
|
{"inquiry_vpd_unknown_page", testInquiryVPDUnknownPage}, |
||||
|
{"inquiry_alloc_length", testInquiryAllocLength}, |
||||
|
{"read_capacity_10", testReadCapacity10}, |
||||
|
{"read_capacity_10_large", testReadCapacity10Large}, |
||||
|
{"read_capacity_16", testReadCapacity16}, |
||||
|
{"read_capacity_16_lbpme", testReadCapacity16LBPME}, |
||||
|
{"mode_sense_6", testModeSense6}, |
||||
|
{"report_luns", testReportLuns}, |
||||
|
{"unknown_opcode", testUnknownOpcode}, |
||||
|
{"read_10", testRead10}, |
||||
|
{"read_16", testRead16}, |
||||
|
{"write_10", testWrite10}, |
||||
|
{"write_16", testWrite16}, |
||||
|
{"read_write_roundtrip", testReadWriteRoundtrip}, |
||||
|
{"write_oob", testWriteOOB}, |
||||
|
{"read_oob", testReadOOB}, |
||||
|
{"zero_length_transfer", testZeroLengthTransfer}, |
||||
|
{"sync_cache", testSyncCache}, |
||||
|
{"sync_cache_error", testSyncCacheError}, |
||||
|
{"unmap_single", testUnmapSingle}, |
||||
|
{"unmap_multiple_descriptors", testUnmapMultipleDescriptors}, |
||||
|
{"unmap_short_param", testUnmapShortParam}, |
||||
|
{"build_sense_data", testBuildSenseData}, |
||||
|
{"read_error", testReadError}, |
||||
|
{"write_error", testWriteError}, |
||||
|
{"read_capacity_16_invalid_sa", testReadCapacity16InvalidSA}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTestUnitReadyGood(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiTestUnitReady |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatalf("status: %d", r.Status) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTestUnitReadyNotReady(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
dev.healthy = false |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiTestUnitReady |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatalf("status: %d", r.Status) |
||||
|
} |
||||
|
if r.SenseKey != SenseNotReady { |
||||
|
t.Fatalf("sense key: %d", r.SenseKey) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInquiryStandard(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiInquiry |
||||
|
binary.BigEndian.PutUint16(cdb[3:5], 96) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatalf("status: %d", r.Status) |
||||
|
} |
||||
|
if len(r.Data) != 96 { |
||||
|
t.Fatalf("data length: %d", len(r.Data)) |
||||
|
} |
||||
|
// Check peripheral device type
|
||||
|
if r.Data[0] != 0x00 { |
||||
|
t.Fatal("not SBC device type") |
||||
|
} |
||||
|
// Check vendor
|
||||
|
vendor := string(r.Data[8:16]) |
||||
|
if vendor != "SeaweedF" { |
||||
|
t.Fatalf("vendor: %q", vendor) |
||||
|
} |
||||
|
// Check CmdQue
|
||||
|
if r.Data[7]&0x02 == 0 { |
||||
|
t.Fatal("CmdQue not set") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInquiryVPDSupportedPages(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiInquiry |
||||
|
cdb[1] = 0x01 // EVPD
|
||||
|
cdb[2] = 0x00 // Supported pages
|
||||
|
binary.BigEndian.PutUint16(cdb[3:5], 255) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatalf("status: %d", r.Status) |
||||
|
} |
||||
|
if r.Data[1] != 0x00 { |
||||
|
t.Fatal("wrong page code") |
||||
|
} |
||||
|
// Should list pages 0x00, 0x80, 0x83
|
||||
|
if len(r.Data) < 7 || r.Data[4] != 0x00 || r.Data[5] != 0x80 || r.Data[6] != 0x83 { |
||||
|
t.Fatalf("supported pages: %v", r.Data) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInquiryVPDSerial(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiInquiry |
||||
|
cdb[1] = 0x01 |
||||
|
cdb[2] = 0x80 |
||||
|
binary.BigEndian.PutUint16(cdb[3:5], 255) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatalf("status: %d", r.Status) |
||||
|
} |
||||
|
if r.Data[1] != 0x80 { |
||||
|
t.Fatal("wrong page code") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInquiryVPDDeviceID(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiInquiry |
||||
|
cdb[1] = 0x01 |
||||
|
cdb[2] = 0x83 |
||||
|
binary.BigEndian.PutUint16(cdb[3:5], 255) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatalf("status: %d", r.Status) |
||||
|
} |
||||
|
if r.Data[1] != 0x83 { |
||||
|
t.Fatal("wrong page code") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInquiryVPDUnknownPage(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiInquiry |
||||
|
cdb[1] = 0x01 |
||||
|
cdb[2] = 0xFF // unknown page
|
||||
|
binary.BigEndian.PutUint16(cdb[3:5], 255) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("expected CHECK_CONDITION") |
||||
|
} |
||||
|
if r.SenseKey != SenseIllegalRequest { |
||||
|
t.Fatal("expected ILLEGAL_REQUEST") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testInquiryAllocLength(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiInquiry |
||||
|
binary.BigEndian.PutUint16(cdb[3:5], 10) // small alloc
|
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("should succeed") |
||||
|
} |
||||
|
if len(r.Data) != 10 { |
||||
|
t.Fatalf("truncation: expected 10, got %d", len(r.Data)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadCapacity10(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) // 100 blocks
|
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiReadCapacity10 |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("status not good") |
||||
|
} |
||||
|
if len(r.Data) != 8 { |
||||
|
t.Fatalf("data length: %d", len(r.Data)) |
||||
|
} |
||||
|
lastLBA := binary.BigEndian.Uint32(r.Data[0:4]) |
||||
|
blockSize := binary.BigEndian.Uint32(r.Data[4:8]) |
||||
|
if lastLBA != 99 { |
||||
|
t.Fatalf("last LBA: %d, expected 99", lastLBA) |
||||
|
} |
||||
|
if blockSize != 4096 { |
||||
|
t.Fatalf("block size: %d", blockSize) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadCapacity10Large(t *testing.T) { |
||||
|
// Volume with >2^32 blocks should return 0xFFFFFFFF
|
||||
|
dev := newMockDevice(uint64(0x100000001) * 4096) // 2^32+1 blocks
|
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiReadCapacity10 |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
lastLBA := binary.BigEndian.Uint32(r.Data[0:4]) |
||||
|
if lastLBA != 0xFFFFFFFF { |
||||
|
t.Fatalf("should return 0xFFFFFFFF for >2TB, got %d", lastLBA) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadCapacity16(t *testing.T) { |
||||
|
dev := newMockDevice(3 * 1024 * 1024 * 1024 * 1024) // 3 TB
|
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiReadCapacity16 |
||||
|
cdb[1] = ScsiSAReadCapacity16 |
||||
|
binary.BigEndian.PutUint32(cdb[10:14], 32) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("status not good") |
||||
|
} |
||||
|
lastLBA := binary.BigEndian.Uint64(r.Data[0:8]) |
||||
|
expectedBlocks := uint64(3*1024*1024*1024*1024) / 4096 |
||||
|
if lastLBA != expectedBlocks-1 { |
||||
|
t.Fatalf("last LBA: %d, expected %d", lastLBA, expectedBlocks-1) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadCapacity16LBPME(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiReadCapacity16 |
||||
|
cdb[1] = ScsiSAReadCapacity16 |
||||
|
binary.BigEndian.PutUint32(cdb[10:14], 32) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
// LBPME bit should be set (byte 14, bit 7)
|
||||
|
if r.Data[14]&0x80 == 0 { |
||||
|
t.Fatal("LBPME bit not set") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testModeSense6(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiModeSense6 |
||||
|
cdb[4] = 255 |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("status not good") |
||||
|
} |
||||
|
if len(r.Data) != 4 { |
||||
|
t.Fatalf("mode sense data: %d bytes", len(r.Data)) |
||||
|
} |
||||
|
// No write protect
|
||||
|
if r.Data[2]&0x80 != 0 { |
||||
|
t.Fatal("write protect set") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReportLuns(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiReportLuns |
||||
|
binary.BigEndian.PutUint32(cdb[6:10], 256) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("status not good") |
||||
|
} |
||||
|
lunListLen := binary.BigEndian.Uint32(r.Data[0:4]) |
||||
|
if lunListLen != 8 { |
||||
|
t.Fatalf("LUN list length: %d (expected 8 for 1 LUN)", lunListLen) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testUnknownOpcode(t *testing.T) { |
||||
|
dev := newMockDevice(1024 * 1024) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = 0xFF |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("expected CHECK_CONDITION") |
||||
|
} |
||||
|
if r.SenseKey != SenseIllegalRequest { |
||||
|
t.Fatal("expected ILLEGAL_REQUEST") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRead10(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
// Write some data first
|
||||
|
data := make([]byte, 4096) |
||||
|
for i := range data { |
||||
|
data[i] = 0xAB |
||||
|
} |
||||
|
dev.blocks[5] = data |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 5) // LBA=5
|
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block
|
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("read failed") |
||||
|
} |
||||
|
if len(r.Data) != 4096 { |
||||
|
t.Fatalf("data length: %d", len(r.Data)) |
||||
|
} |
||||
|
if r.Data[0] != 0xAB { |
||||
|
t.Fatal("data mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRead16(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
data := make([]byte, 4096) |
||||
|
data[0] = 0xCD |
||||
|
dev.blocks[10] = data |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiRead16 |
||||
|
binary.BigEndian.PutUint64(cdb[2:10], 10) |
||||
|
binary.BigEndian.PutUint32(cdb[10:14], 1) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("read16 failed") |
||||
|
} |
||||
|
if r.Data[0] != 0xCD { |
||||
|
t.Fatal("data mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWrite10(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
dataOut := make([]byte, 4096) |
||||
|
dataOut[0] = 0xEF |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 7) |
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) |
||||
|
r := h.HandleCommand(cdb, dataOut) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("write failed") |
||||
|
} |
||||
|
if dev.blocks[7][0] != 0xEF { |
||||
|
t.Fatal("data not written") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWrite16(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
dataOut := make([]byte, 8192) |
||||
|
dataOut[0] = 0x11 |
||||
|
dataOut[4096] = 0x22 |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiWrite16 |
||||
|
binary.BigEndian.PutUint64(cdb[2:10], 50) |
||||
|
binary.BigEndian.PutUint32(cdb[10:14], 2) |
||||
|
r := h.HandleCommand(cdb, dataOut) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("write16 failed") |
||||
|
} |
||||
|
if dev.blocks[50][0] != 0x11 { |
||||
|
t.Fatal("block 50 wrong") |
||||
|
} |
||||
|
if dev.blocks[51][0] != 0x22 { |
||||
|
t.Fatal("block 51 wrong") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadWriteRoundtrip(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
// Write
|
||||
|
dataOut := make([]byte, 4096) |
||||
|
for i := range dataOut { |
||||
|
dataOut[i] = byte(i % 256) |
||||
|
} |
||||
|
var wcdb [16]byte |
||||
|
wcdb[0] = ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(wcdb[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(wcdb[7:9], 1) |
||||
|
h.HandleCommand(wcdb, dataOut) |
||||
|
|
||||
|
// Read back
|
||||
|
var rcdb [16]byte |
||||
|
rcdb[0] = ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(rcdb[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(rcdb[7:9], 1) |
||||
|
r := h.HandleCommand(rcdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("read failed") |
||||
|
} |
||||
|
for i := 0; i < 4096; i++ { |
||||
|
if r.Data[i] != byte(i%256) { |
||||
|
t.Fatalf("byte %d: got %d, want %d", i, r.Data[i], i%256) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWriteOOB(t *testing.T) { |
||||
|
dev := newMockDevice(10 * 4096) // 10 blocks
|
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 9) |
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 2) // LBA 9 + 2 blocks > 10
|
||||
|
r := h.HandleCommand(cdb, make([]byte, 8192)) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("should fail for OOB") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadOOB(t *testing.T) { |
||||
|
dev := newMockDevice(10 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 10) // LBA 10 == total blocks
|
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("should fail for OOB") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testZeroLengthTransfer(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 0) // 0 blocks
|
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("zero-length read should succeed") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSyncCache(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiSyncCache10 |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("sync cache failed") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSyncCacheError(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
dev.syncErr = errors.New("disk error") |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiSyncCache10 |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("should fail") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testUnmapSingle(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
// Write data at LBA 5
|
||||
|
dev.blocks[5] = make([]byte, 4096) |
||||
|
dev.blocks[5][0] = 0xFF |
||||
|
|
||||
|
// UNMAP parameter list
|
||||
|
unmapData := make([]byte, 24) // 8 header + 16 descriptor
|
||||
|
binary.BigEndian.PutUint16(unmapData[0:2], 22) // data length
|
||||
|
binary.BigEndian.PutUint16(unmapData[2:4], 16) // block desc length
|
||||
|
binary.BigEndian.PutUint64(unmapData[8:16], 5) // LBA
|
||||
|
binary.BigEndian.PutUint32(unmapData[16:20], 1) // num blocks
|
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiUnmap |
||||
|
r := h.HandleCommand(cdb, unmapData) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("unmap failed") |
||||
|
} |
||||
|
if _, ok := dev.blocks[5]; ok { |
||||
|
t.Fatal("block 5 should be trimmed") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testUnmapMultipleDescriptors(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
dev.blocks[3] = make([]byte, 4096) |
||||
|
dev.blocks[7] = make([]byte, 4096) |
||||
|
|
||||
|
// 2 descriptors
|
||||
|
unmapData := make([]byte, 40) // 8 header + 2*16 descriptors
|
||||
|
binary.BigEndian.PutUint16(unmapData[0:2], 38) |
||||
|
binary.BigEndian.PutUint16(unmapData[2:4], 32) |
||||
|
// Descriptor 1: LBA=3, count=1
|
||||
|
binary.BigEndian.PutUint64(unmapData[8:16], 3) |
||||
|
binary.BigEndian.PutUint32(unmapData[16:20], 1) |
||||
|
// Descriptor 2: LBA=7, count=1
|
||||
|
binary.BigEndian.PutUint64(unmapData[24:32], 7) |
||||
|
binary.BigEndian.PutUint32(unmapData[32:36], 1) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiUnmap |
||||
|
r := h.HandleCommand(cdb, unmapData) |
||||
|
if r.Status != SCSIStatusGood { |
||||
|
t.Fatal("unmap failed") |
||||
|
} |
||||
|
if _, ok := dev.blocks[3]; ok { |
||||
|
t.Fatal("block 3 should be trimmed") |
||||
|
} |
||||
|
if _, ok := dev.blocks[7]; ok { |
||||
|
t.Fatal("block 7 should be trimmed") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testUnmapShortParam(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiUnmap |
||||
|
r := h.HandleCommand(cdb, []byte{1, 2, 3}) // too short
|
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("should fail for short unmap params") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testBuildSenseData(t *testing.T) { |
||||
|
data := BuildSenseData(SenseIllegalRequest, ASCInvalidOpcode, ASCQLuk) |
||||
|
if len(data) != 18 { |
||||
|
t.Fatalf("length: %d", len(data)) |
||||
|
} |
||||
|
if data[0] != 0x70 { |
||||
|
t.Fatal("response code wrong") |
||||
|
} |
||||
|
if data[2] != SenseIllegalRequest { |
||||
|
t.Fatal("sense key wrong") |
||||
|
} |
||||
|
if data[12] != ASCInvalidOpcode { |
||||
|
t.Fatal("ASC wrong") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadError(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
dev.readErr = errors.New("io error") |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) |
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("should fail") |
||||
|
} |
||||
|
if r.SenseKey != SenseMediumError { |
||||
|
t.Fatal("should be MEDIUM_ERROR") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWriteError(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
dev.writeErr = errors.New("io error") |
||||
|
h := NewSCSIHandler(dev) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) |
||||
|
r := h.HandleCommand(cdb, make([]byte, 4096)) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("should fail") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testReadCapacity16InvalidSA(t *testing.T) { |
||||
|
dev := newMockDevice(100 * 4096) |
||||
|
h := NewSCSIHandler(dev) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiReadCapacity16 |
||||
|
cdb[1] = 0x05 // wrong service action
|
||||
|
r := h.HandleCommand(cdb, nil) |
||||
|
if r.Status != SCSIStatusCheckCond { |
||||
|
t.Fatal("should fail for wrong SA") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,421 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"log" |
||||
|
"net" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrSessionClosed = errors.New("iscsi: session closed") |
||||
|
ErrCmdSNOutOfWindow = errors.New("iscsi: CmdSN out of window") |
||||
|
) |
||||
|
|
||||
|
// SessionState tracks the lifecycle of an iSCSI session.
|
||||
|
type SessionState int |
||||
|
|
||||
|
const ( |
||||
|
SessionLogin SessionState = iota // login phase
|
||||
|
SessionLoggedIn // full feature phase
|
||||
|
SessionLogout // logout requested
|
||||
|
SessionClosed // terminated
|
||||
|
) |
||||
|
|
||||
|
// Session manages a single iSCSI session (one initiator connection).
|
||||
|
type Session struct { |
||||
|
mu sync.Mutex |
||||
|
|
||||
|
state SessionState |
||||
|
conn net.Conn |
||||
|
scsi *SCSIHandler |
||||
|
config TargetConfig |
||||
|
resolver TargetResolver |
||||
|
devices DeviceLookup |
||||
|
|
||||
|
// Sequence numbers
|
||||
|
expCmdSN atomic.Uint32 // expected CmdSN from initiator
|
||||
|
maxCmdSN atomic.Uint32 // max CmdSN we allow
|
||||
|
statSN uint32 // target status sequence number
|
||||
|
|
||||
|
// Login state
|
||||
|
negotiator *LoginNegotiator |
||||
|
loginDone bool |
||||
|
|
||||
|
// Data sequencing
|
||||
|
dataInWriter *DataInWriter |
||||
|
|
||||
|
// Shutdown
|
||||
|
closed atomic.Bool |
||||
|
closeErr error |
||||
|
|
||||
|
// Logging
|
||||
|
logger *log.Logger |
||||
|
} |
||||
|
|
||||
|
// NewSession creates a new iSCSI session on the given connection.
|
||||
|
func NewSession(conn net.Conn, config TargetConfig, resolver TargetResolver, devices DeviceLookup, logger *log.Logger) *Session { |
||||
|
if logger == nil { |
||||
|
logger = log.Default() |
||||
|
} |
||||
|
s := &Session{ |
||||
|
state: SessionLogin, |
||||
|
conn: conn, |
||||
|
config: config, |
||||
|
resolver: resolver, |
||||
|
devices: devices, |
||||
|
negotiator: NewLoginNegotiator(config), |
||||
|
logger: logger, |
||||
|
} |
||||
|
s.expCmdSN.Store(1) |
||||
|
s.maxCmdSN.Store(32) // window of 32 commands
|
||||
|
return s |
||||
|
} |
||||
|
|
||||
|
// HandleConnection processes PDUs until the connection is closed or an error occurs.
|
||||
|
func (s *Session) HandleConnection() error { |
||||
|
defer s.close() |
||||
|
|
||||
|
for !s.closed.Load() { |
||||
|
pdu, err := ReadPDU(s.conn) |
||||
|
if err != nil { |
||||
|
if s.closed.Load() { |
||||
|
return nil |
||||
|
} |
||||
|
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { |
||||
|
return nil |
||||
|
} |
||||
|
return fmt.Errorf("read PDU: %w", err) |
||||
|
} |
||||
|
|
||||
|
if err := s.dispatch(pdu); err != nil { |
||||
|
if s.closed.Load() { |
||||
|
return nil |
||||
|
} |
||||
|
return fmt.Errorf("dispatch %s: %w", OpcodeName(pdu.Opcode()), err) |
||||
|
} |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Close terminates the session.
|
||||
|
func (s *Session) Close() error { |
||||
|
s.closed.Store(true) |
||||
|
return s.conn.Close() |
||||
|
} |
||||
|
|
||||
|
func (s *Session) close() { |
||||
|
s.closed.Store(true) |
||||
|
s.mu.Lock() |
||||
|
s.state = SessionClosed |
||||
|
s.mu.Unlock() |
||||
|
s.conn.Close() |
||||
|
} |
||||
|
|
||||
|
func (s *Session) dispatch(pdu *PDU) error { |
||||
|
op := pdu.Opcode() |
||||
|
|
||||
|
switch op { |
||||
|
case OpLoginReq: |
||||
|
return s.handleLogin(pdu) |
||||
|
case OpTextReq: |
||||
|
return s.handleText(pdu) |
||||
|
case OpSCSICmd: |
||||
|
return s.handleSCSICmd(pdu) |
||||
|
case OpSCSIDataOut: |
||||
|
// Handled inline during write command processing
|
||||
|
return nil |
||||
|
case OpNOPOut: |
||||
|
return s.handleNOPOut(pdu) |
||||
|
case OpLogoutReq: |
||||
|
return s.handleLogout(pdu) |
||||
|
case OpSCSITaskMgmt: |
||||
|
return s.handleTaskMgmt(pdu) |
||||
|
default: |
||||
|
s.logger.Printf("unhandled opcode: %s", OpcodeName(op)) |
||||
|
return s.sendReject(pdu, 0x04) // command not supported
|
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (s *Session) handleLogin(pdu *PDU) error { |
||||
|
s.mu.Lock() |
||||
|
defer s.mu.Unlock() |
||||
|
|
||||
|
resp := s.negotiator.HandleLoginPDU(pdu, s.resolver) |
||||
|
|
||||
|
// Set sequence numbers
|
||||
|
resp.SetStatSN(s.statSN) |
||||
|
s.statSN++ |
||||
|
resp.SetExpCmdSN(s.expCmdSN.Load()) |
||||
|
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
||||
|
|
||||
|
if err := WritePDU(s.conn, resp); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
if s.negotiator.Done() { |
||||
|
s.loginDone = true |
||||
|
s.state = SessionLoggedIn |
||||
|
result := s.negotiator.Result() |
||||
|
s.dataInWriter = NewDataInWriter(uint32(result.MaxRecvDataSegLen)) |
||||
|
|
||||
|
// Bind SCSI handler to the device for the target the initiator logged into
|
||||
|
if s.devices != nil && result.TargetName != "" { |
||||
|
dev := s.devices.LookupDevice(result.TargetName) |
||||
|
s.scsi = NewSCSIHandler(dev) |
||||
|
} |
||||
|
|
||||
|
s.logger.Printf("login complete: initiator=%s target=%s session=%s", |
||||
|
result.InitiatorName, result.TargetName, result.SessionType) |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (s *Session) handleText(pdu *PDU) error { |
||||
|
s.mu.Lock() |
||||
|
defer s.mu.Unlock() |
||||
|
|
||||
|
// Gather discovery targets from resolver if it supports listing
|
||||
|
var targets []DiscoveryTarget |
||||
|
if lister, ok := s.resolver.(TargetLister); ok { |
||||
|
targets = lister.ListTargets() |
||||
|
} |
||||
|
|
||||
|
resp := HandleTextRequest(pdu, targets) |
||||
|
resp.SetStatSN(s.statSN) |
||||
|
s.statSN++ |
||||
|
resp.SetExpCmdSN(s.expCmdSN.Load()) |
||||
|
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
||||
|
|
||||
|
return WritePDU(s.conn, resp) |
||||
|
} |
||||
|
|
||||
|
func (s *Session) handleSCSICmd(pdu *PDU) error { |
||||
|
if !s.loginDone { |
||||
|
return s.sendReject(pdu, 0x0b) // protocol error
|
||||
|
} |
||||
|
|
||||
|
cdb := pdu.CDB() |
||||
|
itt := pdu.InitiatorTaskTag() |
||||
|
flags := pdu.OpSpecific1() |
||||
|
|
||||
|
// Advance CmdSN
|
||||
|
if !pdu.Immediate() { |
||||
|
s.advanceCmdSN() |
||||
|
} |
||||
|
|
||||
|
isWrite := flags&FlagW != 0 |
||||
|
isRead := flags&FlagR != 0 |
||||
|
expectedLen := pdu.ExpectedDataTransferLength() |
||||
|
|
||||
|
// Handle write commands — collect data
|
||||
|
var dataOut []byte |
||||
|
if isWrite && expectedLen > 0 { |
||||
|
collector := NewDataOutCollector(expectedLen) |
||||
|
|
||||
|
// Immediate data
|
||||
|
if len(pdu.DataSegment) > 0 { |
||||
|
if err := collector.AddImmediateData(pdu.DataSegment); err != nil { |
||||
|
return s.sendCheckCondition(itt, SenseIllegalRequest, ASCInvalidFieldInCDB, ASCQLuk) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// If more data needed, send R2T and collect Data-Out PDUs
|
||||
|
if !collector.Done() { |
||||
|
if err := s.collectDataOut(collector, itt); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
dataOut = collector.Data() |
||||
|
} |
||||
|
|
||||
|
// Execute SCSI command
|
||||
|
result := s.scsi.HandleCommand(cdb, dataOut) |
||||
|
|
||||
|
// Send response
|
||||
|
s.mu.Lock() |
||||
|
expCmdSN := s.expCmdSN.Load() |
||||
|
maxCmdSN := s.maxCmdSN.Load() |
||||
|
s.mu.Unlock() |
||||
|
|
||||
|
if isRead && result.Status == SCSIStatusGood && len(result.Data) > 0 { |
||||
|
// Send Data-In PDUs
|
||||
|
s.mu.Lock() |
||||
|
_, err := s.dataInWriter.WriteDataIn(s.conn, result.Data, itt, expCmdSN, maxCmdSN, &s.statSN) |
||||
|
s.mu.Unlock() |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// Send SCSI Response
|
||||
|
s.mu.Lock() |
||||
|
err := SendSCSIResponse(s.conn, result, itt, &s.statSN, expCmdSN, maxCmdSN) |
||||
|
s.mu.Unlock() |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
func (s *Session) collectDataOut(collector *DataOutCollector, itt uint32) error { |
||||
|
var r2tSN uint32 |
||||
|
ttt := itt // use ITT as TTT for simplicity
|
||||
|
|
||||
|
for !collector.Done() { |
||||
|
// Send R2T
|
||||
|
s.mu.Lock() |
||||
|
r2t := BuildR2T(itt, ttt, r2tSN, s.totalReceived(collector), collector.Remaining(), |
||||
|
s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load()) |
||||
|
s.mu.Unlock() |
||||
|
|
||||
|
if err := WritePDU(s.conn, r2t); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
r2tSN++ |
||||
|
|
||||
|
// Read Data-Out PDUs until F-bit
|
||||
|
for { |
||||
|
doPDU, err := ReadPDU(s.conn) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
if doPDU.Opcode() != OpSCSIDataOut { |
||||
|
return fmt.Errorf("expected Data-Out, got %s", OpcodeName(doPDU.Opcode())) |
||||
|
} |
||||
|
if err := collector.AddDataOut(doPDU); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
if doPDU.OpSpecific1()&FlagF != 0 { |
||||
|
break |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (s *Session) totalReceived(c *DataOutCollector) uint32 { |
||||
|
return c.expectedLen - c.Remaining() |
||||
|
} |
||||
|
|
||||
|
func (s *Session) handleNOPOut(pdu *PDU) error { |
||||
|
resp := &PDU{} |
||||
|
resp.SetOpcode(OpNOPIn) |
||||
|
resp.SetOpSpecific1(FlagF) |
||||
|
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) |
||||
|
resp.SetTargetTransferTag(0xFFFFFFFF) |
||||
|
|
||||
|
s.mu.Lock() |
||||
|
resp.SetStatSN(s.statSN) |
||||
|
s.statSN++ |
||||
|
resp.SetExpCmdSN(s.expCmdSN.Load()) |
||||
|
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
||||
|
s.mu.Unlock() |
||||
|
|
||||
|
// Echo back data if present
|
||||
|
if len(pdu.DataSegment) > 0 { |
||||
|
resp.DataSegment = pdu.DataSegment |
||||
|
} |
||||
|
|
||||
|
return WritePDU(s.conn, resp) |
||||
|
} |
||||
|
|
||||
|
func (s *Session) handleLogout(pdu *PDU) error { |
||||
|
s.mu.Lock() |
||||
|
defer s.mu.Unlock() |
||||
|
|
||||
|
s.state = SessionLogout |
||||
|
|
||||
|
resp := &PDU{} |
||||
|
resp.SetOpcode(OpLogoutResp) |
||||
|
resp.SetOpSpecific1(FlagF) |
||||
|
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) |
||||
|
resp.BHS[2] = 0x00 // response: connection/session closed successfully
|
||||
|
resp.SetStatSN(s.statSN) |
||||
|
s.statSN++ |
||||
|
resp.SetExpCmdSN(s.expCmdSN.Load()) |
||||
|
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
||||
|
|
||||
|
if err := WritePDU(s.conn, resp); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// Signal HandleConnection to exit
|
||||
|
s.closed.Store(true) |
||||
|
s.conn.Close() |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (s *Session) handleTaskMgmt(pdu *PDU) error { |
||||
|
s.mu.Lock() |
||||
|
defer s.mu.Unlock() |
||||
|
|
||||
|
// Simplified: always respond with "function complete"
|
||||
|
resp := &PDU{} |
||||
|
resp.SetOpcode(OpSCSITaskResp) |
||||
|
resp.SetOpSpecific1(FlagF) |
||||
|
resp.SetInitiatorTaskTag(pdu.InitiatorTaskTag()) |
||||
|
resp.BHS[2] = 0x00 // function complete
|
||||
|
resp.SetStatSN(s.statSN) |
||||
|
s.statSN++ |
||||
|
resp.SetExpCmdSN(s.expCmdSN.Load()) |
||||
|
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
||||
|
|
||||
|
return WritePDU(s.conn, resp) |
||||
|
} |
||||
|
|
||||
|
func (s *Session) advanceCmdSN() { |
||||
|
s.expCmdSN.Add(1) |
||||
|
s.maxCmdSN.Add(1) |
||||
|
} |
||||
|
|
||||
|
func (s *Session) sendReject(origPDU *PDU, reason uint8) error { |
||||
|
resp := &PDU{} |
||||
|
resp.SetOpcode(OpReject) |
||||
|
resp.SetOpSpecific1(FlagF) |
||||
|
resp.BHS[2] = reason |
||||
|
resp.SetInitiatorTaskTag(0xFFFFFFFF) |
||||
|
|
||||
|
s.mu.Lock() |
||||
|
resp.SetStatSN(s.statSN) |
||||
|
s.statSN++ |
||||
|
resp.SetExpCmdSN(s.expCmdSN.Load()) |
||||
|
resp.SetMaxCmdSN(s.maxCmdSN.Load()) |
||||
|
s.mu.Unlock() |
||||
|
|
||||
|
// Include the rejected BHS in the data segment
|
||||
|
resp.DataSegment = origPDU.BHS[:] |
||||
|
|
||||
|
return WritePDU(s.conn, resp) |
||||
|
} |
||||
|
|
||||
|
func (s *Session) sendCheckCondition(itt uint32, senseKey, asc, ascq uint8) error { |
||||
|
result := SCSIResult{ |
||||
|
Status: SCSIStatusCheckCond, |
||||
|
SenseKey: senseKey, |
||||
|
SenseASC: asc, |
||||
|
SenseASCQ: ascq, |
||||
|
} |
||||
|
s.mu.Lock() |
||||
|
defer s.mu.Unlock() |
||||
|
return SendSCSIResponse(s.conn, result, itt, &s.statSN, s.expCmdSN.Load(), s.maxCmdSN.Load()) |
||||
|
} |
||||
|
|
||||
|
// TargetLister is an optional interface that TargetResolver can implement
|
||||
|
// to support SendTargets discovery.
|
||||
|
type TargetLister interface { |
||||
|
ListTargets() []DiscoveryTarget |
||||
|
} |
||||
|
|
||||
|
// DeviceLookup resolves a target IQN to a BlockDevice.
|
||||
|
// Used after login to bind the session to the correct volume.
|
||||
|
type DeviceLookup interface { |
||||
|
LookupDevice(iqn string) BlockDevice |
||||
|
} |
||||
|
|
||||
|
// State returns the current session state.
|
||||
|
func (s *Session) State() SessionState { |
||||
|
s.mu.Lock() |
||||
|
defer s.mu.Unlock() |
||||
|
return s.state |
||||
|
} |
||||
@ -0,0 +1,364 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"encoding/binary" |
||||
|
"io" |
||||
|
"log" |
||||
|
"net" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// testResolver implements TargetResolver, TargetLister, and DeviceLookup.
|
||||
|
type testResolver struct { |
||||
|
targets []DiscoveryTarget |
||||
|
dev BlockDevice |
||||
|
} |
||||
|
|
||||
|
func (r *testResolver) HasTarget(name string) bool { |
||||
|
for _, t := range r.targets { |
||||
|
if t.Name == name { |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
func (r *testResolver) ListTargets() []DiscoveryTarget { |
||||
|
return r.targets |
||||
|
} |
||||
|
|
||||
|
func (r *testResolver) LookupDevice(iqn string) BlockDevice { |
||||
|
if r.dev != nil { |
||||
|
return r.dev |
||||
|
} |
||||
|
return &nullDevice{} |
||||
|
} |
||||
|
|
||||
|
func newTestResolver() *testResolver { |
||||
|
return newTestResolverWithDevice(nil) |
||||
|
} |
||||
|
|
||||
|
func newTestResolverWithDevice(dev BlockDevice) *testResolver { |
||||
|
return &testResolver{ |
||||
|
targets: []DiscoveryTarget{ |
||||
|
{Name: testTargetName, Address: "127.0.0.1:3260,1"}, |
||||
|
}, |
||||
|
dev: dev, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestSession(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"login_and_read", testLoginAndRead}, |
||||
|
{"login_and_write", testLoginAndWrite}, |
||||
|
{"nop_ping", testNOPPing}, |
||||
|
{"logout", testLogout}, |
||||
|
{"discovery_session", testDiscoverySession}, |
||||
|
{"task_mgmt", testTaskMgmt}, |
||||
|
{"reject_scsi_before_login", testRejectSCSIBeforeLogin}, |
||||
|
{"connection_close", testConnectionClose}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
type sessionTestEnv struct { |
||||
|
clientConn net.Conn |
||||
|
session *Session |
||||
|
done chan error |
||||
|
} |
||||
|
|
||||
|
func setupSession(t *testing.T) *sessionTestEnv { |
||||
|
t.Helper() |
||||
|
client, server := net.Pipe() |
||||
|
|
||||
|
dev := newMockDevice(1024 * 4096) // 1024 blocks
|
||||
|
config := DefaultTargetConfig() |
||||
|
config.TargetName = testTargetName |
||||
|
resolver := newTestResolverWithDevice(dev) |
||||
|
logger := log.New(io.Discard, "", 0) |
||||
|
|
||||
|
sess := NewSession(server, config, resolver, resolver, logger) |
||||
|
done := make(chan error, 1) |
||||
|
go func() { |
||||
|
done <- sess.HandleConnection() |
||||
|
}() |
||||
|
|
||||
|
t.Cleanup(func() { |
||||
|
sess.Close() |
||||
|
client.Close() |
||||
|
}) |
||||
|
|
||||
|
return &sessionTestEnv{clientConn: client, session: sess, done: done} |
||||
|
} |
||||
|
|
||||
|
func doLogin(t *testing.T, conn net.Conn) { |
||||
|
t.Helper() |
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
params.Set("SessionType", "Normal") |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
req.SetCmdSN(1) |
||||
|
|
||||
|
if err := WritePDU(conn, req); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := ReadPDU(conn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpLoginResp { |
||||
|
t.Fatalf("expected LoginResp, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLoginAndRead(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
doLogin(t, env.clientConn) |
||||
|
|
||||
|
// Send SCSI READ_10 for 1 block at LBA 0
|
||||
|
cmd := &PDU{} |
||||
|
cmd.SetOpcode(OpSCSICmd) |
||||
|
cmd.SetOpSpecific1(FlagF | FlagR) // Final + Read
|
||||
|
cmd.SetInitiatorTaskTag(1) |
||||
|
cmd.SetExpectedDataTransferLength(4096) |
||||
|
cmd.SetCmdSN(2) |
||||
|
cmd.SetExpStatSN(2) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 0) // LBA=0
|
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) // 1 block
|
||||
|
cmd.SetCDB(cdb) |
||||
|
|
||||
|
if err := WritePDU(env.clientConn, cmd); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
// Expect Data-In with S-bit (single PDU, 4096 bytes of zeros)
|
||||
|
resp, err := ReadPDU(env.clientConn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpSCSIDataIn { |
||||
|
t.Fatalf("expected Data-In, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
if resp.OpSpecific1()&FlagS == 0 { |
||||
|
t.Fatal("S-bit not set") |
||||
|
} |
||||
|
if len(resp.DataSegment) != 4096 { |
||||
|
t.Fatalf("data length: %d", len(resp.DataSegment)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLoginAndWrite(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
doLogin(t, env.clientConn) |
||||
|
|
||||
|
// SCSI WRITE_10: 1 block at LBA 5 with immediate data
|
||||
|
cmd := &PDU{} |
||||
|
cmd.SetOpcode(OpSCSICmd) |
||||
|
cmd.SetOpSpecific1(FlagF | FlagW) // Final + Write
|
||||
|
cmd.SetInitiatorTaskTag(1) |
||||
|
cmd.SetExpectedDataTransferLength(4096) |
||||
|
cmd.SetCmdSN(2) |
||||
|
|
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 5) |
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) |
||||
|
cmd.SetCDB(cdb) |
||||
|
|
||||
|
// Include immediate data
|
||||
|
cmd.DataSegment = bytes.Repeat([]byte{0xAA}, 4096) |
||||
|
|
||||
|
if err := WritePDU(env.clientConn, cmd); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
// Expect SCSI Response (good)
|
||||
|
resp, err := ReadPDU(env.clientConn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpSCSIResp { |
||||
|
t.Fatalf("expected SCSI Response, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
if resp.SCSIStatus() != SCSIStatusGood { |
||||
|
t.Fatalf("status: %d", resp.SCSIStatus()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testNOPPing(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
doLogin(t, env.clientConn) |
||||
|
|
||||
|
// Send NOP-Out
|
||||
|
nop := &PDU{} |
||||
|
nop.SetOpcode(OpNOPOut) |
||||
|
nop.SetOpSpecific1(FlagF) |
||||
|
nop.SetInitiatorTaskTag(0x9999) |
||||
|
nop.SetImmediate(true) |
||||
|
nop.DataSegment = []byte("ping") |
||||
|
|
||||
|
if err := WritePDU(env.clientConn, nop); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := ReadPDU(env.clientConn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpNOPIn { |
||||
|
t.Fatalf("expected NOP-In, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
if resp.InitiatorTaskTag() != 0x9999 { |
||||
|
t.Fatal("ITT mismatch") |
||||
|
} |
||||
|
if string(resp.DataSegment) != "ping" { |
||||
|
t.Fatalf("echo data: %q", resp.DataSegment) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLogout(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
doLogin(t, env.clientConn) |
||||
|
|
||||
|
// Send Logout
|
||||
|
logout := &PDU{} |
||||
|
logout.SetOpcode(OpLogoutReq) |
||||
|
logout.SetOpSpecific1(FlagF) |
||||
|
logout.SetInitiatorTaskTag(0xAAAA) |
||||
|
logout.SetCmdSN(2) |
||||
|
|
||||
|
if err := WritePDU(env.clientConn, logout); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := ReadPDU(env.clientConn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpLogoutResp { |
||||
|
t.Fatalf("expected Logout Response, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
|
||||
|
// After logout, the connection should be closed by the target.
|
||||
|
// Verify by trying to read — should get EOF.
|
||||
|
_, err = ReadPDU(env.clientConn) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected EOF after logout") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDiscoverySession(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
|
||||
|
// Login as discovery session
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("SessionType", "Discovery") |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
req.SetCmdSN(1) |
||||
|
WritePDU(env.clientConn, req) |
||||
|
ReadPDU(env.clientConn) // login resp
|
||||
|
|
||||
|
// Send SendTargets=All
|
||||
|
textParams := NewParams() |
||||
|
textParams.Set("SendTargets", "All") |
||||
|
textReq := makeTextReq(textParams) |
||||
|
textReq.SetCmdSN(2) |
||||
|
|
||||
|
if err := WritePDU(env.clientConn, textReq); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := ReadPDU(env.clientConn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpTextResp { |
||||
|
t.Fatalf("expected Text Response, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
body := string(resp.DataSegment) |
||||
|
if len(body) == 0 { |
||||
|
t.Fatal("empty discovery response") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTaskMgmt(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
doLogin(t, env.clientConn) |
||||
|
|
||||
|
tm := &PDU{} |
||||
|
tm.SetOpcode(OpSCSITaskMgmt) |
||||
|
tm.SetOpSpecific1(FlagF | 0x01) // ABORT TASK
|
||||
|
tm.SetInitiatorTaskTag(0xBBBB) |
||||
|
tm.SetImmediate(true) |
||||
|
|
||||
|
if err := WritePDU(env.clientConn, tm); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := ReadPDU(env.clientConn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpSCSITaskResp { |
||||
|
t.Fatalf("expected Task Mgmt Response, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRejectSCSIBeforeLogin(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
|
||||
|
// Send SCSI command without login
|
||||
|
cmd := &PDU{} |
||||
|
cmd.SetOpcode(OpSCSICmd) |
||||
|
cmd.SetOpSpecific1(FlagF | FlagR) |
||||
|
cmd.SetInitiatorTaskTag(1) |
||||
|
|
||||
|
if err := WritePDU(env.clientConn, cmd); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := ReadPDU(env.clientConn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.Opcode() != OpReject { |
||||
|
t.Fatalf("expected Reject, got %s", OpcodeName(resp.Opcode())) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testConnectionClose(t *testing.T) { |
||||
|
env := setupSession(t) |
||||
|
doLogin(t, env.clientConn) |
||||
|
|
||||
|
// Close client side
|
||||
|
env.clientConn.Close() |
||||
|
|
||||
|
select { |
||||
|
case err := <-env.done: |
||||
|
if err != nil { |
||||
|
t.Fatalf("unexpected error on clean close: %v", err) |
||||
|
} |
||||
|
case <-time.After(2 * time.Second): |
||||
|
t.Fatal("session did not detect close") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,216 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"log" |
||||
|
"net" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrTargetClosed = errors.New("iscsi: target server closed") |
||||
|
ErrVolumeNotFound = errors.New("iscsi: volume not found") |
||||
|
) |
||||
|
|
||||
|
// TargetServer manages the iSCSI target: TCP listener, volume registry,
|
||||
|
// and active sessions.
|
||||
|
type TargetServer struct { |
||||
|
mu sync.RWMutex |
||||
|
listener net.Listener |
||||
|
config TargetConfig |
||||
|
volumes map[string]BlockDevice // target IQN -> device
|
||||
|
addr string |
||||
|
|
||||
|
// Active session tracking for graceful shutdown
|
||||
|
activeMu sync.Mutex |
||||
|
active map[uint64]*Session |
||||
|
nextID atomic.Uint64 |
||||
|
|
||||
|
sessions sync.WaitGroup |
||||
|
closed chan struct{} |
||||
|
logger *log.Logger |
||||
|
} |
||||
|
|
||||
|
// NewTargetServer creates a target server bound to the given address.
|
||||
|
func NewTargetServer(addr string, config TargetConfig, logger *log.Logger) *TargetServer { |
||||
|
if logger == nil { |
||||
|
logger = log.Default() |
||||
|
} |
||||
|
return &TargetServer{ |
||||
|
config: config, |
||||
|
volumes: make(map[string]BlockDevice), |
||||
|
active: make(map[uint64]*Session), |
||||
|
addr: addr, |
||||
|
closed: make(chan struct{}), |
||||
|
logger: logger, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// AddVolume registers a block device under the given target IQN.
|
||||
|
func (ts *TargetServer) AddVolume(iqn string, dev BlockDevice) { |
||||
|
ts.mu.Lock() |
||||
|
defer ts.mu.Unlock() |
||||
|
ts.volumes[iqn] = dev |
||||
|
ts.logger.Printf("volume added: %s", iqn) |
||||
|
} |
||||
|
|
||||
|
// RemoveVolume unregisters a target IQN.
|
||||
|
func (ts *TargetServer) RemoveVolume(iqn string) { |
||||
|
ts.mu.Lock() |
||||
|
defer ts.mu.Unlock() |
||||
|
delete(ts.volumes, iqn) |
||||
|
ts.logger.Printf("volume removed: %s", iqn) |
||||
|
} |
||||
|
|
||||
|
// HasTarget implements TargetResolver.
|
||||
|
func (ts *TargetServer) HasTarget(name string) bool { |
||||
|
ts.mu.RLock() |
||||
|
defer ts.mu.RUnlock() |
||||
|
_, ok := ts.volumes[name] |
||||
|
return ok |
||||
|
} |
||||
|
|
||||
|
// ListTargets implements TargetLister.
|
||||
|
func (ts *TargetServer) ListTargets() []DiscoveryTarget { |
||||
|
ts.mu.RLock() |
||||
|
defer ts.mu.RUnlock() |
||||
|
targets := make([]DiscoveryTarget, 0, len(ts.volumes)) |
||||
|
for iqn := range ts.volumes { |
||||
|
targets = append(targets, DiscoveryTarget{ |
||||
|
Name: iqn, |
||||
|
Address: ts.ListenAddr(), |
||||
|
}) |
||||
|
} |
||||
|
return targets |
||||
|
} |
||||
|
|
||||
|
// ListenAndServe starts listening for iSCSI connections.
|
||||
|
// Blocks until Close() is called or an error occurs.
|
||||
|
func (ts *TargetServer) ListenAndServe() error { |
||||
|
ln, err := net.Listen("tcp", ts.addr) |
||||
|
if err != nil { |
||||
|
return fmt.Errorf("iscsi target: listen: %w", err) |
||||
|
} |
||||
|
ts.mu.Lock() |
||||
|
ts.listener = ln |
||||
|
ts.mu.Unlock() |
||||
|
|
||||
|
ts.logger.Printf("iSCSI target listening on %s", ln.Addr()) |
||||
|
return ts.acceptLoop(ln) |
||||
|
} |
||||
|
|
||||
|
// Serve accepts connections on an existing listener.
|
||||
|
func (ts *TargetServer) Serve(ln net.Listener) error { |
||||
|
ts.mu.Lock() |
||||
|
ts.listener = ln |
||||
|
ts.mu.Unlock() |
||||
|
return ts.acceptLoop(ln) |
||||
|
} |
||||
|
|
||||
|
func (ts *TargetServer) acceptLoop(ln net.Listener) error { |
||||
|
for { |
||||
|
conn, err := ln.Accept() |
||||
|
if err != nil { |
||||
|
select { |
||||
|
case <-ts.closed: |
||||
|
return nil |
||||
|
default: |
||||
|
return fmt.Errorf("iscsi target: accept: %w", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
ts.sessions.Add(1) |
||||
|
go ts.handleConn(conn) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (ts *TargetServer) handleConn(conn net.Conn) { |
||||
|
defer ts.sessions.Done() |
||||
|
defer conn.Close() |
||||
|
|
||||
|
ts.logger.Printf("new connection from %s", conn.RemoteAddr()) |
||||
|
|
||||
|
sess := NewSession(conn, ts.config, ts, ts, ts.logger) |
||||
|
|
||||
|
id := ts.nextID.Add(1) |
||||
|
ts.activeMu.Lock() |
||||
|
ts.active[id] = sess |
||||
|
ts.activeMu.Unlock() |
||||
|
|
||||
|
defer func() { |
||||
|
ts.activeMu.Lock() |
||||
|
delete(ts.active, id) |
||||
|
ts.activeMu.Unlock() |
||||
|
}() |
||||
|
|
||||
|
if err := sess.HandleConnection(); err != nil { |
||||
|
ts.logger.Printf("session error (%s): %v", conn.RemoteAddr(), err) |
||||
|
} |
||||
|
|
||||
|
ts.logger.Printf("connection closed: %s", conn.RemoteAddr()) |
||||
|
} |
||||
|
|
||||
|
// LookupDevice implements DeviceLookup. Returns the BlockDevice for the given IQN,
|
||||
|
// or a nullDevice if not found.
|
||||
|
func (ts *TargetServer) LookupDevice(iqn string) BlockDevice { |
||||
|
ts.mu.RLock() |
||||
|
defer ts.mu.RUnlock() |
||||
|
|
||||
|
if dev, ok := ts.volumes[iqn]; ok { |
||||
|
return dev |
||||
|
} |
||||
|
return &nullDevice{} |
||||
|
} |
||||
|
|
||||
|
// Close gracefully shuts down the target server.
|
||||
|
func (ts *TargetServer) Close() error { |
||||
|
select { |
||||
|
case <-ts.closed: |
||||
|
return nil |
||||
|
default: |
||||
|
} |
||||
|
close(ts.closed) |
||||
|
|
||||
|
ts.mu.RLock() |
||||
|
ln := ts.listener |
||||
|
ts.mu.RUnlock() |
||||
|
|
||||
|
if ln != nil { |
||||
|
ln.Close() |
||||
|
} |
||||
|
|
||||
|
// Close all active sessions to unblock ReadPDU
|
||||
|
ts.activeMu.Lock() |
||||
|
for _, sess := range ts.active { |
||||
|
sess.Close() |
||||
|
} |
||||
|
ts.activeMu.Unlock() |
||||
|
|
||||
|
ts.sessions.Wait() |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// ListenAddr returns the actual listen address (useful when port=0).
|
||||
|
func (ts *TargetServer) ListenAddr() string { |
||||
|
ts.mu.RLock() |
||||
|
defer ts.mu.RUnlock() |
||||
|
if ts.listener != nil { |
||||
|
return ts.listener.Addr().String() |
||||
|
} |
||||
|
return ts.addr |
||||
|
} |
||||
|
|
||||
|
// nullDevice is a stub BlockDevice used when no volumes are registered.
|
||||
|
type nullDevice struct{} |
||||
|
|
||||
|
func (d *nullDevice) ReadAt(lba uint64, length uint32) ([]byte, error) { |
||||
|
return make([]byte, length), nil |
||||
|
} |
||||
|
func (d *nullDevice) WriteAt(lba uint64, data []byte) error { return nil } |
||||
|
func (d *nullDevice) Trim(lba uint64, length uint32) error { return nil } |
||||
|
func (d *nullDevice) SyncCache() error { return nil } |
||||
|
func (d *nullDevice) BlockSize() uint32 { return 4096 } |
||||
|
func (d *nullDevice) VolumeSize() uint64 { return 0 } |
||||
|
func (d *nullDevice) IsHealthy() bool { return false } |
||||
@ -0,0 +1,259 @@ |
|||||
|
package iscsi |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"io" |
||||
|
"log" |
||||
|
"net" |
||||
|
"strings" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
func TestTarget(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{"listen_and_connect", testListenAndConnect}, |
||||
|
{"discovery_via_target", testDiscoveryViaTarget}, |
||||
|
{"login_read_write", testTargetLoginReadWrite}, |
||||
|
{"graceful_shutdown", testGracefulShutdown}, |
||||
|
{"add_remove_volume", testAddRemoveVolume}, |
||||
|
{"multiple_connections", testMultipleConnections}, |
||||
|
{"connect_no_volumes", testConnectNoVolumes}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func setupTarget(t *testing.T) (*TargetServer, string) { |
||||
|
t.Helper() |
||||
|
config := DefaultTargetConfig() |
||||
|
config.TargetName = testTargetName |
||||
|
logger := log.New(io.Discard, "", 0) |
||||
|
ts := NewTargetServer("127.0.0.1:0", config, logger) |
||||
|
|
||||
|
dev := newMockDevice(256 * 4096) // 256 blocks = 1MB
|
||||
|
ts.AddVolume(testTargetName, dev) |
||||
|
|
||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0") |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
go ts.Serve(ln) |
||||
|
|
||||
|
addr := ln.Addr().String() |
||||
|
t.Cleanup(func() { |
||||
|
ts.Close() |
||||
|
}) |
||||
|
|
||||
|
return ts, addr |
||||
|
} |
||||
|
|
||||
|
func dialTarget(t *testing.T, addr string) net.Conn { |
||||
|
t.Helper() |
||||
|
conn, err := net.DialTimeout("tcp", addr, 2*time.Second) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
t.Cleanup(func() { conn.Close() }) |
||||
|
return conn |
||||
|
} |
||||
|
|
||||
|
func loginToTarget(t *testing.T, conn net.Conn) { |
||||
|
t.Helper() |
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("TargetName", testTargetName) |
||||
|
params.Set("SessionType", "Normal") |
||||
|
|
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
req.SetCmdSN(1) |
||||
|
|
||||
|
if err := WritePDU(conn, req); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
resp, err := ReadPDU(conn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatalf("login failed: %d/%d", resp.LoginStatusClass(), resp.LoginStatusDetail()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testListenAndConnect(t *testing.T) { |
||||
|
_, addr := setupTarget(t) |
||||
|
|
||||
|
conn := dialTarget(t, addr) |
||||
|
loginToTarget(t, conn) |
||||
|
} |
||||
|
|
||||
|
func testDiscoveryViaTarget(t *testing.T) { |
||||
|
_, addr := setupTarget(t) |
||||
|
|
||||
|
conn := dialTarget(t, addr) |
||||
|
|
||||
|
// Login as discovery
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("SessionType", "Discovery") |
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
req.SetCmdSN(1) |
||||
|
WritePDU(conn, req) |
||||
|
ReadPDU(conn) |
||||
|
|
||||
|
// SendTargets=All
|
||||
|
textParams := NewParams() |
||||
|
textParams.Set("SendTargets", "All") |
||||
|
textReq := makeTextReq(textParams) |
||||
|
textReq.SetCmdSN(2) |
||||
|
WritePDU(conn, textReq) |
||||
|
|
||||
|
resp, err := ReadPDU(conn) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
body := string(resp.DataSegment) |
||||
|
if !strings.Contains(body, testTargetName) { |
||||
|
t.Fatalf("discovery response missing target: %q", body) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testTargetLoginReadWrite(t *testing.T) { |
||||
|
_, addr := setupTarget(t) |
||||
|
conn := dialTarget(t, addr) |
||||
|
loginToTarget(t, conn) |
||||
|
|
||||
|
// Write 1 block at LBA 0
|
||||
|
cmd := &PDU{} |
||||
|
cmd.SetOpcode(OpSCSICmd) |
||||
|
cmd.SetOpSpecific1(FlagF | FlagW) |
||||
|
cmd.SetInitiatorTaskTag(1) |
||||
|
cmd.SetExpectedDataTransferLength(4096) |
||||
|
cmd.SetCmdSN(2) |
||||
|
var cdb [16]byte |
||||
|
cdb[0] = ScsiWrite10 |
||||
|
binary.BigEndian.PutUint32(cdb[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(cdb[7:9], 1) |
||||
|
cmd.SetCDB(cdb) |
||||
|
data := make([]byte, 4096) |
||||
|
for i := range data { |
||||
|
data[i] = 0xAB |
||||
|
} |
||||
|
cmd.DataSegment = data |
||||
|
|
||||
|
WritePDU(conn, cmd) |
||||
|
resp, _ := ReadPDU(conn) |
||||
|
if resp.SCSIStatus() != SCSIStatusGood { |
||||
|
t.Fatalf("write failed: %d", resp.SCSIStatus()) |
||||
|
} |
||||
|
|
||||
|
// Read it back
|
||||
|
cmd2 := &PDU{} |
||||
|
cmd2.SetOpcode(OpSCSICmd) |
||||
|
cmd2.SetOpSpecific1(FlagF | FlagR) |
||||
|
cmd2.SetInitiatorTaskTag(2) |
||||
|
cmd2.SetExpectedDataTransferLength(4096) |
||||
|
cmd2.SetCmdSN(3) |
||||
|
var cdb2 [16]byte |
||||
|
cdb2[0] = ScsiRead10 |
||||
|
binary.BigEndian.PutUint32(cdb2[2:6], 0) |
||||
|
binary.BigEndian.PutUint16(cdb2[7:9], 1) |
||||
|
cmd2.SetCDB(cdb2) |
||||
|
|
||||
|
WritePDU(conn, cmd2) |
||||
|
resp2, _ := ReadPDU(conn) |
||||
|
if resp2.Opcode() != OpSCSIDataIn { |
||||
|
t.Fatalf("expected Data-In, got %s", OpcodeName(resp2.Opcode())) |
||||
|
} |
||||
|
if resp2.DataSegment[0] != 0xAB { |
||||
|
t.Fatal("data mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testGracefulShutdown(t *testing.T) { |
||||
|
ts, addr := setupTarget(t) |
||||
|
|
||||
|
conn := dialTarget(t, addr) |
||||
|
loginToTarget(t, conn) |
||||
|
|
||||
|
// Close the target — should shut down cleanly
|
||||
|
ts.Close() |
||||
|
|
||||
|
// Connection should be dropped
|
||||
|
_, err := ReadPDU(conn) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected error after shutdown") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testAddRemoveVolume(t *testing.T) { |
||||
|
config := DefaultTargetConfig() |
||||
|
logger := log.New(io.Discard, "", 0) |
||||
|
ts := NewTargetServer("127.0.0.1:0", config, logger) |
||||
|
|
||||
|
ts.AddVolume("iqn.2024.com.test:v1", newMockDevice(1024*4096)) |
||||
|
ts.AddVolume("iqn.2024.com.test:v2", newMockDevice(1024*4096)) |
||||
|
|
||||
|
if !ts.HasTarget("iqn.2024.com.test:v1") { |
||||
|
t.Fatal("v1 should exist") |
||||
|
} |
||||
|
|
||||
|
ts.RemoveVolume("iqn.2024.com.test:v1") |
||||
|
if ts.HasTarget("iqn.2024.com.test:v1") { |
||||
|
t.Fatal("v1 should be removed") |
||||
|
} |
||||
|
if !ts.HasTarget("iqn.2024.com.test:v2") { |
||||
|
t.Fatal("v2 should still exist") |
||||
|
} |
||||
|
|
||||
|
targets := ts.ListTargets() |
||||
|
if len(targets) != 1 { |
||||
|
t.Fatalf("expected 1 target, got %d", len(targets)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testMultipleConnections(t *testing.T) { |
||||
|
_, addr := setupTarget(t) |
||||
|
|
||||
|
// Connect multiple clients
|
||||
|
for i := 0; i < 5; i++ { |
||||
|
conn := dialTarget(t, addr) |
||||
|
loginToTarget(t, conn) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testConnectNoVolumes(t *testing.T) { |
||||
|
config := DefaultTargetConfig() |
||||
|
logger := log.New(io.Discard, "", 0) |
||||
|
ts := NewTargetServer("127.0.0.1:0", config, logger) |
||||
|
// No volumes added
|
||||
|
|
||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0") |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
go ts.Serve(ln) |
||||
|
t.Cleanup(func() { ts.Close() }) |
||||
|
|
||||
|
conn := dialTarget(t, ln.Addr().String()) |
||||
|
|
||||
|
// Login should work (discovery is fine without volumes)
|
||||
|
params := NewParams() |
||||
|
params.Set("InitiatorName", testInitiatorName) |
||||
|
params.Set("SessionType", "Discovery") |
||||
|
req := makeLoginReq(StageSecurityNeg, StageFullFeature, true, params) |
||||
|
req.SetCmdSN(1) |
||||
|
WritePDU(conn, req) |
||||
|
resp, _ := ReadPDU(conn) |
||||
|
if resp.LoginStatusClass() != LoginStatusSuccess { |
||||
|
t.Fatal("discovery login should succeed without volumes") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,38 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrLBAOutOfBounds = errors.New("blockvol: LBA out of bounds") |
||||
|
ErrWritePastEnd = errors.New("blockvol: write extends past volume end") |
||||
|
ErrAlignment = errors.New("blockvol: data length not aligned to block size") |
||||
|
) |
||||
|
|
||||
|
// ValidateLBA checks that lba is within the volume's logical address space.
|
||||
|
func ValidateLBA(lba uint64, volumeSize uint64, blockSize uint32) error { |
||||
|
maxLBA := volumeSize / uint64(blockSize) |
||||
|
if lba >= maxLBA { |
||||
|
return fmt.Errorf("%w: lba=%d, max=%d", ErrLBAOutOfBounds, lba, maxLBA-1) |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// ValidateWrite checks that a write at lba with dataLen bytes fits within
|
||||
|
// the volume and that dataLen is aligned to blockSize.
|
||||
|
func ValidateWrite(lba uint64, dataLen uint32, volumeSize uint64, blockSize uint32) error { |
||||
|
if err := ValidateLBA(lba, volumeSize, blockSize); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
if dataLen%blockSize != 0 { |
||||
|
return fmt.Errorf("%w: dataLen=%d, blockSize=%d", ErrAlignment, dataLen, blockSize) |
||||
|
} |
||||
|
blocksNeeded := uint64(dataLen / blockSize) |
||||
|
maxLBA := volumeSize / uint64(blockSize) |
||||
|
if lba+blocksNeeded > maxLBA { |
||||
|
return fmt.Errorf("%w: lba=%d, blocks=%d, max=%d", ErrWritePastEnd, lba, blocksNeeded, maxLBA) |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
@ -0,0 +1,95 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestLBAValidation(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "lba_within_bounds", run: testLBAWithinBounds}, |
||||
|
{name: "lba_last_block", run: testLBALastBlock}, |
||||
|
{name: "lba_out_of_bounds", run: testLBAOutOfBounds}, |
||||
|
{name: "lba_write_spans_end", run: testLBAWriteSpansEnd}, |
||||
|
{name: "lba_alignment", run: testLBAAlignment}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
const ( |
||||
|
testVolSize = 100 * 1024 * 1024 * 1024 // 100GB
|
||||
|
testBlockSize = 4096 |
||||
|
) |
||||
|
|
||||
|
func testLBAWithinBounds(t *testing.T) { |
||||
|
err := ValidateLBA(0, testVolSize, testBlockSize) |
||||
|
if err != nil { |
||||
|
t.Errorf("LBA=0 should be valid: %v", err) |
||||
|
} |
||||
|
|
||||
|
err = ValidateLBA(1000, testVolSize, testBlockSize) |
||||
|
if err != nil { |
||||
|
t.Errorf("LBA=1000 should be valid: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLBALastBlock(t *testing.T) { |
||||
|
maxLBA := uint64(testVolSize/testBlockSize) - 1 |
||||
|
err := ValidateLBA(maxLBA, testVolSize, testBlockSize) |
||||
|
if err != nil { |
||||
|
t.Errorf("last LBA=%d should be valid: %v", maxLBA, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLBAOutOfBounds(t *testing.T) { |
||||
|
maxLBA := uint64(testVolSize / testBlockSize) |
||||
|
err := ValidateLBA(maxLBA, testVolSize, testBlockSize) |
||||
|
if !errors.Is(err, ErrLBAOutOfBounds) { |
||||
|
t.Errorf("LBA=%d (one past end): expected ErrLBAOutOfBounds, got %v", maxLBA, err) |
||||
|
} |
||||
|
|
||||
|
err = ValidateLBA(maxLBA+1000, testVolSize, testBlockSize) |
||||
|
if !errors.Is(err, ErrLBAOutOfBounds) { |
||||
|
t.Errorf("LBA far past end: expected ErrLBAOutOfBounds, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLBAWriteSpansEnd(t *testing.T) { |
||||
|
maxLBA := uint64(testVolSize / testBlockSize) |
||||
|
// Write 2 blocks starting at last LBA -- spans past end.
|
||||
|
lastLBA := maxLBA - 1 |
||||
|
err := ValidateWrite(lastLBA, 2*testBlockSize, testVolSize, testBlockSize) |
||||
|
if !errors.Is(err, ErrWritePastEnd) { |
||||
|
t.Errorf("write spanning end: expected ErrWritePastEnd, got %v", err) |
||||
|
} |
||||
|
|
||||
|
// Write 1 block at last LBA -- should succeed.
|
||||
|
err = ValidateWrite(lastLBA, testBlockSize, testVolSize, testBlockSize) |
||||
|
if err != nil { |
||||
|
t.Errorf("write at last LBA should succeed: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testLBAAlignment(t *testing.T) { |
||||
|
err := ValidateWrite(0, 4096, testVolSize, testBlockSize) |
||||
|
if err != nil { |
||||
|
t.Errorf("aligned write should succeed: %v", err) |
||||
|
} |
||||
|
|
||||
|
err = ValidateWrite(0, 4000, testVolSize, testBlockSize) |
||||
|
if !errors.Is(err, ErrAlignment) { |
||||
|
t.Errorf("unaligned write: expected ErrAlignment, got %v", err) |
||||
|
} |
||||
|
|
||||
|
err = ValidateWrite(0, 1, testVolSize, testBlockSize) |
||||
|
if !errors.Is(err, ErrAlignment) { |
||||
|
t.Errorf("1-byte write: expected ErrAlignment, got %v", err) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,157 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"os" |
||||
|
) |
||||
|
|
||||
|
// RecoveryResult contains the outcome of WAL recovery.
|
||||
|
type RecoveryResult struct { |
||||
|
EntriesReplayed int // number of entries replayed into dirty map
|
||||
|
HighestLSN uint64 // highest LSN seen during recovery
|
||||
|
TornEntries int // entries discarded due to CRC failure
|
||||
|
} |
||||
|
|
||||
|
// RecoverWAL scans the WAL region from tail to head, replaying valid entries
|
||||
|
// into the dirty map. Entries with LSN <= checkpointLSN are skipped (already
|
||||
|
// in extent). Scanning stops at the first CRC failure (torn write).
|
||||
|
//
|
||||
|
// The WAL is a circular buffer. If head >= tail, scan [tail, head).
|
||||
|
// If head < tail (wrapped), scan [tail, walSize) then [0, head).
|
||||
|
func RecoverWAL(fd *os.File, sb *Superblock, dirtyMap *DirtyMap) (RecoveryResult, error) { |
||||
|
result := RecoveryResult{} |
||||
|
|
||||
|
logicalHead := sb.WALHead |
||||
|
logicalTail := sb.WALTail |
||||
|
walOffset := sb.WALOffset |
||||
|
walSize := sb.WALSize |
||||
|
checkpointLSN := sb.WALCheckpointLSN |
||||
|
|
||||
|
if logicalHead == logicalTail { |
||||
|
// WAL is empty (or fully flushed).
|
||||
|
return result, nil |
||||
|
} |
||||
|
|
||||
|
// Convert logical positions to physical.
|
||||
|
physHead := logicalHead % walSize |
||||
|
physTail := logicalTail % walSize |
||||
|
|
||||
|
// Build the list of byte ranges to scan.
|
||||
|
type scanRange struct { |
||||
|
start, end uint64 // physical positions within WAL
|
||||
|
} |
||||
|
|
||||
|
var ranges []scanRange |
||||
|
if physHead > physTail { |
||||
|
// No wrap: scan [tail, head).
|
||||
|
ranges = append(ranges, scanRange{physTail, physHead}) |
||||
|
} else if physHead == physTail { |
||||
|
// Head and tail at same physical position but different logical positions
|
||||
|
// means the WAL is completely full. Scan the entire region.
|
||||
|
ranges = append(ranges, scanRange{physTail, walSize}) |
||||
|
if physHead > 0 { |
||||
|
ranges = append(ranges, scanRange{0, physHead}) |
||||
|
} |
||||
|
} else { |
||||
|
// Wrapped: scan [tail, walSize) then [0, head).
|
||||
|
ranges = append(ranges, scanRange{physTail, walSize}) |
||||
|
if physHead > 0 { |
||||
|
ranges = append(ranges, scanRange{0, physHead}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
for _, r := range ranges { |
||||
|
pos := r.start |
||||
|
for pos < r.end { |
||||
|
remaining := r.end - pos |
||||
|
|
||||
|
// Need at least a header to proceed.
|
||||
|
if remaining < uint64(walEntryHeaderSize) { |
||||
|
break |
||||
|
} |
||||
|
|
||||
|
// Read header.
|
||||
|
headerBuf := make([]byte, walEntryHeaderSize) |
||||
|
absOff := int64(walOffset + pos) |
||||
|
if _, err := fd.ReadAt(headerBuf, absOff); err != nil { |
||||
|
return result, fmt.Errorf("recovery: read header at WAL+%d: %w", pos, err) |
||||
|
} |
||||
|
|
||||
|
// Parse entry type and length field.
|
||||
|
entryType := headerBuf[16] |
||||
|
lengthField := parseLength(headerBuf) |
||||
|
|
||||
|
// For padding entries, skip forward.
|
||||
|
if entryType == EntryTypePadding { |
||||
|
entrySize := uint64(walEntryHeaderSize) + uint64(lengthField) |
||||
|
pos += entrySize |
||||
|
continue |
||||
|
} |
||||
|
|
||||
|
// Calculate on-disk entry size. WRITE and PADDING carry data payload;
|
||||
|
// TRIM and BARRIER do not (Length is metadata, not data size).
|
||||
|
var payloadLen uint64 |
||||
|
if entryType == EntryTypeWrite { |
||||
|
payloadLen = uint64(lengthField) |
||||
|
} |
||||
|
entrySize := uint64(walEntryHeaderSize) + payloadLen |
||||
|
if entrySize > remaining { |
||||
|
// Torn write: entry extends past available data.
|
||||
|
result.TornEntries++ |
||||
|
break |
||||
|
} |
||||
|
|
||||
|
// Read full entry.
|
||||
|
fullBuf := make([]byte, entrySize) |
||||
|
if _, err := fd.ReadAt(fullBuf, absOff); err != nil { |
||||
|
return result, fmt.Errorf("recovery: read entry at WAL+%d: %w", pos, err) |
||||
|
} |
||||
|
|
||||
|
// Decode and validate CRC.
|
||||
|
entry, err := DecodeWALEntry(fullBuf) |
||||
|
if err != nil { |
||||
|
// CRC failure or corrupt entry — stop here (torn write).
|
||||
|
result.TornEntries++ |
||||
|
break |
||||
|
} |
||||
|
|
||||
|
// Skip entries already flushed to extent.
|
||||
|
if entry.LSN <= checkpointLSN { |
||||
|
pos += entrySize |
||||
|
continue |
||||
|
} |
||||
|
|
||||
|
// Replay entry.
|
||||
|
switch entry.Type { |
||||
|
case EntryTypeWrite: |
||||
|
blocks := entry.Length / sb.BlockSize |
||||
|
for i := uint32(0); i < blocks; i++ { |
||||
|
dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize) |
||||
|
} |
||||
|
result.EntriesReplayed++ |
||||
|
|
||||
|
case EntryTypeTrim: |
||||
|
// TRIM carries Length (bytes) covering multiple blocks.
|
||||
|
blocks := entry.Length / sb.BlockSize |
||||
|
if blocks == 0 { |
||||
|
blocks = 1 // legacy single-block trim
|
||||
|
} |
||||
|
for i := uint32(0); i < blocks; i++ { |
||||
|
dirtyMap.Put(entry.LBA+uint64(i), pos, entry.LSN, sb.BlockSize) |
||||
|
} |
||||
|
result.EntriesReplayed++ |
||||
|
|
||||
|
case EntryTypeBarrier: |
||||
|
// Barriers don't modify data, just skip.
|
||||
|
} |
||||
|
|
||||
|
if entry.LSN > result.HighestLSN { |
||||
|
result.HighestLSN = entry.LSN |
||||
|
} |
||||
|
|
||||
|
pos += entrySize |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return result, nil |
||||
|
} |
||||
@ -0,0 +1,416 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"path/filepath" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
func TestRecovery(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "recover_empty_wal", run: testRecoverEmptyWAL}, |
||||
|
{name: "recover_one_entry", run: testRecoverOneEntry}, |
||||
|
{name: "recover_many_entries", run: testRecoverManyEntries}, |
||||
|
{name: "recover_torn_write", run: testRecoverTornWrite}, |
||||
|
{name: "recover_after_checkpoint", run: testRecoverAfterCheckpoint}, |
||||
|
{name: "recover_idempotent", run: testRecoverIdempotent}, |
||||
|
{name: "recover_wal_full", run: testRecoverWALFull}, |
||||
|
{name: "recover_barrier_only", run: testRecoverBarrierOnly}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// simulateCrash closes the volume without syncing and returns the path.
|
||||
|
func simulateCrash(v *BlockVol) string { |
||||
|
path := v.Path() |
||||
|
v.groupCommit.Stop() |
||||
|
v.fd.Close() |
||||
|
return path |
||||
|
} |
||||
|
|
||||
|
func testRecoverEmptyWAL(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
// Sync to make superblock durable.
|
||||
|
v.fd.Sync() |
||||
|
path := simulateCrash(v) |
||||
|
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
// Empty volume: read should return zeros.
|
||||
|
got, err := v2.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, make([]byte, 4096)) { |
||||
|
t.Error("expected zeros from empty volume") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRecoverOneEntry(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
data := makeBlock('A') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Sync WAL to make the entry durable.
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Update superblock with WAL state.
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
v.fd.Seek(0, 0) |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
path := simulateCrash(v) |
||||
|
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
got, err := v2.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA after recovery: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("data not recovered") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRecoverManyEntries(t *testing.T) { |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "many.blockvol") |
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: 4 * 1024 * 1024, // 4MB
|
||||
|
BlockSize: 4096, |
||||
|
WALSize: 2 * 1024 * 1024, // 2MB WAL -- enough for ~500 entries
|
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
|
||||
|
const numWrites = 400 |
||||
|
for i := uint64(0); i < numWrites; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte(i%26+'A'))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
v.fd.Seek(0, 0) |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
path = simulateCrash(v) |
||||
|
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
for i := uint64(0); i < numWrites; i++ { |
||||
|
got, err := v2.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d) after recovery: %v", i, err) |
||||
|
} |
||||
|
expected := makeBlock(byte(i%26 + 'A')) |
||||
|
if !bytes.Equal(got, expected) { |
||||
|
t.Errorf("block %d: data mismatch after recovery", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRecoverTornWrite(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
|
||||
|
// Write 2 entries.
|
||||
|
if err := v.WriteLBA(0, makeBlock('A')); err != nil { |
||||
|
t.Fatalf("WriteLBA(0): %v", err) |
||||
|
} |
||||
|
if err := v.WriteLBA(1, makeBlock('B')); err != nil { |
||||
|
t.Fatalf("WriteLBA(1): %v", err) |
||||
|
} |
||||
|
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Save superblock with correct WAL state.
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
v.fd.Seek(0, 0) |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
// Corrupt the last 2 bytes of the second entry to simulate torn write.
|
||||
|
entrySize := uint64(walEntryHeaderSize + 4096) |
||||
|
secondEntryEnd := v.super.WALOffset + entrySize*2 |
||||
|
corruptOff := int64(secondEntryEnd - 2) |
||||
|
v.fd.WriteAt([]byte{0xFF, 0xFF}, corruptOff) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
path := simulateCrash(v) |
||||
|
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
// First entry should be recovered.
|
||||
|
got, err := v2.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(0) after torn recovery: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, makeBlock('A')) { |
||||
|
t.Error("block 0 should be recovered") |
||||
|
} |
||||
|
|
||||
|
// Second entry was torn -- should NOT be in dirty map.
|
||||
|
// Read returns zeros (from extent).
|
||||
|
got, err = v2.ReadLBA(1, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(1) after torn recovery: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, make([]byte, 4096)) { |
||||
|
t.Error("block 1 (torn) should return zeros") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRecoverAfterCheckpoint(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
|
||||
|
// Write blocks 0-9.
|
||||
|
for i := uint64(0); i < 10; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Flush first 5 blocks (flusher moves them to extent, advances checkpoint).
|
||||
|
f := NewFlusher(FlusherConfig{ |
||||
|
FD: v.fd, |
||||
|
Super: &v.super, |
||||
|
WAL: v.wal, |
||||
|
DirtyMap: v.dirtyMap, |
||||
|
Interval: 1 * time.Hour, |
||||
|
}) |
||||
|
|
||||
|
// Flush all (flusher takes all dirty entries).
|
||||
|
// To simulate partial flush, we'll manually set checkpoint to midpoint.
|
||||
|
if err := f.FlushOnce(); err != nil { |
||||
|
t.Fatalf("FlushOnce: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Now write 5 more blocks (these will need replay after crash).
|
||||
|
for i := uint64(10); i < 15; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Update superblock.
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
v.fd.Seek(0, 0) |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
path := simulateCrash(v) |
||||
|
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
// Blocks 0-9 should be in extent (flushed). Blocks 10-14 replayed from WAL.
|
||||
|
for i := uint64(0); i < 15; i++ { |
||||
|
got, err := v2.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d): %v", i, err) |
||||
|
} |
||||
|
expected := makeBlock(byte('A' + i)) |
||||
|
if !bytes.Equal(got, expected) { |
||||
|
t.Errorf("block %d: data mismatch after checkpoint recovery", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRecoverIdempotent(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
data := makeBlock('X') |
||||
|
if err := v.WriteLBA(0, data); err != nil { |
||||
|
t.Fatalf("WriteLBA: %v", err) |
||||
|
} |
||||
|
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
v.fd.Seek(0, 0) |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
path := simulateCrash(v) |
||||
|
|
||||
|
// First recovery.
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol 1: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Verify data.
|
||||
|
got, err := v2.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA 1: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("first recovery: data mismatch") |
||||
|
} |
||||
|
|
||||
|
// Close and recover again (should be idempotent).
|
||||
|
path2 := simulateCrash(v2) |
||||
|
|
||||
|
v3, err := OpenBlockVol(path2) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol 2: %v", err) |
||||
|
} |
||||
|
defer v3.Close() |
||||
|
|
||||
|
got, err = v3.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA 2: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, data) { |
||||
|
t.Error("second recovery: data mismatch") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRecoverWALFull(t *testing.T) { |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "full.blockvol") |
||||
|
|
||||
|
entrySize := uint64(walEntryHeaderSize + 4096) |
||||
|
walSize := entrySize * 5 // room for exactly 5 entries
|
||||
|
|
||||
|
v, err := CreateBlockVol(path, CreateOptions{ |
||||
|
VolumeSize: 1 * 1024 * 1024, |
||||
|
BlockSize: 4096, |
||||
|
WALSize: walSize, |
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("CreateBlockVol: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Write exactly 5 entries to fill WAL.
|
||||
|
for i := uint64(0); i < 5; i++ { |
||||
|
if err := v.WriteLBA(i, makeBlock(byte('A'+i))); err != nil { |
||||
|
t.Fatalf("WriteLBA(%d): %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
v.fd.Seek(0, 0) |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
path = simulateCrash(v) |
||||
|
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
for i := uint64(0); i < 5; i++ { |
||||
|
got, err := v2.ReadLBA(i, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA(%d): %v", i, err) |
||||
|
} |
||||
|
expected := makeBlock(byte('A' + i)) |
||||
|
if !bytes.Equal(got, expected) { |
||||
|
t.Errorf("block %d: data mismatch after full WAL recovery", i) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testRecoverBarrierOnly(t *testing.T) { |
||||
|
v := createTestVol(t) |
||||
|
|
||||
|
// Write a barrier entry.
|
||||
|
lsn := v.nextLSN.Add(1) - 1 |
||||
|
entry := &WALEntry{ |
||||
|
LSN: lsn, |
||||
|
Type: EntryTypeBarrier, |
||||
|
LBA: 0, |
||||
|
} |
||||
|
if _, err := v.wal.Append(entry); err != nil { |
||||
|
t.Fatalf("Append barrier: %v", err) |
||||
|
} |
||||
|
|
||||
|
if err := v.SyncCache(); err != nil { |
||||
|
t.Fatalf("SyncCache: %v", err) |
||||
|
} |
||||
|
|
||||
|
v.super.WALHead = v.wal.LogicalHead() |
||||
|
v.super.WALTail = v.wal.LogicalTail() |
||||
|
v.fd.Seek(0, 0) |
||||
|
v.super.WriteTo(v.fd) |
||||
|
v.fd.Sync() |
||||
|
|
||||
|
path := simulateCrash(v) |
||||
|
|
||||
|
v2, err := OpenBlockVol(path) |
||||
|
if err != nil { |
||||
|
t.Fatalf("OpenBlockVol: %v", err) |
||||
|
} |
||||
|
defer v2.Close() |
||||
|
|
||||
|
// No data changes from barrier.
|
||||
|
got, err := v2.ReadLBA(0, 4096) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadLBA: %v", err) |
||||
|
} |
||||
|
if !bytes.Equal(got, make([]byte, 4096)) { |
||||
|
t.Error("barrier-only WAL should leave data as zeros") |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,249 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
|
||||
|
"github.com/google/uuid" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
SuperblockSize = 4096 |
||||
|
MagicSWBK = "SWBK" |
||||
|
CurrentVersion = 1 |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrNotBlockVol = errors.New("blockvol: not a blockvol file (bad magic)") |
||||
|
ErrUnsupportedVersion = errors.New("blockvol: unsupported version") |
||||
|
ErrInvalidVolumeSize = errors.New("blockvol: volume size must be > 0") |
||||
|
) |
||||
|
|
||||
|
// Superblock is the 4KB header at offset 0 of a blockvol file.
|
||||
|
// It identifies the file format, stores volume geometry, and tracks WAL state.
|
||||
|
type Superblock struct { |
||||
|
Magic [4]byte |
||||
|
Version uint16 |
||||
|
Flags uint16 |
||||
|
UUID [16]byte |
||||
|
VolumeSize uint64 // logical size in bytes
|
||||
|
ExtentSize uint32 // default 64KB
|
||||
|
BlockSize uint32 // default 4KB (min I/O unit)
|
||||
|
WALOffset uint64 // byte offset where WAL region starts
|
||||
|
WALSize uint64 // WAL region size in bytes
|
||||
|
WALHead uint64 // logical WAL write position (monotonically increasing)
|
||||
|
WALTail uint64 // logical WAL flush position (monotonically increasing)
|
||||
|
WALCheckpointLSN uint64 // last LSN flushed to extent region
|
||||
|
Replication [4]byte |
||||
|
CreatedAt uint64 // unix timestamp
|
||||
|
SnapshotCount uint32 |
||||
|
} |
||||
|
|
||||
|
// superblockOnDisk is the fixed-size on-disk layout (binary.Write/Read target).
|
||||
|
// Must be <= SuperblockSize bytes. Remaining space is zero-padded.
|
||||
|
type superblockOnDisk struct { |
||||
|
Magic [4]byte |
||||
|
Version uint16 |
||||
|
Flags uint16 |
||||
|
UUID [16]byte |
||||
|
VolumeSize uint64 |
||||
|
ExtentSize uint32 |
||||
|
BlockSize uint32 |
||||
|
WALOffset uint64 |
||||
|
WALSize uint64 |
||||
|
WALHead uint64 |
||||
|
WALTail uint64 |
||||
|
WALCheckpointLSN uint64 |
||||
|
Replication [4]byte |
||||
|
CreatedAt uint64 |
||||
|
SnapshotCount uint32 |
||||
|
} |
||||
|
|
||||
|
// NewSuperblock creates a superblock with defaults and a fresh UUID.
|
||||
|
func NewSuperblock(volumeSize uint64, opts CreateOptions) (Superblock, error) { |
||||
|
if volumeSize == 0 { |
||||
|
return Superblock{}, ErrInvalidVolumeSize |
||||
|
} |
||||
|
|
||||
|
extentSize := opts.ExtentSize |
||||
|
if extentSize == 0 { |
||||
|
extentSize = 64 * 1024 // 64KB
|
||||
|
} |
||||
|
blockSize := opts.BlockSize |
||||
|
if blockSize == 0 { |
||||
|
blockSize = 4096 // 4KB
|
||||
|
} |
||||
|
walSize := opts.WALSize |
||||
|
if walSize == 0 { |
||||
|
walSize = 64 * 1024 * 1024 // 64MB
|
||||
|
} |
||||
|
|
||||
|
var repl [4]byte |
||||
|
if opts.Replication != "" { |
||||
|
copy(repl[:], opts.Replication) |
||||
|
} else { |
||||
|
copy(repl[:], "000") |
||||
|
} |
||||
|
|
||||
|
id := uuid.New() |
||||
|
|
||||
|
sb := Superblock{ |
||||
|
Version: CurrentVersion, |
||||
|
VolumeSize: volumeSize, |
||||
|
ExtentSize: extentSize, |
||||
|
BlockSize: blockSize, |
||||
|
WALOffset: SuperblockSize, |
||||
|
WALSize: walSize, |
||||
|
} |
||||
|
copy(sb.Magic[:], MagicSWBK) |
||||
|
sb.UUID = id |
||||
|
sb.Replication = repl |
||||
|
|
||||
|
return sb, nil |
||||
|
} |
||||
|
|
||||
|
// WriteTo serializes the superblock to w as a 4096-byte block.
|
||||
|
func (sb *Superblock) WriteTo(w io.Writer) (int64, error) { |
||||
|
buf := make([]byte, SuperblockSize) |
||||
|
|
||||
|
d := superblockOnDisk{ |
||||
|
Magic: sb.Magic, |
||||
|
Version: sb.Version, |
||||
|
Flags: sb.Flags, |
||||
|
UUID: sb.UUID, |
||||
|
VolumeSize: sb.VolumeSize, |
||||
|
ExtentSize: sb.ExtentSize, |
||||
|
BlockSize: sb.BlockSize, |
||||
|
WALOffset: sb.WALOffset, |
||||
|
WALSize: sb.WALSize, |
||||
|
WALHead: sb.WALHead, |
||||
|
WALTail: sb.WALTail, |
||||
|
WALCheckpointLSN: sb.WALCheckpointLSN, |
||||
|
Replication: sb.Replication, |
||||
|
CreatedAt: sb.CreatedAt, |
||||
|
SnapshotCount: sb.SnapshotCount, |
||||
|
} |
||||
|
|
||||
|
// Encode into beginning of buf; rest stays zero (padding).
|
||||
|
endian := binary.LittleEndian |
||||
|
off := 0 |
||||
|
off += copy(buf[off:], d.Magic[:]) |
||||
|
endian.PutUint16(buf[off:], d.Version) |
||||
|
off += 2 |
||||
|
endian.PutUint16(buf[off:], d.Flags) |
||||
|
off += 2 |
||||
|
off += copy(buf[off:], d.UUID[:]) |
||||
|
endian.PutUint64(buf[off:], d.VolumeSize) |
||||
|
off += 8 |
||||
|
endian.PutUint32(buf[off:], d.ExtentSize) |
||||
|
off += 4 |
||||
|
endian.PutUint32(buf[off:], d.BlockSize) |
||||
|
off += 4 |
||||
|
endian.PutUint64(buf[off:], d.WALOffset) |
||||
|
off += 8 |
||||
|
endian.PutUint64(buf[off:], d.WALSize) |
||||
|
off += 8 |
||||
|
endian.PutUint64(buf[off:], d.WALHead) |
||||
|
off += 8 |
||||
|
endian.PutUint64(buf[off:], d.WALTail) |
||||
|
off += 8 |
||||
|
endian.PutUint64(buf[off:], d.WALCheckpointLSN) |
||||
|
off += 8 |
||||
|
off += copy(buf[off:], d.Replication[:]) |
||||
|
endian.PutUint64(buf[off:], d.CreatedAt) |
||||
|
off += 8 |
||||
|
endian.PutUint32(buf[off:], d.SnapshotCount) |
||||
|
|
||||
|
n, err := w.Write(buf) |
||||
|
return int64(n), err |
||||
|
} |
||||
|
|
||||
|
// ReadSuperblock reads and validates a superblock from r.
|
||||
|
func ReadSuperblock(r io.Reader) (Superblock, error) { |
||||
|
buf := make([]byte, SuperblockSize) |
||||
|
if _, err := io.ReadFull(r, buf); err != nil { |
||||
|
return Superblock{}, fmt.Errorf("blockvol: read superblock: %w", err) |
||||
|
} |
||||
|
|
||||
|
endian := binary.LittleEndian |
||||
|
var sb Superblock |
||||
|
off := 0 |
||||
|
copy(sb.Magic[:], buf[off:off+4]) |
||||
|
off += 4 |
||||
|
|
||||
|
if string(sb.Magic[:]) != MagicSWBK { |
||||
|
return Superblock{}, ErrNotBlockVol |
||||
|
} |
||||
|
|
||||
|
sb.Version = endian.Uint16(buf[off:]) |
||||
|
off += 2 |
||||
|
if sb.Version != CurrentVersion { |
||||
|
return Superblock{}, fmt.Errorf("%w: got %d, want %d", ErrUnsupportedVersion, sb.Version, CurrentVersion) |
||||
|
} |
||||
|
|
||||
|
sb.Flags = endian.Uint16(buf[off:]) |
||||
|
off += 2 |
||||
|
copy(sb.UUID[:], buf[off:off+16]) |
||||
|
off += 16 |
||||
|
sb.VolumeSize = endian.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
|
||||
|
if sb.VolumeSize == 0 { |
||||
|
return Superblock{}, ErrInvalidVolumeSize |
||||
|
} |
||||
|
|
||||
|
sb.ExtentSize = endian.Uint32(buf[off:]) |
||||
|
off += 4 |
||||
|
sb.BlockSize = endian.Uint32(buf[off:]) |
||||
|
off += 4 |
||||
|
sb.WALOffset = endian.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
sb.WALSize = endian.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
sb.WALHead = endian.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
sb.WALTail = endian.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
sb.WALCheckpointLSN = endian.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
copy(sb.Replication[:], buf[off:off+4]) |
||||
|
off += 4 |
||||
|
sb.CreatedAt = endian.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
sb.SnapshotCount = endian.Uint32(buf[off:]) |
||||
|
|
||||
|
return sb, nil |
||||
|
} |
||||
|
|
||||
|
var ErrInvalidSuperblock = errors.New("blockvol: invalid superblock") |
||||
|
|
||||
|
// Validate checks that the superblock fields are internally consistent.
|
||||
|
func (sb *Superblock) Validate() error { |
||||
|
if string(sb.Magic[:]) != MagicSWBK { |
||||
|
return ErrNotBlockVol |
||||
|
} |
||||
|
if sb.Version != CurrentVersion { |
||||
|
return fmt.Errorf("%w: got %d", ErrUnsupportedVersion, sb.Version) |
||||
|
} |
||||
|
if sb.VolumeSize == 0 { |
||||
|
return ErrInvalidVolumeSize |
||||
|
} |
||||
|
if sb.BlockSize == 0 { |
||||
|
return fmt.Errorf("%w: BlockSize is 0", ErrInvalidSuperblock) |
||||
|
} |
||||
|
if sb.ExtentSize == 0 { |
||||
|
return fmt.Errorf("%w: ExtentSize is 0", ErrInvalidSuperblock) |
||||
|
} |
||||
|
if sb.WALSize == 0 { |
||||
|
return fmt.Errorf("%w: WALSize is 0", ErrInvalidSuperblock) |
||||
|
} |
||||
|
if sb.WALOffset != SuperblockSize { |
||||
|
return fmt.Errorf("%w: WALOffset=%d, expected %d", ErrInvalidSuperblock, sb.WALOffset, SuperblockSize) |
||||
|
} |
||||
|
if sb.VolumeSize%uint64(sb.BlockSize) != 0 { |
||||
|
return fmt.Errorf("%w: VolumeSize %d not aligned to BlockSize %d", ErrInvalidSuperblock, sb.VolumeSize, sb.BlockSize) |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
@ -0,0 +1,146 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestSuperblock(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "superblock_roundtrip", run: testSuperblockRoundtrip}, |
||||
|
{name: "superblock_magic_check", run: testSuperblockMagicCheck}, |
||||
|
{name: "superblock_version_check", run: testSuperblockVersionCheck}, |
||||
|
{name: "superblock_zero_vol_size", run: testSuperblockZeroVolSize}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSuperblockRoundtrip(t *testing.T) { |
||||
|
sb, err := NewSuperblock(100*1024*1024*1024, CreateOptions{ |
||||
|
ExtentSize: 64 * 1024, |
||||
|
BlockSize: 4096, |
||||
|
WALSize: 64 * 1024 * 1024, |
||||
|
Replication: "010", |
||||
|
}) |
||||
|
if err != nil { |
||||
|
t.Fatalf("NewSuperblock: %v", err) |
||||
|
} |
||||
|
sb.CreatedAt = 1700000000 |
||||
|
sb.SnapshotCount = 3 |
||||
|
sb.WALHead = 8192 |
||||
|
sb.WALCheckpointLSN = 42 |
||||
|
|
||||
|
var buf bytes.Buffer |
||||
|
n, err := sb.WriteTo(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("WriteTo: %v", err) |
||||
|
} |
||||
|
if n != SuperblockSize { |
||||
|
t.Fatalf("WriteTo wrote %d bytes, want %d", n, SuperblockSize) |
||||
|
} |
||||
|
|
||||
|
got, err := ReadSuperblock(&buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("ReadSuperblock: %v", err) |
||||
|
} |
||||
|
|
||||
|
if string(got.Magic[:]) != MagicSWBK { |
||||
|
t.Errorf("Magic = %q, want %q", got.Magic, MagicSWBK) |
||||
|
} |
||||
|
if got.Version != CurrentVersion { |
||||
|
t.Errorf("Version = %d, want %d", got.Version, CurrentVersion) |
||||
|
} |
||||
|
if got.UUID != sb.UUID { |
||||
|
t.Errorf("UUID mismatch") |
||||
|
} |
||||
|
if got.VolumeSize != sb.VolumeSize { |
||||
|
t.Errorf("VolumeSize = %d, want %d", got.VolumeSize, sb.VolumeSize) |
||||
|
} |
||||
|
if got.ExtentSize != sb.ExtentSize { |
||||
|
t.Errorf("ExtentSize = %d, want %d", got.ExtentSize, sb.ExtentSize) |
||||
|
} |
||||
|
if got.BlockSize != sb.BlockSize { |
||||
|
t.Errorf("BlockSize = %d, want %d", got.BlockSize, sb.BlockSize) |
||||
|
} |
||||
|
if got.WALOffset != sb.WALOffset { |
||||
|
t.Errorf("WALOffset = %d, want %d", got.WALOffset, sb.WALOffset) |
||||
|
} |
||||
|
if got.WALSize != sb.WALSize { |
||||
|
t.Errorf("WALSize = %d, want %d", got.WALSize, sb.WALSize) |
||||
|
} |
||||
|
if got.WALHead != sb.WALHead { |
||||
|
t.Errorf("WALHead = %d, want %d", got.WALHead, sb.WALHead) |
||||
|
} |
||||
|
if got.WALCheckpointLSN != sb.WALCheckpointLSN { |
||||
|
t.Errorf("WALCheckpointLSN = %d, want %d", got.WALCheckpointLSN, sb.WALCheckpointLSN) |
||||
|
} |
||||
|
if got.Replication != sb.Replication { |
||||
|
t.Errorf("Replication = %q, want %q", got.Replication, sb.Replication) |
||||
|
} |
||||
|
if got.CreatedAt != sb.CreatedAt { |
||||
|
t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, sb.CreatedAt) |
||||
|
} |
||||
|
if got.SnapshotCount != sb.SnapshotCount { |
||||
|
t.Errorf("SnapshotCount = %d, want %d", got.SnapshotCount, sb.SnapshotCount) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSuperblockMagicCheck(t *testing.T) { |
||||
|
// Write a valid superblock, then corrupt the magic bytes.
|
||||
|
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) |
||||
|
var buf bytes.Buffer |
||||
|
sb.WriteTo(&buf) |
||||
|
|
||||
|
data := buf.Bytes() |
||||
|
copy(data[0:4], "XXXX") // corrupt magic
|
||||
|
|
||||
|
_, err := ReadSuperblock(bytes.NewReader(data)) |
||||
|
if !errors.Is(err, ErrNotBlockVol) { |
||||
|
t.Errorf("expected ErrNotBlockVol, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSuperblockVersionCheck(t *testing.T) { |
||||
|
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) |
||||
|
var buf bytes.Buffer |
||||
|
sb.WriteTo(&buf) |
||||
|
|
||||
|
data := buf.Bytes() |
||||
|
// Version is at offset 4, uint16 little-endian.
|
||||
|
binary.LittleEndian.PutUint16(data[4:], 99) |
||||
|
|
||||
|
_, err := ReadSuperblock(bytes.NewReader(data)) |
||||
|
if !errors.Is(err, ErrUnsupportedVersion) { |
||||
|
t.Errorf("expected ErrUnsupportedVersion, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testSuperblockZeroVolSize(t *testing.T) { |
||||
|
_, err := NewSuperblock(0, CreateOptions{}) |
||||
|
if !errors.Is(err, ErrInvalidVolumeSize) { |
||||
|
t.Errorf("NewSuperblock(0): expected ErrInvalidVolumeSize, got %v", err) |
||||
|
} |
||||
|
|
||||
|
// Also test reading a superblock with zero volume size.
|
||||
|
sb, _ := NewSuperblock(1*1024*1024*1024, CreateOptions{}) |
||||
|
var buf bytes.Buffer |
||||
|
sb.WriteTo(&buf) |
||||
|
|
||||
|
data := buf.Bytes() |
||||
|
// VolumeSize is at offset 4+2+2+16 = 24, uint64 little-endian.
|
||||
|
binary.LittleEndian.PutUint64(data[24:], 0) |
||||
|
|
||||
|
_, err = ReadSuperblock(bytes.NewReader(data)) |
||||
|
if !errors.Is(err, ErrInvalidVolumeSize) { |
||||
|
t.Errorf("ReadSuperblock(vol_size=0): expected ErrInvalidVolumeSize, got %v", err) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,153 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"hash/crc32" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
EntryTypeWrite = 0x01 |
||||
|
EntryTypeTrim = 0x02 |
||||
|
EntryTypeBarrier = 0x03 |
||||
|
EntryTypePadding = 0xFF |
||||
|
|
||||
|
// walEntryHeaderSize is the fixed portion: LSN(8) + Epoch(8) + Type(1) +
|
||||
|
// Flags(1) + LBA(8) + Length(4) + CRC32(4) + EntrySize(4) = 38 bytes.
|
||||
|
walEntryHeaderSize = 38 |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrCRCMismatch = errors.New("blockvol: CRC mismatch") |
||||
|
ErrInvalidEntry = errors.New("blockvol: invalid WAL entry") |
||||
|
ErrEntryTruncated = errors.New("blockvol: WAL entry truncated") |
||||
|
) |
||||
|
|
||||
|
// WALEntry is a variable-size record in the WAL region.
|
||||
|
type WALEntry struct { |
||||
|
LSN uint64 |
||||
|
Epoch uint64 // writer's epoch (Phase 1: always 0)
|
||||
|
Type uint8 // EntryTypeWrite, EntryTypeTrim, EntryTypeBarrier
|
||||
|
Flags uint8 |
||||
|
LBA uint64 // in blocks
|
||||
|
Length uint32 // data length in bytes (WRITE: data size, TRIM: trim size, BARRIER: 0)
|
||||
|
Data []byte // present only for WRITE
|
||||
|
CRC32 uint32 // covers LSN through Data
|
||||
|
EntrySize uint32 // total serialized size
|
||||
|
} |
||||
|
|
||||
|
// Encode serializes the entry into a byte slice, computing CRC over all fields.
|
||||
|
func (e *WALEntry) Encode() ([]byte, error) { |
||||
|
switch e.Type { |
||||
|
case EntryTypeWrite: |
||||
|
if len(e.Data) == 0 { |
||||
|
return nil, fmt.Errorf("%w: WRITE entry with no data", ErrInvalidEntry) |
||||
|
} |
||||
|
if uint32(len(e.Data)) != e.Length { |
||||
|
return nil, fmt.Errorf("%w: data length %d != Length field %d", ErrInvalidEntry, len(e.Data), e.Length) |
||||
|
} |
||||
|
case EntryTypeTrim: |
||||
|
if len(e.Data) != 0 { |
||||
|
return nil, fmt.Errorf("%w: TRIM entry must have no data payload", ErrInvalidEntry) |
||||
|
} |
||||
|
// TRIM carries Length (trim size in bytes) but no Data payload.
|
||||
|
case EntryTypeBarrier: |
||||
|
if e.Length != 0 || len(e.Data) != 0 { |
||||
|
return nil, fmt.Errorf("%w: BARRIER entry must have no data", ErrInvalidEntry) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
totalSize := uint32(walEntryHeaderSize + len(e.Data)) |
||||
|
buf := make([]byte, totalSize) |
||||
|
|
||||
|
le := binary.LittleEndian |
||||
|
off := 0 |
||||
|
le.PutUint64(buf[off:], e.LSN) |
||||
|
off += 8 |
||||
|
le.PutUint64(buf[off:], e.Epoch) |
||||
|
off += 8 |
||||
|
buf[off] = e.Type |
||||
|
off++ |
||||
|
buf[off] = e.Flags |
||||
|
off++ |
||||
|
le.PutUint64(buf[off:], e.LBA) |
||||
|
off += 8 |
||||
|
le.PutUint32(buf[off:], e.Length) |
||||
|
off += 4 |
||||
|
|
||||
|
if len(e.Data) > 0 { |
||||
|
copy(buf[off:], e.Data) |
||||
|
off += len(e.Data) |
||||
|
} |
||||
|
|
||||
|
// CRC covers everything from start through Data (offset 0 to off).
|
||||
|
checksum := crc32.ChecksumIEEE(buf[:off]) |
||||
|
le.PutUint32(buf[off:], checksum) |
||||
|
off += 4 |
||||
|
le.PutUint32(buf[off:], totalSize) |
||||
|
|
||||
|
e.CRC32 = checksum |
||||
|
e.EntrySize = totalSize |
||||
|
return buf, nil |
||||
|
} |
||||
|
|
||||
|
// DecodeWALEntry deserializes a WAL entry from buf, validating CRC.
|
||||
|
func DecodeWALEntry(buf []byte) (WALEntry, error) { |
||||
|
if len(buf) < walEntryHeaderSize { |
||||
|
return WALEntry{}, fmt.Errorf("%w: need %d bytes, have %d", ErrEntryTruncated, walEntryHeaderSize, len(buf)) |
||||
|
} |
||||
|
|
||||
|
le := binary.LittleEndian |
||||
|
var e WALEntry |
||||
|
off := 0 |
||||
|
e.LSN = le.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
e.Epoch = le.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
e.Type = buf[off] |
||||
|
off++ |
||||
|
e.Flags = buf[off] |
||||
|
off++ |
||||
|
e.LBA = le.Uint64(buf[off:]) |
||||
|
off += 8 |
||||
|
e.Length = le.Uint32(buf[off:]) |
||||
|
off += 4 |
||||
|
|
||||
|
// For WRITE entries, Length is the data payload size.
|
||||
|
// For TRIM entries, Length is the trim extent in bytes (no data payload).
|
||||
|
// For BARRIER/PADDING, Length is 0 (or padding size).
|
||||
|
var dataLen int |
||||
|
if e.Type == EntryTypeWrite || e.Type == EntryTypePadding { |
||||
|
dataLen = int(e.Length) |
||||
|
} |
||||
|
|
||||
|
dataEnd := off + dataLen |
||||
|
if dataEnd+8 > len(buf) { // +8 for CRC32 + EntrySize
|
||||
|
return WALEntry{}, fmt.Errorf("%w: need %d bytes for data+footer, have %d", ErrEntryTruncated, dataEnd+8, len(buf)) |
||||
|
} |
||||
|
|
||||
|
if dataLen > 0 { |
||||
|
e.Data = make([]byte, dataLen) |
||||
|
copy(e.Data, buf[off:dataEnd]) |
||||
|
} |
||||
|
off = dataEnd |
||||
|
|
||||
|
e.CRC32 = le.Uint32(buf[off:]) |
||||
|
off += 4 |
||||
|
e.EntrySize = le.Uint32(buf[off:]) |
||||
|
|
||||
|
// Verify CRC: covers LSN through Data.
|
||||
|
expected := crc32.ChecksumIEEE(buf[:dataEnd]) |
||||
|
if e.CRC32 != expected { |
||||
|
return WALEntry{}, fmt.Errorf("%w: stored=%08x computed=%08x", ErrCRCMismatch, e.CRC32, expected) |
||||
|
} |
||||
|
|
||||
|
// Verify EntrySize matches actual layout.
|
||||
|
expectedSize := uint32(walEntryHeaderSize) + uint32(dataLen) |
||||
|
if e.EntrySize != expectedSize { |
||||
|
return WALEntry{}, fmt.Errorf("%w: EntrySize=%d, expected=%d", ErrInvalidEntry, e.EntrySize, expectedSize) |
||||
|
} |
||||
|
|
||||
|
return e, nil |
||||
|
} |
||||
@ -0,0 +1,284 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestWALEntry(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "wal_entry_roundtrip", run: testWALEntryRoundtrip}, |
||||
|
{name: "wal_entry_trim_no_data", run: testWALEntryTrimNoData}, |
||||
|
{name: "wal_entry_barrier", run: testWALEntryBarrier}, |
||||
|
{name: "wal_entry_crc_valid", run: testWALEntryCRCValid}, |
||||
|
{name: "wal_entry_crc_corrupt", run: testWALEntryCRCCorrupt}, |
||||
|
{name: "wal_entry_max_size", run: testWALEntryMaxSize}, |
||||
|
{name: "wal_entry_zero_length", run: testWALEntryZeroLength}, |
||||
|
{name: "wal_entry_trim_rejects_data", run: testWALEntryTrimRejectsData}, |
||||
|
{name: "wal_entry_barrier_rejects_data", run: testWALEntryBarrierRejectsData}, |
||||
|
{name: "wal_entry_bad_entry_size", run: testWALEntryBadEntrySize}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryRoundtrip(t *testing.T) { |
||||
|
data := make([]byte, 4096) |
||||
|
for i := range data { |
||||
|
data[i] = byte(i % 251) // deterministic pattern
|
||||
|
} |
||||
|
|
||||
|
e := WALEntry{ |
||||
|
LSN: 1, |
||||
|
Epoch: 0, |
||||
|
Type: EntryTypeWrite, |
||||
|
Flags: 0, |
||||
|
LBA: 100, |
||||
|
Length: uint32(len(data)), |
||||
|
Data: data, |
||||
|
} |
||||
|
|
||||
|
buf, err := e.Encode() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Encode: %v", err) |
||||
|
} |
||||
|
|
||||
|
got, err := DecodeWALEntry(buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("DecodeWALEntry: %v", err) |
||||
|
} |
||||
|
|
||||
|
if got.LSN != e.LSN { |
||||
|
t.Errorf("LSN = %d, want %d", got.LSN, e.LSN) |
||||
|
} |
||||
|
if got.Epoch != e.Epoch { |
||||
|
t.Errorf("Epoch = %d, want %d", got.Epoch, e.Epoch) |
||||
|
} |
||||
|
if got.Type != e.Type { |
||||
|
t.Errorf("Type = %d, want %d", got.Type, e.Type) |
||||
|
} |
||||
|
if got.LBA != e.LBA { |
||||
|
t.Errorf("LBA = %d, want %d", got.LBA, e.LBA) |
||||
|
} |
||||
|
if got.Length != e.Length { |
||||
|
t.Errorf("Length = %d, want %d", got.Length, e.Length) |
||||
|
} |
||||
|
if !bytes.Equal(got.Data, e.Data) { |
||||
|
t.Errorf("Data mismatch") |
||||
|
} |
||||
|
if got.EntrySize != uint32(walEntryHeaderSize+len(data)) { |
||||
|
t.Errorf("EntrySize = %d, want %d", got.EntrySize, walEntryHeaderSize+len(data)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryTrimNoData(t *testing.T) { |
||||
|
// TRIM carries Length (trim size in bytes) but no Data payload.
|
||||
|
e := WALEntry{ |
||||
|
LSN: 5, |
||||
|
Epoch: 0, |
||||
|
Type: EntryTypeTrim, |
||||
|
LBA: 200, |
||||
|
Length: 4096, |
||||
|
} |
||||
|
|
||||
|
buf, err := e.Encode() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Encode: %v", err) |
||||
|
} |
||||
|
|
||||
|
got, err := DecodeWALEntry(buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("DecodeWALEntry: %v", err) |
||||
|
} |
||||
|
|
||||
|
if got.Type != EntryTypeTrim { |
||||
|
t.Errorf("Type = %d, want %d", got.Type, EntryTypeTrim) |
||||
|
} |
||||
|
if got.Length != 4096 { |
||||
|
t.Errorf("Length = %d, want 4096", got.Length) |
||||
|
} |
||||
|
if len(got.Data) != 0 { |
||||
|
t.Errorf("Data should be empty, got %d bytes", len(got.Data)) |
||||
|
} |
||||
|
// TRIM with Length but no Data: EntrySize = header only.
|
||||
|
if got.EntrySize != uint32(walEntryHeaderSize) { |
||||
|
t.Errorf("EntrySize = %d, want %d (header only)", got.EntrySize, walEntryHeaderSize) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryBarrier(t *testing.T) { |
||||
|
e := WALEntry{ |
||||
|
LSN: 10, |
||||
|
Epoch: 0, |
||||
|
Type: EntryTypeBarrier, |
||||
|
} |
||||
|
|
||||
|
buf, err := e.Encode() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Encode: %v", err) |
||||
|
} |
||||
|
|
||||
|
got, err := DecodeWALEntry(buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("DecodeWALEntry: %v", err) |
||||
|
} |
||||
|
|
||||
|
if got.Type != EntryTypeBarrier { |
||||
|
t.Errorf("Type = %d, want %d", got.Type, EntryTypeBarrier) |
||||
|
} |
||||
|
if len(got.Data) != 0 { |
||||
|
t.Errorf("Barrier should have no data, got %d bytes", len(got.Data)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryCRCValid(t *testing.T) { |
||||
|
data := []byte("hello blockvol WAL") |
||||
|
e := WALEntry{ |
||||
|
LSN: 1, |
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: 0, |
||||
|
Length: uint32(len(data)), |
||||
|
Data: data, |
||||
|
} |
||||
|
|
||||
|
buf, err := e.Encode() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Encode: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Decode should succeed with valid CRC.
|
||||
|
_, err = DecodeWALEntry(buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("valid entry should decode without error, got: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryCRCCorrupt(t *testing.T) { |
||||
|
data := []byte("hello blockvol WAL") |
||||
|
e := WALEntry{ |
||||
|
LSN: 1, |
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: 0, |
||||
|
Length: uint32(len(data)), |
||||
|
Data: data, |
||||
|
} |
||||
|
|
||||
|
buf, err := e.Encode() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Encode: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Flip one bit in the data region.
|
||||
|
buf[walEntryHeaderSize-8] ^= 0x01 // flip bit in data area (before CRC/EntrySize footer)
|
||||
|
|
||||
|
_, err = DecodeWALEntry(buf) |
||||
|
if !errors.Is(err, ErrCRCMismatch) { |
||||
|
t.Errorf("expected ErrCRCMismatch, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryMaxSize(t *testing.T) { |
||||
|
data := make([]byte, 64*1024) // 64KB
|
||||
|
for i := range data { |
||||
|
data[i] = byte(i % 256) |
||||
|
} |
||||
|
|
||||
|
e := WALEntry{ |
||||
|
LSN: 99, |
||||
|
Epoch: 0, |
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: 0, |
||||
|
Length: uint32(len(data)), |
||||
|
Data: data, |
||||
|
} |
||||
|
|
||||
|
buf, err := e.Encode() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Encode 64KB entry: %v", err) |
||||
|
} |
||||
|
|
||||
|
got, err := DecodeWALEntry(buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("DecodeWALEntry: %v", err) |
||||
|
} |
||||
|
|
||||
|
if !bytes.Equal(got.Data, data) { |
||||
|
t.Errorf("64KB data mismatch after roundtrip") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryZeroLength(t *testing.T) { |
||||
|
e := WALEntry{ |
||||
|
LSN: 1, |
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: 0, |
||||
|
Length: 0, |
||||
|
Data: nil, |
||||
|
} |
||||
|
|
||||
|
_, err := e.Encode() |
||||
|
if !errors.Is(err, ErrInvalidEntry) { |
||||
|
t.Errorf("WRITE with zero data: expected ErrInvalidEntry, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryTrimRejectsData(t *testing.T) { |
||||
|
// TRIM allows Length (trim extent) but rejects Data payload.
|
||||
|
e := WALEntry{ |
||||
|
LSN: 1, |
||||
|
Type: EntryTypeTrim, |
||||
|
LBA: 0, |
||||
|
Length: 4096, |
||||
|
Data: []byte("bad!"), |
||||
|
} |
||||
|
_, err := e.Encode() |
||||
|
if !errors.Is(err, ErrInvalidEntry) { |
||||
|
t.Errorf("TRIM with data payload: expected ErrInvalidEntry, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryBarrierRejectsData(t *testing.T) { |
||||
|
e := WALEntry{ |
||||
|
LSN: 1, |
||||
|
Type: EntryTypeBarrier, |
||||
|
Length: 1, |
||||
|
Data: []byte("x"), |
||||
|
} |
||||
|
_, err := e.Encode() |
||||
|
if !errors.Is(err, ErrInvalidEntry) { |
||||
|
t.Errorf("BARRIER with data: expected ErrInvalidEntry, got %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALEntryBadEntrySize(t *testing.T) { |
||||
|
data := []byte("test data for entry size validation") |
||||
|
e := WALEntry{ |
||||
|
LSN: 1, |
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: 0, |
||||
|
Length: uint32(len(data)), |
||||
|
Data: data, |
||||
|
} |
||||
|
|
||||
|
buf, err := e.Encode() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Encode: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Corrupt EntrySize (last 4 bytes of buf).
|
||||
|
le := binary.LittleEndian |
||||
|
le.PutUint32(buf[len(buf)-4:], 9999) |
||||
|
|
||||
|
_, err = DecodeWALEntry(buf) |
||||
|
if !errors.Is(err, ErrInvalidEntry) { |
||||
|
t.Errorf("bad EntrySize: expected ErrInvalidEntry, got %v", err) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,192 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"hash/crc32" |
||||
|
"os" |
||||
|
"sync" |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
ErrWALFull = errors.New("blockvol: WAL region full") |
||||
|
) |
||||
|
|
||||
|
// WALWriter appends entries to the circular WAL region of a blockvol file.
|
||||
|
//
|
||||
|
// It uses logical (monotonically increasing) head and tail counters to track
|
||||
|
// used space. Physical position = logical % walSize. This eliminates the
|
||||
|
// classic circular buffer ambiguity where head==tail could mean empty or full.
|
||||
|
// Used space = logicalHead - logicalTail. Free space = walSize - used.
|
||||
|
type WALWriter struct { |
||||
|
mu sync.Mutex |
||||
|
fd *os.File |
||||
|
walOffset uint64 // absolute file offset where WAL region starts
|
||||
|
walSize uint64 // size of the WAL region in bytes
|
||||
|
logicalHead uint64 // monotonically increasing write position
|
||||
|
logicalTail uint64 // monotonically increasing flush position
|
||||
|
} |
||||
|
|
||||
|
// NewWALWriter creates a WAL writer for the given file.
|
||||
|
// head and tail are physical positions relative to WAL region start.
|
||||
|
// For a fresh WAL, both are 0.
|
||||
|
func NewWALWriter(fd *os.File, walOffset, walSize, head, tail uint64) *WALWriter { |
||||
|
return &WALWriter{ |
||||
|
fd: fd, |
||||
|
walOffset: walOffset, |
||||
|
walSize: walSize, |
||||
|
logicalHead: head, // on fresh WAL, physical == logical (both start at 0)
|
||||
|
logicalTail: tail, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// physicalPos converts a logical position to a physical WAL offset.
|
||||
|
func (w *WALWriter) physicalPos(logical uint64) uint64 { |
||||
|
return logical % w.walSize |
||||
|
} |
||||
|
|
||||
|
// used returns the number of bytes occupied in the WAL.
|
||||
|
func (w *WALWriter) used() uint64 { |
||||
|
return w.logicalHead - w.logicalTail |
||||
|
} |
||||
|
|
||||
|
// Append writes a serialized WAL entry to the circular WAL region.
|
||||
|
// Returns the physical WAL-relative offset where the entry was written.
|
||||
|
// If the entry doesn't fit in the remaining space before the region end,
|
||||
|
// a padding entry is written and the real entry starts at physical offset 0.
|
||||
|
func (w *WALWriter) Append(entry *WALEntry) (walRelOffset uint64, err error) { |
||||
|
buf, err := entry.Encode() |
||||
|
if err != nil { |
||||
|
return 0, fmt.Errorf("WALWriter.Append: encode: %w", err) |
||||
|
} |
||||
|
|
||||
|
w.mu.Lock() |
||||
|
defer w.mu.Unlock() |
||||
|
|
||||
|
entryLen := uint64(len(buf)) |
||||
|
if entryLen > w.walSize { |
||||
|
return 0, fmt.Errorf("%w: entry size %d exceeds WAL size %d", ErrWALFull, entryLen, w.walSize) |
||||
|
} |
||||
|
|
||||
|
physHead := w.physicalPos(w.logicalHead) |
||||
|
remaining := w.walSize - physHead |
||||
|
|
||||
|
if remaining < entryLen { |
||||
|
// Not enough room at end of region -- write padding and wrap.
|
||||
|
// Padding consumes 'remaining' bytes logically.
|
||||
|
if w.used()+remaining+entryLen > w.walSize { |
||||
|
return 0, ErrWALFull |
||||
|
} |
||||
|
if err := w.writePadding(remaining, physHead); err != nil { |
||||
|
return 0, fmt.Errorf("WALWriter.Append: padding: %w", err) |
||||
|
} |
||||
|
w.logicalHead += remaining |
||||
|
physHead = 0 |
||||
|
} |
||||
|
|
||||
|
// Check if there's enough free space for the entry.
|
||||
|
if w.used()+entryLen > w.walSize { |
||||
|
return 0, ErrWALFull |
||||
|
} |
||||
|
|
||||
|
absOffset := int64(w.walOffset + physHead) |
||||
|
if _, err := w.fd.WriteAt(buf, absOffset); err != nil { |
||||
|
return 0, fmt.Errorf("WALWriter.Append: pwrite at offset %d: %w", absOffset, err) |
||||
|
} |
||||
|
|
||||
|
writeOffset := physHead |
||||
|
w.logicalHead += entryLen |
||||
|
return writeOffset, nil |
||||
|
} |
||||
|
|
||||
|
// writePadding writes a padding entry at the given physical position.
|
||||
|
func (w *WALWriter) writePadding(size uint64, physPos uint64) error { |
||||
|
if size < walEntryHeaderSize { |
||||
|
// Too small for a proper entry header -- zero it out.
|
||||
|
buf := make([]byte, size) |
||||
|
absOffset := int64(w.walOffset + physPos) |
||||
|
_, err := w.fd.WriteAt(buf, absOffset) |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
buf := make([]byte, size) |
||||
|
le := binary.LittleEndian |
||||
|
off := 0 |
||||
|
le.PutUint64(buf[off:], 0) // LSN=0
|
||||
|
off += 8 |
||||
|
le.PutUint64(buf[off:], 0) // Epoch=0
|
||||
|
off += 8 |
||||
|
buf[off] = EntryTypePadding |
||||
|
off++ |
||||
|
buf[off] = 0 // Flags
|
||||
|
off++ |
||||
|
le.PutUint64(buf[off:], 0) // LBA=0
|
||||
|
off += 8 |
||||
|
paddingDataLen := uint32(size) - uint32(walEntryHeaderSize) |
||||
|
le.PutUint32(buf[off:], paddingDataLen) |
||||
|
off += 4 |
||||
|
dataEnd := off + int(paddingDataLen) |
||||
|
|
||||
|
crc := crc32.ChecksumIEEE(buf[:dataEnd]) |
||||
|
le.PutUint32(buf[dataEnd:], crc) |
||||
|
le.PutUint32(buf[dataEnd+4:], uint32(size)) |
||||
|
|
||||
|
absOffset := int64(w.walOffset + physPos) |
||||
|
_, err := w.fd.WriteAt(buf, absOffset) |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// AdvanceTail moves the tail forward, freeing WAL space.
|
||||
|
// Called by the flusher after entries have been written to the extent region.
|
||||
|
// newTail is a physical position; it is converted to a logical advance.
|
||||
|
func (w *WALWriter) AdvanceTail(newTail uint64) { |
||||
|
w.mu.Lock() |
||||
|
physTail := w.physicalPos(w.logicalTail) |
||||
|
var advance uint64 |
||||
|
if newTail >= physTail { |
||||
|
advance = newTail - physTail |
||||
|
} else { |
||||
|
// Tail wrapped around.
|
||||
|
advance = w.walSize - physTail + newTail |
||||
|
} |
||||
|
w.logicalTail += advance |
||||
|
w.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
// Head returns the current physical head position (relative to WAL start).
|
||||
|
func (w *WALWriter) Head() uint64 { |
||||
|
w.mu.Lock() |
||||
|
h := w.physicalPos(w.logicalHead) |
||||
|
w.mu.Unlock() |
||||
|
return h |
||||
|
} |
||||
|
|
||||
|
// Tail returns the current physical tail position (relative to WAL start).
|
||||
|
func (w *WALWriter) Tail() uint64 { |
||||
|
w.mu.Lock() |
||||
|
t := w.physicalPos(w.logicalTail) |
||||
|
w.mu.Unlock() |
||||
|
return t |
||||
|
} |
||||
|
|
||||
|
// LogicalHead returns the logical (monotonically increasing) head position.
|
||||
|
func (w *WALWriter) LogicalHead() uint64 { |
||||
|
w.mu.Lock() |
||||
|
h := w.logicalHead |
||||
|
w.mu.Unlock() |
||||
|
return h |
||||
|
} |
||||
|
|
||||
|
// LogicalTail returns the logical (monotonically increasing) tail position.
|
||||
|
func (w *WALWriter) LogicalTail() uint64 { |
||||
|
w.mu.Lock() |
||||
|
t := w.logicalTail |
||||
|
w.mu.Unlock() |
||||
|
return t |
||||
|
} |
||||
|
|
||||
|
// Sync fsyncs the underlying file descriptor.
|
||||
|
func (w *WALWriter) Sync() error { |
||||
|
return w.fd.Sync() |
||||
|
} |
||||
@ -0,0 +1,244 @@ |
|||||
|
package blockvol |
||||
|
|
||||
|
import ( |
||||
|
"os" |
||||
|
"path/filepath" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestWALWriter(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
name string |
||||
|
run func(t *testing.T) |
||||
|
}{ |
||||
|
{name: "wal_writer_append_read_back", run: testWALWriterAppendReadBack}, |
||||
|
{name: "wal_writer_multiple_entries", run: testWALWriterMultipleEntries}, |
||||
|
{name: "wal_writer_wrap_around", run: testWALWriterWrapAround}, |
||||
|
{name: "wal_writer_full", run: testWALWriterFull}, |
||||
|
{name: "wal_writer_advance_tail_frees_space", run: testWALWriterAdvanceTailFreesSpace}, |
||||
|
{name: "wal_writer_fill_no_flusher", run: testWALWriterFillNoFlusher}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
t.Run(tt.name, func(t *testing.T) { |
||||
|
tt.run(t) |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// createTestWAL creates a temp file with a WAL region at the given offset and size.
|
||||
|
func createTestWAL(t *testing.T, walOffset, walSize uint64) (*os.File, func()) { |
||||
|
t.Helper() |
||||
|
dir := t.TempDir() |
||||
|
path := filepath.Join(dir, "test.blockvol") |
||||
|
fd, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) |
||||
|
if err != nil { |
||||
|
t.Fatalf("create temp file: %v", err) |
||||
|
} |
||||
|
// Extend file to cover superblock + WAL region.
|
||||
|
totalSize := int64(walOffset + walSize) |
||||
|
if err := fd.Truncate(totalSize); err != nil { |
||||
|
fd.Close() |
||||
|
t.Fatalf("truncate: %v", err) |
||||
|
} |
||||
|
return fd, func() { fd.Close() } |
||||
|
} |
||||
|
|
||||
|
func testWALWriterAppendReadBack(t *testing.T) { |
||||
|
walOffset := uint64(SuperblockSize) |
||||
|
walSize := uint64(64 * 1024) // 64KB WAL
|
||||
|
fd, cleanup := createTestWAL(t, walOffset, walSize) |
||||
|
defer cleanup() |
||||
|
|
||||
|
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
||||
|
|
||||
|
data := []byte("hello WAL writer test!") |
||||
|
// Pad to block size for a realistic entry.
|
||||
|
padded := make([]byte, 4096) |
||||
|
copy(padded, data) |
||||
|
|
||||
|
entry := &WALEntry{ |
||||
|
LSN: 1, |
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: 0, |
||||
|
Length: uint32(len(padded)), |
||||
|
Data: padded, |
||||
|
} |
||||
|
|
||||
|
off, err := w.Append(entry) |
||||
|
if err != nil { |
||||
|
t.Fatalf("Append: %v", err) |
||||
|
} |
||||
|
if off != 0 { |
||||
|
t.Errorf("first entry offset = %d, want 0", off) |
||||
|
} |
||||
|
|
||||
|
// Read back from file and decode.
|
||||
|
buf := make([]byte, entry.EntrySize) |
||||
|
if _, err := fd.ReadAt(buf, int64(walOffset+off)); err != nil { |
||||
|
t.Fatalf("ReadAt: %v", err) |
||||
|
} |
||||
|
decoded, err := DecodeWALEntry(buf) |
||||
|
if err != nil { |
||||
|
t.Fatalf("DecodeWALEntry: %v", err) |
||||
|
} |
||||
|
if decoded.LSN != 1 { |
||||
|
t.Errorf("decoded LSN = %d, want 1", decoded.LSN) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALWriterMultipleEntries(t *testing.T) { |
||||
|
walOffset := uint64(SuperblockSize) |
||||
|
walSize := uint64(64 * 1024) |
||||
|
fd, cleanup := createTestWAL(t, walOffset, walSize) |
||||
|
defer cleanup() |
||||
|
|
||||
|
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
||||
|
|
||||
|
for i := uint64(1); i <= 5; i++ { |
||||
|
entry := &WALEntry{ |
||||
|
LSN: i, |
||||
|
Type: EntryTypeWrite, |
||||
|
LBA: i * 10, |
||||
|
Length: 4096, |
||||
|
Data: make([]byte, 4096), |
||||
|
} |
||||
|
_, err := w.Append(entry) |
||||
|
if err != nil { |
||||
|
t.Fatalf("Append entry %d: %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
expectedHead := uint64(5 * (walEntryHeaderSize + 4096)) |
||||
|
if w.Head() != expectedHead { |
||||
|
t.Errorf("head = %d, want %d", w.Head(), expectedHead) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALWriterWrapAround(t *testing.T) { |
||||
|
walOffset := uint64(SuperblockSize) |
||||
|
// Small WAL: 2.5x entry size. After writing 2 entries (head at 2*entrySize),
|
||||
|
// remaining (0.5*entrySize) is too small for entry 3, so it wraps to offset 0.
|
||||
|
// Tail must be advanced past both entries so there's free space after wrap.
|
||||
|
entrySize := uint64(walEntryHeaderSize + 4096) |
||||
|
walSize := entrySize*2 + entrySize/2 |
||||
|
|
||||
|
fd, cleanup := createTestWAL(t, walOffset, walSize) |
||||
|
defer cleanup() |
||||
|
|
||||
|
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
||||
|
|
||||
|
// Write 2 entries to fill most of the WAL.
|
||||
|
for i := uint64(1); i <= 2; i++ { |
||||
|
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)} |
||||
|
if _, err := w.Append(entry); err != nil { |
||||
|
t.Fatalf("Append entry %d: %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Advance tail past both entries (simulates flusher flushed them).
|
||||
|
w.AdvanceTail(entrySize * 2) |
||||
|
|
||||
|
// Write a 3rd entry -- should wrap around.
|
||||
|
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 3, Length: 4096, Data: make([]byte, 4096)} |
||||
|
off, err := w.Append(entry3) |
||||
|
if err != nil { |
||||
|
t.Fatalf("Append entry 3 (wrap): %v", err) |
||||
|
} |
||||
|
|
||||
|
// After wrap, entry should be written at offset 0.
|
||||
|
if off != 0 { |
||||
|
t.Errorf("wrapped entry offset = %d, want 0", off) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALWriterFull(t *testing.T) { |
||||
|
walOffset := uint64(SuperblockSize) |
||||
|
entrySize := uint64(walEntryHeaderSize + 4096) |
||||
|
walSize := entrySize * 2 // fits exactly 2 entries
|
||||
|
|
||||
|
fd, cleanup := createTestWAL(t, walOffset, walSize) |
||||
|
defer cleanup() |
||||
|
|
||||
|
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
||||
|
|
||||
|
// Fill the WAL with 2 entries (exact fit).
|
||||
|
for i := uint64(1); i <= 2; i++ { |
||||
|
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)} |
||||
|
if _, err := w.Append(entry); err != nil { |
||||
|
t.Fatalf("Append entry %d: %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Third entry should fail -- tail hasn't moved, so no free space.
|
||||
|
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)} |
||||
|
_, err := w.Append(entry3) |
||||
|
if err == nil { |
||||
|
t.Fatal("expected ErrWALFull when WAL is full") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALWriterAdvanceTailFreesSpace(t *testing.T) { |
||||
|
walOffset := uint64(SuperblockSize) |
||||
|
entrySize := uint64(walEntryHeaderSize + 4096) |
||||
|
walSize := entrySize * 2 |
||||
|
|
||||
|
fd, cleanup := createTestWAL(t, walOffset, walSize) |
||||
|
defer cleanup() |
||||
|
|
||||
|
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
||||
|
|
||||
|
// Fill the WAL with 2 entries.
|
||||
|
for i := uint64(1); i <= 2; i++ { |
||||
|
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i - 1, Length: 4096, Data: make([]byte, 4096)} |
||||
|
if _, err := w.Append(entry); err != nil { |
||||
|
t.Fatalf("Append entry %d: %v", i, err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// WAL is full. Third entry should fail.
|
||||
|
entry3 := &WALEntry{LSN: 3, Type: EntryTypeWrite, LBA: 2, Length: 4096, Data: make([]byte, 4096)} |
||||
|
if _, err := w.Append(entry3); err == nil { |
||||
|
t.Fatal("expected ErrWALFull before AdvanceTail") |
||||
|
} |
||||
|
|
||||
|
// Advance tail to free space for 1 entry.
|
||||
|
w.AdvanceTail(entrySize) |
||||
|
|
||||
|
// Now entry 3 should succeed.
|
||||
|
if _, err := w.Append(entry3); err != nil { |
||||
|
t.Fatalf("Append after AdvanceTail: %v", err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testWALWriterFillNoFlusher(t *testing.T) { |
||||
|
// QA-001 regression: fill WAL without flusher (tail stays at 0).
|
||||
|
// After wrap, head=0 and tail=0 must NOT be treated as "empty".
|
||||
|
walOffset := uint64(SuperblockSize) |
||||
|
entrySize := uint64(walEntryHeaderSize + 4096) |
||||
|
walSize := entrySize * 10 // room for ~10 entries
|
||||
|
|
||||
|
fd, cleanup := createTestWAL(t, walOffset, walSize) |
||||
|
defer cleanup() |
||||
|
|
||||
|
w := NewWALWriter(fd, walOffset, walSize, 0, 0) |
||||
|
|
||||
|
// Fill the WAL completely -- tail never moves (no flusher).
|
||||
|
written := 0 |
||||
|
for i := uint64(1); ; i++ { |
||||
|
entry := &WALEntry{LSN: i, Type: EntryTypeWrite, LBA: i, Length: 4096, Data: make([]byte, 4096)} |
||||
|
_, err := w.Append(entry) |
||||
|
if err != nil { |
||||
|
// Should eventually get ErrWALFull, NOT wrap and overwrite.
|
||||
|
break |
||||
|
} |
||||
|
written++ |
||||
|
if written > 20 { |
||||
|
t.Fatalf("wrote %d entries to a 10-entry WAL without ErrWALFull -- wrap overwrote live entries (QA-001)", written) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if written == 0 { |
||||
|
t.Fatal("should have written at least 1 entry") |
||||
|
} |
||||
|
t.Logf("correctly wrote %d entries then got ErrWALFull", written) |
||||
|
} |
||||
Write
Preview
Loading…
Cancel
Save
Reference in new issue