diff --git a/sw-block/engine/replication/budget.go b/sw-block/engine/replication/budget.go new file mode 100644 index 000000000..cf62d9ea7 --- /dev/null +++ b/sw-block/engine/replication/budget.go @@ -0,0 +1,46 @@ +package replication + +// CatchUpBudget defines the bounded resource contract for a catch-up session. +// When any limit is exceeded, the session escalates to NeedsRebuild. +// A zero value for any field means "no limit" for that dimension. +// +// Note: the frozen catch-up target is on Session.FrozenTargetLSN, not here. +// FrozenTargetLSN is set unconditionally by BeginCatchUp and enforced by +// RecordCatchUpProgress regardless of budget presence. +type CatchUpBudget struct { + MaxDurationTicks uint64 // hard time limit + MaxEntries uint64 // max WAL entries to replay + ProgressDeadlineTicks uint64 // stall detection window +} + +// BudgetCheck tracks runtime budget consumption. +type BudgetCheck struct { + StartTick uint64 + EntriesReplayed uint64 + LastProgressTick uint64 +} + +// BudgetViolation identifies which budget limit was exceeded. +type BudgetViolation string + +const ( + BudgetOK BudgetViolation = "" + BudgetDurationExceeded BudgetViolation = "duration_exceeded" + BudgetEntriesExceeded BudgetViolation = "entries_exceeded" + BudgetProgressStalled BudgetViolation = "progress_stalled" +) + +// Check evaluates the budget against the current tick. +func (b *CatchUpBudget) Check(tracker BudgetCheck, currentTick uint64) BudgetViolation { + if b.MaxDurationTicks > 0 && currentTick-tracker.StartTick > b.MaxDurationTicks { + return BudgetDurationExceeded + } + if b.MaxEntries > 0 && tracker.EntriesReplayed > b.MaxEntries { + return BudgetEntriesExceeded + } + if b.ProgressDeadlineTicks > 0 && tracker.LastProgressTick > 0 && + currentTick-tracker.LastProgressTick > b.ProgressDeadlineTicks { + return BudgetProgressStalled + } + return BudgetOK +} diff --git a/sw-block/engine/replication/doc.go b/sw-block/engine/replication/doc.go new file mode 100644 index 000000000..86a14c942 --- /dev/null +++ b/sw-block/engine/replication/doc.go @@ -0,0 +1,24 @@ +// Package replication implements V2 per-replica sender/session ownership. +// +// This is the real V2 engine core, promoted from the prototype at +// sw-block/prototype/enginev2/. It preserves all accepted invariants: +// +// - One stable Sender per replica, identified by ReplicaID +// - One active Session per replica per epoch +// - Session identity fencing: stale sessionID rejected at every execution API +// - Endpoint change invalidates active session +// - Epoch bump invalidates all stale-epoch sessions +// - Catch-up is bounded (frozen target, budget enforcement) +// - Rebuild is a separate, exclusive sender-owned execution path +// - Completion requires convergence (catch-up) or ReadyToComplete (rebuild) +// +// File layout (Slice 1): +// +// types.go — Endpoint, ReplicaState, SessionKind, SessionPhase +// sender.go — Sender: per-replica owner with execution APIs +// session.go — Session: recovery lifecycle with FSM phases +// registry.go — Registry: sender group with reconcile + assignment intent +// budget.go — CatchUpBudget: bounded catch-up enforcement +// rebuild.go — RebuildState: rebuild execution FSM +// outcome.go — HandshakeResult, RecoveryOutcome classification +package replication diff --git a/sw-block/engine/replication/go.mod b/sw-block/engine/replication/go.mod new file mode 100644 index 000000000..788b4617c --- /dev/null +++ b/sw-block/engine/replication/go.mod @@ -0,0 +1,3 @@ +module github.com/seaweedfs/seaweedfs/sw-block/engine/replication + +go 1.23.0 diff --git a/sw-block/engine/replication/outcome.go b/sw-block/engine/replication/outcome.go new file mode 100644 index 000000000..182910a57 --- /dev/null +++ b/sw-block/engine/replication/outcome.go @@ -0,0 +1,30 @@ +package replication + +// HandshakeResult captures what the reconnect handshake reveals about +// a replica's state relative to the primary's lineage-safe boundary. +type HandshakeResult struct { + ReplicaFlushedLSN uint64 // highest LSN durably persisted on replica + CommittedLSN uint64 // lineage-safe recovery target + RetentionStartLSN uint64 // oldest LSN still available in primary WAL +} + +// RecoveryOutcome classifies the gap between replica and primary. +type RecoveryOutcome string + +const ( + OutcomeZeroGap RecoveryOutcome = "zero_gap" + OutcomeCatchUp RecoveryOutcome = "catchup" + OutcomeNeedsRebuild RecoveryOutcome = "needs_rebuild" +) + +// ClassifyRecoveryOutcome determines the recovery path from handshake data. +// Zero-gap requires exact equality (FlushedLSN == CommittedLSN). +func ClassifyRecoveryOutcome(result HandshakeResult) RecoveryOutcome { + if result.ReplicaFlushedLSN == result.CommittedLSN { + return OutcomeZeroGap + } + if result.RetentionStartLSN == 0 || result.ReplicaFlushedLSN+1 >= result.RetentionStartLSN { + return OutcomeCatchUp + } + return OutcomeNeedsRebuild +} diff --git a/sw-block/engine/replication/ownership_test.go b/sw-block/engine/replication/ownership_test.go new file mode 100644 index 000000000..bee6e3d81 --- /dev/null +++ b/sw-block/engine/replication/ownership_test.go @@ -0,0 +1,312 @@ +package replication + +import "testing" + +// ============================================================ +// Phase 05 Slice 1: Engine ownership/fencing tests +// Mapped to V2 acceptance criteria and boundary cases. +// ============================================================ + +// --- Changed-address invalidation (A10) --- + +func TestEngine_ChangedAddress_InvalidatesSession(t *testing.T) { + r := NewRegistry() + r.ApplyAssignment(AssignmentIntent{ + Endpoints: map[string]Endpoint{ + "r1:9333": {DataAddr: "r1:9333", CtrlAddr: "r1:9334", Version: 1}, + }, + Epoch: 1, + RecoveryTargets: map[string]SessionKind{"r1:9333": SessionCatchUp}, + }) + + s := r.Sender("r1:9333") + sess := s.Session() + s.BeginConnect(sess.ID) + + // Address changes mid-recovery. + r.Reconcile(map[string]Endpoint{ + "r1:9333": {DataAddr: "r1:9333", CtrlAddr: "r1:9445", Version: 2}, + }, 1) + + if sess.Active() { + t.Fatal("session should be invalidated by address change") + } + if r.Sender("r1:9333") != s { + t.Fatal("sender identity should be preserved") + } + if s.State != StateDisconnected { + t.Fatalf("state=%s, want disconnected", s.State) + } +} + +func TestEngine_ChangedAddress_NewSessionAfterUpdate(t *testing.T) { + r := NewRegistry() + r.ApplyAssignment(AssignmentIntent{ + Endpoints: map[string]Endpoint{ + "r1:9333": {DataAddr: "r1:9333", Version: 1}, + }, + Epoch: 1, + RecoveryTargets: map[string]SessionKind{"r1:9333": SessionCatchUp}, + }) + + s := r.Sender("r1:9333") + oldSess := s.Session() + s.BeginConnect(oldSess.ID) + + // Address change + new assignment. + r.Reconcile(map[string]Endpoint{ + "r1:9333": {DataAddr: "r1:9333", Version: 2}, + }, 1) + result := r.ApplyAssignment(AssignmentIntent{ + Endpoints: map[string]Endpoint{"r1:9333": {DataAddr: "r1:9333", Version: 2}}, + Epoch: 1, + RecoveryTargets: map[string]SessionKind{"r1:9333": SessionCatchUp}, + }) + + if len(result.SessionsCreated) != 1 { + t.Fatalf("should create new session: %v", result) + } + newSess := s.Session() + if newSess.ID == oldSess.ID { + t.Fatal("new session should have different ID") + } +} + +// --- Stale-session rejection (A3) --- + +func TestEngine_StaleSessionID_RejectedAtAllAPIs(t *testing.T) { + s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) + sess, _ := s.AttachSession(1, SessionCatchUp) + staleID := sess.ID + + s.UpdateEpoch(2) + s.AttachSession(2, SessionCatchUp) + + if err := s.BeginConnect(staleID); err == nil { + t.Fatal("stale ID: BeginConnect should reject") + } + if err := s.RecordHandshake(staleID, 0, 10); err == nil { + t.Fatal("stale ID: RecordHandshake should reject") + } + if err := s.BeginCatchUp(staleID); err == nil { + t.Fatal("stale ID: BeginCatchUp should reject") + } + if err := s.RecordCatchUpProgress(staleID, 5); err == nil { + t.Fatal("stale ID: RecordCatchUpProgress should reject") + } + if s.CompleteSessionByID(staleID) { + t.Fatal("stale ID: CompleteSessionByID should reject") + } +} + +func TestEngine_StaleCompletion_AfterSupersede(t *testing.T) { + s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) + sess1, _ := s.AttachSession(1, SessionCatchUp) + id1 := sess1.ID + + s.UpdateEpoch(2) + sess2, _ := s.AttachSession(2, SessionCatchUp) + + // Old session completion rejected. + if s.CompleteSessionByID(id1) { + t.Fatal("stale completion must be rejected") + } + // New session still active. + if !sess2.Active() { + t.Fatal("new session should be active") + } + // Sender not moved to InSync. + if s.State == StateInSync { + t.Fatal("sender should not be InSync from stale completion") + } +} + +// --- Epoch-bump invalidation (A3) --- + +func TestEngine_EpochBump_InvalidatesAllSessions(t *testing.T) { + r := NewRegistry() + r.ApplyAssignment(AssignmentIntent{ + Endpoints: map[string]Endpoint{ + "r1:9333": {DataAddr: "r1:9333", Version: 1}, + "r2:9333": {DataAddr: "r2:9333", Version: 1}, + }, + Epoch: 1, + RecoveryTargets: map[string]SessionKind{ + "r1:9333": SessionCatchUp, + "r2:9333": SessionCatchUp, + }, + }) + + s1 := r.Sender("r1:9333") + s2 := r.Sender("r2:9333") + sess1 := s1.Session() + sess2 := s2.Session() + + count := r.InvalidateEpoch(2) + if count != 2 { + t.Fatalf("should invalidate 2, got %d", count) + } + if sess1.Active() || sess2.Active() { + t.Fatal("both sessions should be invalidated") + } +} + +func TestEngine_EpochBump_StaleAssignment_Rejected(t *testing.T) { + r := NewRegistry() + r.ApplyAssignment(AssignmentIntent{ + Endpoints: map[string]Endpoint{ + "r1:9333": {DataAddr: "r1:9333", Version: 1}, + }, + Epoch: 2, + }) + + result := r.ApplyAssignment(AssignmentIntent{ + Endpoints: map[string]Endpoint{"r1:9333": {DataAddr: "r1:9333", Version: 1}}, + Epoch: 1, + RecoveryTargets: map[string]SessionKind{"r1:9333": SessionCatchUp}, + }) + + if len(result.SessionsFailed) != 1 { + t.Fatalf("stale epoch should fail: %v", result) + } +} + +// --- Rebuild exclusivity --- + +func TestEngine_Rebuild_CatchUpAPIs_Rejected(t *testing.T) { + s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) + sess, _ := s.AttachSession(1, SessionRebuild) + s.BeginConnect(sess.ID) + s.RecordHandshake(sess.ID, 0, 100) + + if err := s.BeginCatchUp(sess.ID); err == nil { + t.Fatal("rebuild: BeginCatchUp should be rejected") + } + if err := s.RecordCatchUpProgress(sess.ID, 50); err == nil { + t.Fatal("rebuild: RecordCatchUpProgress should be rejected") + } + if s.CompleteSessionByID(sess.ID) { + t.Fatal("rebuild: catch-up completion should be rejected") + } +} + +func TestEngine_Rebuild_FullLifecycle(t *testing.T) { + s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) + sess, _ := s.AttachSession(1, SessionRebuild) + + s.BeginConnect(sess.ID) + s.RecordHandshake(sess.ID, 0, 100) + s.SelectRebuildSource(sess.ID, 50, true, 100) + s.BeginRebuildTransfer(sess.ID) + s.RecordRebuildTransferProgress(sess.ID, 50) + s.BeginRebuildTailReplay(sess.ID) + s.RecordRebuildTailProgress(sess.ID, 100) + + if err := s.CompleteRebuild(sess.ID); err != nil { + t.Fatalf("rebuild completion: %v", err) + } + if s.State != StateInSync { + t.Fatalf("state=%s, want in_sync", s.State) + } +} + +// --- Bounded catch-up --- + +func TestEngine_FrozenTarget_RejectsChase(t *testing.T) { + s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) + sess, _ := s.AttachSession(1, SessionCatchUp) + + s.BeginConnect(sess.ID) + s.RecordHandshake(sess.ID, 0, 50) + s.BeginCatchUp(sess.ID) + + if err := s.RecordCatchUpProgress(sess.ID, 51); err == nil { + t.Fatal("progress beyond frozen target should be rejected") + } +} + +func TestEngine_BudgetViolation_Escalates(t *testing.T) { + s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) + sess, _ := s.AttachSession(1, SessionCatchUp) + sess.Budget = &CatchUpBudget{MaxDurationTicks: 5} + + s.BeginConnect(sess.ID) + s.RecordHandshake(sess.ID, 0, 100) + s.BeginCatchUp(sess.ID, 0) + s.RecordCatchUpProgress(sess.ID, 10) + + v, _ := s.CheckBudget(sess.ID, 10) + if v != BudgetDurationExceeded { + t.Fatalf("budget=%s", v) + } + if s.State != StateNeedsRebuild { + t.Fatalf("state=%s", s.State) + } +} + +// --- E2E: 3 replicas, 3 outcomes --- + +func TestEngine_E2E_ThreeReplicas_ThreeOutcomes(t *testing.T) { + r := NewRegistry() + r.ApplyAssignment(AssignmentIntent{ + Endpoints: map[string]Endpoint{ + "r1:9333": {DataAddr: "r1:9333", Version: 1}, + "r2:9333": {DataAddr: "r2:9333", Version: 1}, + "r3:9333": {DataAddr: "r3:9333", Version: 1}, + }, + Epoch: 1, + RecoveryTargets: map[string]SessionKind{ + "r1:9333": SessionCatchUp, + "r2:9333": SessionCatchUp, + "r3:9333": SessionCatchUp, + }, + }) + + // r1: zero-gap. + r1 := r.Sender("r1:9333") + s1 := r1.Session() + r1.BeginConnect(s1.ID) + o1, _ := r1.RecordHandshakeWithOutcome(s1.ID, HandshakeResult{ + ReplicaFlushedLSN: 100, CommittedLSN: 100, RetentionStartLSN: 50, + }) + if o1 != OutcomeZeroGap { + t.Fatalf("r1: %s", o1) + } + r1.CompleteSessionByID(s1.ID) + + // r2: catch-up. + r2 := r.Sender("r2:9333") + s2 := r2.Session() + r2.BeginConnect(s2.ID) + o2, _ := r2.RecordHandshakeWithOutcome(s2.ID, HandshakeResult{ + ReplicaFlushedLSN: 70, CommittedLSN: 100, RetentionStartLSN: 50, + }) + if o2 != OutcomeCatchUp { + t.Fatalf("r2: %s", o2) + } + r2.BeginCatchUp(s2.ID) + r2.RecordCatchUpProgress(s2.ID, 100) + r2.CompleteSessionByID(s2.ID) + + // r3: needs rebuild. + r3 := r.Sender("r3:9333") + s3 := r3.Session() + r3.BeginConnect(s3.ID) + o3, _ := r3.RecordHandshakeWithOutcome(s3.ID, HandshakeResult{ + ReplicaFlushedLSN: 10, CommittedLSN: 100, RetentionStartLSN: 50, + }) + if o3 != OutcomeNeedsRebuild { + t.Fatalf("r3: %s", o3) + } + + // Final states. + if r1.State != StateInSync || r2.State != StateInSync { + t.Fatalf("r1=%s r2=%s", r1.State, r2.State) + } + if r3.State != StateNeedsRebuild { + t.Fatalf("r3=%s", r3.State) + } + if r.InSyncCount() != 2 { + t.Fatalf("in_sync=%d", r.InSyncCount()) + } +} diff --git a/sw-block/engine/replication/rebuild.go b/sw-block/engine/replication/rebuild.go new file mode 100644 index 000000000..abee0cb37 --- /dev/null +++ b/sw-block/engine/replication/rebuild.go @@ -0,0 +1,127 @@ +package replication + +import "fmt" + +// RebuildSource identifies the recovery base. +type RebuildSource string + +const ( + RebuildSnapshotTail RebuildSource = "snapshot_tail" + RebuildFullBase RebuildSource = "full_base" +) + +// RebuildPhase tracks rebuild execution progress. +type RebuildPhase string + +const ( + RebuildPhaseInit RebuildPhase = "init" + RebuildPhaseSourceSelect RebuildPhase = "source_select" + RebuildPhaseTransfer RebuildPhase = "transfer" + RebuildPhaseTailReplay RebuildPhase = "tail_replay" + RebuildPhaseCompleted RebuildPhase = "completed" + RebuildPhaseAborted RebuildPhase = "aborted" +) + +// RebuildState tracks rebuild execution. Owned by Session. +type RebuildState struct { + Source RebuildSource + Phase RebuildPhase + AbortReason string + SnapshotLSN uint64 + SnapshotValid bool + TransferredTo uint64 + TailStartLSN uint64 + TailTargetLSN uint64 + TailReplayedTo uint64 +} + +// NewRebuildState creates a rebuild state in init phase. +func NewRebuildState() *RebuildState { + return &RebuildState{Phase: RebuildPhaseInit} +} + +// SelectSource chooses rebuild source based on snapshot availability. +func (rs *RebuildState) SelectSource(snapshotLSN uint64, snapshotValid bool, committedLSN uint64) error { + if rs.Phase != RebuildPhaseInit { + return fmt.Errorf("rebuild: source select requires init phase, got %s", rs.Phase) + } + rs.SnapshotLSN = snapshotLSN + rs.SnapshotValid = snapshotValid + rs.Phase = RebuildPhaseSourceSelect + if snapshotValid && snapshotLSN > 0 { + rs.Source = RebuildSnapshotTail + rs.TailStartLSN = snapshotLSN + rs.TailTargetLSN = committedLSN + } else { + rs.Source = RebuildFullBase + rs.TailTargetLSN = committedLSN + } + return nil +} + +func (rs *RebuildState) BeginTransfer() error { + if rs.Phase != RebuildPhaseSourceSelect { + return fmt.Errorf("rebuild: transfer requires source_select, got %s", rs.Phase) + } + rs.Phase = RebuildPhaseTransfer + return nil +} + +func (rs *RebuildState) RecordTransferProgress(transferredTo uint64) error { + if rs.Phase != RebuildPhaseTransfer { + return fmt.Errorf("rebuild: progress requires transfer, got %s", rs.Phase) + } + if transferredTo <= rs.TransferredTo { + return fmt.Errorf("rebuild: transfer regression") + } + rs.TransferredTo = transferredTo + return nil +} + +func (rs *RebuildState) BeginTailReplay() error { + if rs.Phase != RebuildPhaseTransfer { + return fmt.Errorf("rebuild: tail replay requires transfer, got %s", rs.Phase) + } + if rs.Source != RebuildSnapshotTail { + return fmt.Errorf("rebuild: tail replay only for snapshot_tail") + } + rs.Phase = RebuildPhaseTailReplay + return nil +} + +func (rs *RebuildState) RecordTailReplayProgress(replayedTo uint64) error { + if rs.Phase != RebuildPhaseTailReplay { + return fmt.Errorf("rebuild: tail progress requires tail_replay, got %s", rs.Phase) + } + if replayedTo <= rs.TailReplayedTo { + return fmt.Errorf("rebuild: tail regression") + } + rs.TailReplayedTo = replayedTo + return nil +} + +func (rs *RebuildState) ReadyToComplete() bool { + switch rs.Source { + case RebuildSnapshotTail: + return rs.Phase == RebuildPhaseTailReplay && rs.TailReplayedTo >= rs.TailTargetLSN + case RebuildFullBase: + return rs.Phase == RebuildPhaseTransfer && rs.TransferredTo >= rs.TailTargetLSN + } + return false +} + +func (rs *RebuildState) Complete() error { + if !rs.ReadyToComplete() { + return fmt.Errorf("rebuild: not ready (source=%s phase=%s)", rs.Source, rs.Phase) + } + rs.Phase = RebuildPhaseCompleted + return nil +} + +func (rs *RebuildState) Abort(reason string) { + if rs.Phase == RebuildPhaseCompleted || rs.Phase == RebuildPhaseAborted { + return + } + rs.Phase = RebuildPhaseAborted + rs.AbortReason = reason +} diff --git a/sw-block/engine/replication/registry.go b/sw-block/engine/replication/registry.go new file mode 100644 index 000000000..27b0a0bf8 --- /dev/null +++ b/sw-block/engine/replication/registry.go @@ -0,0 +1,160 @@ +package replication + +import ( + "sort" + "sync" +) + +// AssignmentIntent represents a coordinator-driven assignment update. +type AssignmentIntent struct { + Endpoints map[string]Endpoint + Epoch uint64 + RecoveryTargets map[string]SessionKind +} + +// AssignmentResult records the outcome of applying an assignment. +type AssignmentResult struct { + Added []string + Removed []string + SessionsCreated []string + SessionsSuperseded []string + SessionsFailed []string +} + +// Registry manages per-replica Senders with identity-preserving reconciliation. +type Registry struct { + mu sync.RWMutex + senders map[string]*Sender +} + +// NewRegistry creates an empty Registry. +func NewRegistry() *Registry { + return &Registry{senders: map[string]*Sender{}} +} + +// Reconcile diffs current senders against new endpoints. +func (r *Registry) Reconcile(endpoints map[string]Endpoint, epoch uint64) (added, removed []string) { + r.mu.Lock() + defer r.mu.Unlock() + + for id, s := range r.senders { + if _, keep := endpoints[id]; !keep { + s.Stop() + delete(r.senders, id) + removed = append(removed, id) + } + } + for id, ep := range endpoints { + if existing, ok := r.senders[id]; ok { + existing.UpdateEndpoint(ep) + existing.UpdateEpoch(epoch) + } else { + r.senders[id] = NewSender(id, ep, epoch) + added = append(added, id) + } + } + sort.Strings(added) + sort.Strings(removed) + return +} + +// ApplyAssignment reconciles topology and creates recovery sessions. +func (r *Registry) ApplyAssignment(intent AssignmentIntent) AssignmentResult { + var result AssignmentResult + result.Added, result.Removed = r.Reconcile(intent.Endpoints, intent.Epoch) + + if intent.RecoveryTargets == nil { + return result + } + + r.mu.RLock() + defer r.mu.RUnlock() + for replicaID, kind := range intent.RecoveryTargets { + sender, ok := r.senders[replicaID] + if !ok { + result.SessionsFailed = append(result.SessionsFailed, replicaID) + continue + } + if intent.Epoch < sender.Epoch { + result.SessionsFailed = append(result.SessionsFailed, replicaID) + continue + } + _, err := sender.AttachSession(intent.Epoch, kind) + if err != nil { + sess := sender.SupersedeSession(kind, "assignment_intent") + if sess != nil { + result.SessionsSuperseded = append(result.SessionsSuperseded, replicaID) + } else { + result.SessionsFailed = append(result.SessionsFailed, replicaID) + } + continue + } + result.SessionsCreated = append(result.SessionsCreated, replicaID) + } + return result +} + +// Sender returns the sender for a ReplicaID. +func (r *Registry) Sender(replicaID string) *Sender { + r.mu.RLock() + defer r.mu.RUnlock() + return r.senders[replicaID] +} + +// All returns all senders in deterministic order. +func (r *Registry) All() []*Sender { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]*Sender, 0, len(r.senders)) + for _, s := range r.senders { + out = append(out, s) + } + sort.Slice(out, func(i, j int) bool { + return out[i].ReplicaID < out[j].ReplicaID + }) + return out +} + +// Len returns the sender count. +func (r *Registry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.senders) +} + +// StopAll stops all senders. +func (r *Registry) StopAll() { + r.mu.Lock() + defer r.mu.Unlock() + for _, s := range r.senders { + s.Stop() + } +} + +// InSyncCount returns the number of InSync senders. +func (r *Registry) InSyncCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + count := 0 + for _, s := range r.senders { + if s.State == StateInSync { + count++ + } + } + return count +} + +// InvalidateEpoch invalidates all stale-epoch sessions. +func (r *Registry) InvalidateEpoch(currentEpoch uint64) int { + r.mu.RLock() + defer r.mu.RUnlock() + count := 0 + for _, s := range r.senders { + sess := s.Session() + if sess != nil && sess.Epoch < currentEpoch && sess.Active() { + s.InvalidateSession("epoch_bump", StateDisconnected) + count++ + } + } + return count +} diff --git a/sw-block/engine/replication/sender.go b/sw-block/engine/replication/sender.go new file mode 100644 index 000000000..236c2c6a0 --- /dev/null +++ b/sw-block/engine/replication/sender.go @@ -0,0 +1,430 @@ +package replication + +import ( + "fmt" + "sync" +) + +// Sender owns the replication channel to one replica. It is the authority +// boundary for all execution operations — every API validates the session +// ID before mutating state. +type Sender struct { + mu sync.Mutex + + ReplicaID string + Endpoint Endpoint + Epoch uint64 + State ReplicaState + + session *Session + stopped bool +} + +// NewSender creates a sender for a replica. +func NewSender(replicaID string, endpoint Endpoint, epoch uint64) *Sender { + return &Sender{ + ReplicaID: replicaID, + Endpoint: endpoint, + Epoch: epoch, + State: StateDisconnected, + } +} + +// UpdateEpoch advances the sender's epoch. Invalidates stale sessions. +func (s *Sender) UpdateEpoch(epoch uint64) { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped || epoch <= s.Epoch { + return + } + oldEpoch := s.Epoch + s.Epoch = epoch + if s.session != nil && s.session.Epoch < epoch { + s.session.invalidate(fmt.Sprintf("epoch_advanced_%d_to_%d", oldEpoch, epoch)) + s.session = nil + s.State = StateDisconnected + } +} + +// UpdateEndpoint updates the target address. Invalidates session on change. +func (s *Sender) UpdateEndpoint(ep Endpoint) { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped { + return + } + if s.Endpoint.Changed(ep) && s.session != nil { + s.session.invalidate("endpoint_changed") + s.session = nil + s.State = StateDisconnected + } + s.Endpoint = ep +} + +// AttachSession creates a new recovery session. Epoch must match sender epoch. +func (s *Sender) AttachSession(epoch uint64, kind SessionKind) (*Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped { + return nil, fmt.Errorf("sender stopped") + } + if epoch != s.Epoch { + return nil, fmt.Errorf("epoch mismatch: sender=%d session=%d", s.Epoch, epoch) + } + if s.session != nil && s.session.Active() { + return nil, fmt.Errorf("session already active (id=%d)", s.session.ID) + } + sess := newSession(s.ReplicaID, epoch, kind) + s.session = sess + return sess, nil +} + +// SupersedeSession invalidates current session and attaches new at sender epoch. +func (s *Sender) SupersedeSession(kind SessionKind, reason string) *Session { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped { + return nil + } + if s.session != nil { + s.session.invalidate(reason) + } + sess := newSession(s.ReplicaID, s.Epoch, kind) + s.session = sess + return sess +} + +// Session returns the current session, or nil. +func (s *Sender) Session() *Session { + s.mu.Lock() + defer s.mu.Unlock() + return s.session +} + +// Stopped returns true if the sender has been stopped. +func (s *Sender) Stopped() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.stopped +} + +// Stop shuts down the sender. +func (s *Sender) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped { + return + } + s.stopped = true + if s.session != nil { + s.session.invalidate("sender_stopped") + s.session = nil + } +} + +// InvalidateSession invalidates current session with target state. +func (s *Sender) InvalidateSession(reason string, targetState ReplicaState) { + s.mu.Lock() + defer s.mu.Unlock() + if s.session != nil { + s.session.invalidate(reason) + s.session = nil + } + s.State = targetState +} + +// === Catch-up execution APIs === + +// BeginConnect transitions init → connecting. +func (s *Sender) BeginConnect(sessionID uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if !s.session.Advance(PhaseConnecting) { + return fmt.Errorf("cannot begin connect: phase=%s", s.session.Phase) + } + s.State = StateConnecting + return nil +} + +// RecordHandshake records handshake result and sets catch-up range. +func (s *Sender) RecordHandshake(sessionID uint64, startLSN, targetLSN uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if targetLSN < startLSN { + return fmt.Errorf("invalid range: target=%d < start=%d", targetLSN, startLSN) + } + if !s.session.Advance(PhaseHandshake) { + return fmt.Errorf("cannot record handshake: phase=%s", s.session.Phase) + } + s.session.SetRange(startLSN, targetLSN) + return nil +} + +// RecordHandshakeWithOutcome records handshake and classifies the recovery outcome. +func (s *Sender) RecordHandshakeWithOutcome(sessionID uint64, result HandshakeResult) (RecoveryOutcome, error) { + outcome := ClassifyRecoveryOutcome(result) + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return outcome, err + } + if s.session.Phase != PhaseConnecting { + return outcome, fmt.Errorf("handshake requires connecting, got %s", s.session.Phase) + } + if outcome == OutcomeNeedsRebuild { + s.session.invalidate("gap_exceeds_retention") + s.session = nil + s.State = StateNeedsRebuild + return outcome, nil + } + if !s.session.Advance(PhaseHandshake) { + return outcome, fmt.Errorf("cannot advance to handshake: phase=%s", s.session.Phase) + } + switch outcome { + case OutcomeZeroGap: + s.session.SetRange(result.ReplicaFlushedLSN, result.ReplicaFlushedLSN) + case OutcomeCatchUp: + if result.ReplicaFlushedLSN > result.CommittedLSN { + s.session.TruncateRequired = true + s.session.TruncateToLSN = result.CommittedLSN + s.session.SetRange(result.CommittedLSN, result.CommittedLSN) + } else { + s.session.SetRange(result.ReplicaFlushedLSN, result.CommittedLSN) + } + } + return outcome, nil +} + +// BeginCatchUp transitions to catch-up phase. Rejects rebuild sessions. +// Freezes the target unconditionally. +func (s *Sender) BeginCatchUp(sessionID uint64, startTick ...uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Kind == SessionRebuild { + return fmt.Errorf("rebuild sessions must use rebuild APIs") + } + if !s.session.Advance(PhaseCatchUp) { + return fmt.Errorf("cannot begin catch-up: phase=%s", s.session.Phase) + } + s.State = StateCatchingUp + s.session.FrozenTargetLSN = s.session.TargetLSN + if len(startTick) > 0 { + s.session.Tracker.StartTick = startTick[0] + s.session.Tracker.LastProgressTick = startTick[0] + } + return nil +} + +// RecordCatchUpProgress records catch-up progress. Rejects rebuild sessions. +// Entry counting uses LSN delta. Tick is required when ProgressDeadlineTicks > 0. +func (s *Sender) RecordCatchUpProgress(sessionID uint64, recoveredTo uint64, tick ...uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Kind == SessionRebuild { + return fmt.Errorf("rebuild sessions must use rebuild APIs") + } + if s.session.Phase != PhaseCatchUp { + return fmt.Errorf("progress requires catchup phase, got %s", s.session.Phase) + } + if recoveredTo <= s.session.RecoveredTo { + return fmt.Errorf("progress regression: %d <= %d", recoveredTo, s.session.RecoveredTo) + } + if s.session.FrozenTargetLSN > 0 && recoveredTo > s.session.FrozenTargetLSN { + return fmt.Errorf("progress %d exceeds frozen target %d", recoveredTo, s.session.FrozenTargetLSN) + } + if s.session.Budget != nil && s.session.Budget.ProgressDeadlineTicks > 0 && len(tick) == 0 { + return fmt.Errorf("tick required when ProgressDeadlineTicks > 0") + } + delta := recoveredTo - s.session.RecoveredTo + s.session.Tracker.EntriesReplayed += delta + s.session.UpdateProgress(recoveredTo) + if len(tick) > 0 { + s.session.Tracker.LastProgressTick = tick[0] + } + return nil +} + +// RecordTruncation confirms divergent tail cleanup. +func (s *Sender) RecordTruncation(sessionID uint64, truncatedToLSN uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if !s.session.TruncateRequired { + return fmt.Errorf("truncation not required") + } + if truncatedToLSN != s.session.TruncateToLSN { + return fmt.Errorf("truncation LSN mismatch: expected %d, got %d", s.session.TruncateToLSN, truncatedToLSN) + } + s.session.TruncateRecorded = true + return nil +} + +// CompleteSessionByID completes catch-up sessions. Rejects rebuild sessions. +func (s *Sender) CompleteSessionByID(sessionID uint64) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.checkAuthority(sessionID) != nil { + return false + } + sess := s.session + if sess.Kind == SessionRebuild { + return false + } + if sess.TruncateRequired && !sess.TruncateRecorded { + return false + } + switch sess.Phase { + case PhaseCatchUp: + if !sess.Converged() { + return false + } + case PhaseHandshake: + if sess.TargetLSN != sess.StartLSN { + return false + } + default: + return false + } + sess.complete() + s.session = nil + s.State = StateInSync + return true +} + +// CheckBudget evaluates catch-up budget. Auto-escalates on violation. +func (s *Sender) CheckBudget(sessionID uint64, currentTick uint64) (BudgetViolation, error) { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return BudgetOK, err + } + if s.session.Budget == nil { + return BudgetOK, nil + } + v := s.session.Budget.Check(s.session.Tracker, currentTick) + if v != BudgetOK { + s.session.invalidate(fmt.Sprintf("budget_%s", v)) + s.session = nil + s.State = StateNeedsRebuild + } + return v, nil +} + +// === Rebuild execution APIs === + +// SelectRebuildSource chooses rebuild source. Requires PhaseHandshake. +func (s *Sender) SelectRebuildSource(sessionID uint64, snapshotLSN uint64, snapshotValid bool, committedLSN uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Kind != SessionRebuild { + return fmt.Errorf("not a rebuild session") + } + if s.session.Phase != PhaseHandshake { + return fmt.Errorf("requires PhaseHandshake, got %s", s.session.Phase) + } + if s.session.Rebuild == nil { + return fmt.Errorf("rebuild state not initialized") + } + return s.session.Rebuild.SelectSource(snapshotLSN, snapshotValid, committedLSN) +} + +func (s *Sender) BeginRebuildTransfer(sessionID uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Rebuild == nil { + return fmt.Errorf("no rebuild state") + } + return s.session.Rebuild.BeginTransfer() +} + +func (s *Sender) RecordRebuildTransferProgress(sessionID uint64, transferredTo uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Rebuild == nil { + return fmt.Errorf("no rebuild state") + } + return s.session.Rebuild.RecordTransferProgress(transferredTo) +} + +func (s *Sender) BeginRebuildTailReplay(sessionID uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Rebuild == nil { + return fmt.Errorf("no rebuild state") + } + return s.session.Rebuild.BeginTailReplay() +} + +func (s *Sender) RecordRebuildTailProgress(sessionID uint64, replayedTo uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Rebuild == nil { + return fmt.Errorf("no rebuild state") + } + return s.session.Rebuild.RecordTailReplayProgress(replayedTo) +} + +// CompleteRebuild completes a rebuild session. Requires ReadyToComplete. +func (s *Sender) CompleteRebuild(sessionID uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.checkAuthority(sessionID); err != nil { + return err + } + if s.session.Rebuild == nil { + return fmt.Errorf("no rebuild state") + } + if err := s.session.Rebuild.Complete(); err != nil { + return err + } + s.session.complete() + s.session = nil + s.State = StateInSync + return nil +} + +// checkAuthority validates session ownership. +func (s *Sender) checkAuthority(sessionID uint64) error { + if s.stopped { + return fmt.Errorf("sender stopped") + } + if s.session == nil { + return fmt.Errorf("no active session") + } + if s.session.ID != sessionID { + return fmt.Errorf("session ID mismatch: active=%d requested=%d", s.session.ID, sessionID) + } + if !s.session.Active() { + return fmt.Errorf("session %d not active (phase=%s)", sessionID, s.session.Phase) + } + return nil +} diff --git a/sw-block/engine/replication/session.go b/sw-block/engine/replication/session.go new file mode 100644 index 000000000..0d5d2989d --- /dev/null +++ b/sw-block/engine/replication/session.go @@ -0,0 +1,103 @@ +package replication + +import "sync/atomic" + +// sessionIDCounter generates globally unique session IDs. +var sessionIDCounter atomic.Uint64 + +// Session represents one recovery attempt for a specific replica at a +// specific epoch. It is owned by a Sender and gated by session ID at +// every execution step. +// +// Lifecycle: +// - Created via Sender.AttachSession or Sender.SupersedeSession +// - Advanced through phases: init → connecting → handshake → catchup → completed +// - Invalidated by: epoch bump, endpoint change, sender stop, timeout +// - Stale sessions (wrong ID) are rejected at every execution API +type Session struct { + ID uint64 + ReplicaID string + Epoch uint64 + Kind SessionKind + Phase SessionPhase + InvalidateReason string + + // Progress tracking. + StartLSN uint64 // gap start (exclusive) + TargetLSN uint64 // gap end (inclusive) + FrozenTargetLSN uint64 // frozen at BeginCatchUp — catch-up will not chase beyond + RecoveredTo uint64 // highest LSN recovered so far + + // Truncation. + TruncateRequired bool + TruncateToLSN uint64 + TruncateRecorded bool + + // Budget (nil = no enforcement). + Budget *CatchUpBudget + Tracker BudgetCheck + + // Rebuild state (non-nil when Kind == SessionRebuild). + Rebuild *RebuildState +} + +func newSession(replicaID string, epoch uint64, kind SessionKind) *Session { + s := &Session{ + ID: sessionIDCounter.Add(1), + ReplicaID: replicaID, + Epoch: epoch, + Kind: kind, + Phase: PhaseInit, + } + if kind == SessionRebuild { + s.Rebuild = NewRebuildState() + } + return s +} + +// Active returns true if the session is not completed or invalidated. +func (s *Session) Active() bool { + return s.Phase != PhaseCompleted && s.Phase != PhaseInvalidated +} + +// Advance moves to the next phase. Returns false if the transition is invalid. +func (s *Session) Advance(phase SessionPhase) bool { + if !s.Active() { + return false + } + if !validTransitions[s.Phase][phase] { + return false + } + s.Phase = phase + return true +} + +// SetRange sets the recovery LSN range. +func (s *Session) SetRange(start, target uint64) { + s.StartLSN = start + s.TargetLSN = target +} + +// UpdateProgress records catch-up progress (monotonic). +func (s *Session) UpdateProgress(recoveredTo uint64) { + if recoveredTo > s.RecoveredTo { + s.RecoveredTo = recoveredTo + } +} + +// Converged returns true if recovery reached the target. +func (s *Session) Converged() bool { + return s.TargetLSN > 0 && s.RecoveredTo >= s.TargetLSN +} + +func (s *Session) complete() { + s.Phase = PhaseCompleted +} + +func (s *Session) invalidate(reason string) { + if !s.Active() { + return + } + s.Phase = PhaseInvalidated + s.InvalidateReason = reason +} diff --git a/sw-block/engine/replication/types.go b/sw-block/engine/replication/types.go new file mode 100644 index 000000000..33db9785e --- /dev/null +++ b/sw-block/engine/replication/types.go @@ -0,0 +1,59 @@ +package replication + +// Endpoint represents a replica's network identity. Version is bumped on +// address change; the Sender uses version comparison (not string comparison +// alone) to detect endpoint changes. +type Endpoint struct { + DataAddr string + CtrlAddr string + Version uint64 +} + +// Changed reports whether ep differs from other in any address or version field. +func (ep Endpoint) Changed(other Endpoint) bool { + return ep.DataAddr != other.DataAddr || + ep.CtrlAddr != other.CtrlAddr || + ep.Version != other.Version +} + +// ReplicaState tracks the per-replica replication state machine. +type ReplicaState string + +const ( + StateDisconnected ReplicaState = "disconnected" + StateConnecting ReplicaState = "connecting" + StateCatchingUp ReplicaState = "catching_up" + StateInSync ReplicaState = "in_sync" + StateDegraded ReplicaState = "degraded" + StateNeedsRebuild ReplicaState = "needs_rebuild" +) + +// SessionKind identifies how the recovery session was created. +type SessionKind string + +const ( + SessionBootstrap SessionKind = "bootstrap" + SessionCatchUp SessionKind = "catchup" + SessionRebuild SessionKind = "rebuild" + SessionReassign SessionKind = "reassign" +) + +// SessionPhase tracks progress within a recovery session. +type SessionPhase string + +const ( + PhaseInit SessionPhase = "init" + PhaseConnecting SessionPhase = "connecting" + PhaseHandshake SessionPhase = "handshake" + PhaseCatchUp SessionPhase = "catchup" + PhaseCompleted SessionPhase = "completed" + PhaseInvalidated SessionPhase = "invalidated" +) + +// validTransitions defines the allowed phase transitions. +var validTransitions = map[SessionPhase]map[SessionPhase]bool{ + PhaseInit: {PhaseConnecting: true, PhaseInvalidated: true}, + PhaseConnecting: {PhaseHandshake: true, PhaseInvalidated: true}, + PhaseHandshake: {PhaseCatchUp: true, PhaseCompleted: true, PhaseInvalidated: true}, + PhaseCatchUp: {PhaseCompleted: true, PhaseInvalidated: true}, +}