From bb24b4b039e9120848f55ed2ca957f3a23cc03ef Mon Sep 17 00:00:00 2001 From: pingqiu Date: Sun, 29 Mar 2026 20:58:28 -0700 Subject: [PATCH] fix: encapsulate engine sender/session authority state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All mutable state on Sender and Session is now unexported: - Sender.state, .epoch, .endpoint, .session, .stopped → accessors - Session.id, .phase, .kind, etc. → read-only accessors - Session() replaced by SessionSnapshot() (returns disconnected copy) - SessionID() and HasActiveSession() for common queries - AttachSession returns (sessionID, error) not (*Session, error) - SupersedeSession returns sessionID not *Session Budget configuration via SessionOption: - WithBudget(CatchUpBudget) passed to AttachSession - No direct field mutation on session from external code New test: Encapsulation_SnapshotIsReadOnly proves snapshot mutation does not leak back to sender state. Co-Authored-By: Claude Opus 4.6 (1M context) --- sw-block/engine/replication/ownership_test.go | 198 ++++++------ sw-block/engine/replication/registry.go | 14 +- sw-block/engine/replication/sender.go | 302 +++++++++++------- sw-block/engine/replication/session.go | 120 +++---- 4 files changed, 350 insertions(+), 284 deletions(-) diff --git a/sw-block/engine/replication/ownership_test.go b/sw-block/engine/replication/ownership_test.go index bee6e3d81..b7e067d78 100644 --- a/sw-block/engine/replication/ownership_test.go +++ b/sw-block/engine/replication/ownership_test.go @@ -4,7 +4,6 @@ import "testing" // ============================================================ // Phase 05 Slice 1: Engine ownership/fencing tests -// Mapped to V2 acceptance criteria and boundary cases. // ============================================================ // --- Changed-address invalidation (A10) --- @@ -20,22 +19,18 @@ func TestEngine_ChangedAddress_InvalidatesSession(t *testing.T) { }) s := r.Sender("r1:9333") - sess := s.Session() - s.BeginConnect(sess.ID) + sessID := s.SessionID() + s.BeginConnect(sessID) - // Address changes mid-recovery. r.Reconcile(map[string]Endpoint{ "r1:9333": {DataAddr: "r1:9333", CtrlAddr: "r1:9445", Version: 2}, }, 1) - if sess.Active() { - t.Fatal("session should be invalidated by address change") + if s.HasActiveSession() { + t.Fatal("session should be invalidated") } - if r.Sender("r1:9333") != s { - t.Fatal("sender identity should be preserved") - } - if s.State != StateDisconnected { - t.Fatalf("state=%s, want disconnected", s.State) + if s.State() != StateDisconnected { + t.Fatalf("state=%s", s.State()) } } @@ -50,10 +45,8 @@ func TestEngine_ChangedAddress_NewSessionAfterUpdate(t *testing.T) { }) s := r.Sender("r1:9333") - oldSess := s.Session() - s.BeginConnect(oldSess.ID) + oldID := s.SessionID() - // Address change + new assignment. r.Reconcile(map[string]Endpoint{ "r1:9333": {DataAddr: "r1:9333", Version: 2}, }, 1) @@ -66,9 +59,8 @@ func TestEngine_ChangedAddress_NewSessionAfterUpdate(t *testing.T) { if len(result.SessionsCreated) != 1 { t.Fatalf("should create new session: %v", result) } - newSess := s.Session() - if newSess.ID == oldSess.ID { - t.Fatal("new session should have different ID") + if s.SessionID() == oldID { + t.Fatal("should have different session ID") } } @@ -76,8 +68,7 @@ func TestEngine_ChangedAddress_NewSessionAfterUpdate(t *testing.T) { func TestEngine_StaleSessionID_RejectedAtAllAPIs(t *testing.T) { s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) - sess, _ := s.AttachSession(1, SessionCatchUp) - staleID := sess.ID + staleID, _ := s.AttachSession(1, SessionCatchUp) s.UpdateEpoch(2) s.AttachSession(2, SessionCatchUp) @@ -101,22 +92,18 @@ func TestEngine_StaleSessionID_RejectedAtAllAPIs(t *testing.T) { func TestEngine_StaleCompletion_AfterSupersede(t *testing.T) { s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) - sess1, _ := s.AttachSession(1, SessionCatchUp) - id1 := sess1.ID + id1, _ := s.AttachSession(1, SessionCatchUp) s.UpdateEpoch(2) - sess2, _ := s.AttachSession(2, SessionCatchUp) + s.AttachSession(2, SessionCatchUp) - // Old session completion rejected. if s.CompleteSessionByID(id1) { t.Fatal("stale completion must be rejected") } - // New session still active. - if !sess2.Active() { + if s.HasActiveSession() != true { t.Fatal("new session should be active") } - // Sender not moved to InSync. - if s.State == StateInSync { + if s.State() == StateInSync { t.Fatal("sender should not be InSync from stale completion") } } @@ -130,23 +117,18 @@ func TestEngine_EpochBump_InvalidatesAllSessions(t *testing.T) { "r1:9333": {DataAddr: "r1:9333", Version: 1}, "r2:9333": {DataAddr: "r2:9333", Version: 1}, }, - Epoch: 1, + Epoch: 1, RecoveryTargets: map[string]SessionKind{ "r1:9333": SessionCatchUp, "r2:9333": SessionCatchUp, }, }) - s1 := r.Sender("r1:9333") - s2 := r.Sender("r2:9333") - sess1 := s1.Session() - sess2 := s2.Session() - count := r.InvalidateEpoch(2) if count != 2 { t.Fatalf("should invalidate 2, got %d", count) } - if sess1.Active() || sess2.Active() { + if r.Sender("r1:9333").HasActiveSession() || r.Sender("r2:9333").HasActiveSession() { t.Fatal("both sessions should be invalidated") } } @@ -154,10 +136,8 @@ func TestEngine_EpochBump_InvalidatesAllSessions(t *testing.T) { func TestEngine_EpochBump_StaleAssignment_Rejected(t *testing.T) { r := NewRegistry() r.ApplyAssignment(AssignmentIntent{ - Endpoints: map[string]Endpoint{ - "r1:9333": {DataAddr: "r1:9333", Version: 1}, - }, - Epoch: 2, + Endpoints: map[string]Endpoint{"r1:9333": {DataAddr: "r1:9333", Version: 1}}, + Epoch: 2, }) result := r.ApplyAssignment(AssignmentIntent{ @@ -175,38 +155,38 @@ func TestEngine_EpochBump_StaleAssignment_Rejected(t *testing.T) { func TestEngine_Rebuild_CatchUpAPIs_Rejected(t *testing.T) { s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) - sess, _ := s.AttachSession(1, SessionRebuild) - s.BeginConnect(sess.ID) - s.RecordHandshake(sess.ID, 0, 100) + sessID, _ := s.AttachSession(1, SessionRebuild) + s.BeginConnect(sessID) + s.RecordHandshake(sessID, 0, 100) - if err := s.BeginCatchUp(sess.ID); err == nil { - t.Fatal("rebuild: BeginCatchUp should be rejected") + if err := s.BeginCatchUp(sessID); err == nil { + t.Fatal("rebuild: BeginCatchUp should reject") } - if err := s.RecordCatchUpProgress(sess.ID, 50); err == nil { - t.Fatal("rebuild: RecordCatchUpProgress should be rejected") + if err := s.RecordCatchUpProgress(sessID, 50); err == nil { + t.Fatal("rebuild: RecordCatchUpProgress should reject") } - if s.CompleteSessionByID(sess.ID) { - t.Fatal("rebuild: catch-up completion should be rejected") + if s.CompleteSessionByID(sessID) { + t.Fatal("rebuild: catch-up completion should reject") } } func TestEngine_Rebuild_FullLifecycle(t *testing.T) { s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) - sess, _ := s.AttachSession(1, SessionRebuild) - - s.BeginConnect(sess.ID) - s.RecordHandshake(sess.ID, 0, 100) - s.SelectRebuildSource(sess.ID, 50, true, 100) - s.BeginRebuildTransfer(sess.ID) - s.RecordRebuildTransferProgress(sess.ID, 50) - s.BeginRebuildTailReplay(sess.ID) - s.RecordRebuildTailProgress(sess.ID, 100) - - if err := s.CompleteRebuild(sess.ID); err != nil { - t.Fatalf("rebuild completion: %v", err) + sessID, _ := s.AttachSession(1, SessionRebuild) + + s.BeginConnect(sessID) + s.RecordHandshake(sessID, 0, 100) + s.SelectRebuildSource(sessID, 50, true, 100) + s.BeginRebuildTransfer(sessID) + s.RecordRebuildTransferProgress(sessID, 50) + s.BeginRebuildTailReplay(sessID) + s.RecordRebuildTailProgress(sessID, 100) + + if err := s.CompleteRebuild(sessID); err != nil { + t.Fatalf("rebuild: %v", err) } - if s.State != StateInSync { - t.Fatalf("state=%s, want in_sync", s.State) + if s.State() != StateInSync { + t.Fatalf("state=%s", s.State()) } } @@ -214,37 +194,66 @@ func TestEngine_Rebuild_FullLifecycle(t *testing.T) { func TestEngine_FrozenTarget_RejectsChase(t *testing.T) { s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) - sess, _ := s.AttachSession(1, SessionCatchUp) + sessID, _ := s.AttachSession(1, SessionCatchUp) - s.BeginConnect(sess.ID) - s.RecordHandshake(sess.ID, 0, 50) - s.BeginCatchUp(sess.ID) + s.BeginConnect(sessID) + s.RecordHandshake(sessID, 0, 50) + s.BeginCatchUp(sessID) - if err := s.RecordCatchUpProgress(sess.ID, 51); err == nil { - t.Fatal("progress beyond frozen target should be rejected") + if err := s.RecordCatchUpProgress(sessID, 51); err == nil { + t.Fatal("beyond frozen target should be rejected") } } func TestEngine_BudgetViolation_Escalates(t *testing.T) { s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) - sess, _ := s.AttachSession(1, SessionCatchUp) - sess.Budget = &CatchUpBudget{MaxDurationTicks: 5} + sessID, _ := s.AttachSession(1, SessionCatchUp, WithBudget(CatchUpBudget{MaxDurationTicks: 5})) - s.BeginConnect(sess.ID) - s.RecordHandshake(sess.ID, 0, 100) - s.BeginCatchUp(sess.ID, 0) - s.RecordCatchUpProgress(sess.ID, 10) + s.BeginConnect(sessID) + s.RecordHandshake(sessID, 0, 100) + s.BeginCatchUp(sessID, 0) + s.RecordCatchUpProgress(sessID, 10) - v, _ := s.CheckBudget(sess.ID, 10) + v, _ := s.CheckBudget(sessID, 10) if v != BudgetDurationExceeded { t.Fatalf("budget=%s", v) } - if s.State != StateNeedsRebuild { - t.Fatalf("state=%s", s.State) + if s.State() != StateNeedsRebuild { + t.Fatalf("state=%s", s.State()) + } +} + +// --- Encapsulation: no direct state mutation --- + +func TestEngine_Encapsulation_SnapshotIsReadOnly(t *testing.T) { + s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1) + sessID, _ := s.AttachSession(1, SessionCatchUp) + + snap := s.SessionSnapshot() + if snap == nil || !snap.Active { + t.Fatal("should have active session snapshot") + } + + // Mutating the snapshot does not affect the sender. + snap.Phase = PhaseCompleted + snap.Active = false + + // Sender's session is still active. + if !s.HasActiveSession() { + t.Fatal("sender should still have active session after snapshot mutation") + } + snap2 := s.SessionSnapshot() + if snap2.Phase == PhaseCompleted { + t.Fatal("snapshot mutation should not leak back to sender") + } + + // Can still execute on the real session. + if err := s.BeginConnect(sessID); err != nil { + t.Fatalf("execution should still work: %v", err) } } -// --- E2E: 3 replicas, 3 outcomes --- +// --- E2E --- func TestEngine_E2E_ThreeReplicas_ThreeOutcomes(t *testing.T) { r := NewRegistry() @@ -264,47 +273,46 @@ func TestEngine_E2E_ThreeReplicas_ThreeOutcomes(t *testing.T) { // r1: zero-gap. r1 := r.Sender("r1:9333") - s1 := r1.Session() - r1.BeginConnect(s1.ID) - o1, _ := r1.RecordHandshakeWithOutcome(s1.ID, HandshakeResult{ + id1 := r1.SessionID() + r1.BeginConnect(id1) + o1, _ := r1.RecordHandshakeWithOutcome(id1, HandshakeResult{ ReplicaFlushedLSN: 100, CommittedLSN: 100, RetentionStartLSN: 50, }) if o1 != OutcomeZeroGap { t.Fatalf("r1: %s", o1) } - r1.CompleteSessionByID(s1.ID) + r1.CompleteSessionByID(id1) // r2: catch-up. r2 := r.Sender("r2:9333") - s2 := r2.Session() - r2.BeginConnect(s2.ID) - o2, _ := r2.RecordHandshakeWithOutcome(s2.ID, HandshakeResult{ + id2 := r2.SessionID() + r2.BeginConnect(id2) + o2, _ := r2.RecordHandshakeWithOutcome(id2, HandshakeResult{ ReplicaFlushedLSN: 70, CommittedLSN: 100, RetentionStartLSN: 50, }) if o2 != OutcomeCatchUp { t.Fatalf("r2: %s", o2) } - r2.BeginCatchUp(s2.ID) - r2.RecordCatchUpProgress(s2.ID, 100) - r2.CompleteSessionByID(s2.ID) + r2.BeginCatchUp(id2) + r2.RecordCatchUpProgress(id2, 100) + r2.CompleteSessionByID(id2) // r3: needs rebuild. r3 := r.Sender("r3:9333") - s3 := r3.Session() - r3.BeginConnect(s3.ID) - o3, _ := r3.RecordHandshakeWithOutcome(s3.ID, HandshakeResult{ + id3 := r3.SessionID() + r3.BeginConnect(id3) + o3, _ := r3.RecordHandshakeWithOutcome(id3, HandshakeResult{ ReplicaFlushedLSN: 10, CommittedLSN: 100, RetentionStartLSN: 50, }) if o3 != OutcomeNeedsRebuild { t.Fatalf("r3: %s", o3) } - // Final states. - if r1.State != StateInSync || r2.State != StateInSync { - t.Fatalf("r1=%s r2=%s", r1.State, r2.State) + if r1.State() != StateInSync || r2.State() != StateInSync { + t.Fatalf("r1=%s r2=%s", r1.State(), r2.State()) } - if r3.State != StateNeedsRebuild { - t.Fatalf("r3=%s", r3.State) + if r3.State() != StateNeedsRebuild { + t.Fatalf("r3=%s", r3.State()) } if r.InSyncCount() != 2 { t.Fatalf("in_sync=%d", r.InSyncCount()) diff --git a/sw-block/engine/replication/registry.go b/sw-block/engine/replication/registry.go index 27b0a0bf8..bcaccd1ef 100644 --- a/sw-block/engine/replication/registry.go +++ b/sw-block/engine/replication/registry.go @@ -75,14 +75,14 @@ func (r *Registry) ApplyAssignment(intent AssignmentIntent) AssignmentResult { result.SessionsFailed = append(result.SessionsFailed, replicaID) continue } - if intent.Epoch < sender.Epoch { + if intent.Epoch < sender.Epoch() { result.SessionsFailed = append(result.SessionsFailed, replicaID) continue } _, err := sender.AttachSession(intent.Epoch, kind) if err != nil { - sess := sender.SupersedeSession(kind, "assignment_intent") - if sess != nil { + id := sender.SupersedeSession(kind, "assignment_intent") + if id != 0 { result.SessionsSuperseded = append(result.SessionsSuperseded, replicaID) } else { result.SessionsFailed = append(result.SessionsFailed, replicaID) @@ -110,7 +110,7 @@ func (r *Registry) All() []*Sender { out = append(out, s) } sort.Slice(out, func(i, j int) bool { - return out[i].ReplicaID < out[j].ReplicaID + return out[i].ReplicaID() < out[j].ReplicaID() }) return out } @@ -137,7 +137,7 @@ func (r *Registry) InSyncCount() int { defer r.mu.RUnlock() count := 0 for _, s := range r.senders { - if s.State == StateInSync { + if s.State() == StateInSync { count++ } } @@ -150,8 +150,8 @@ func (r *Registry) InvalidateEpoch(currentEpoch uint64) int { defer r.mu.RUnlock() count := 0 for _, s := range r.senders { - sess := s.Session() - if sess != nil && sess.Epoch < currentEpoch && sess.Active() { + snap := s.SessionSnapshot() + if snap != nil && snap.Epoch < currentEpoch && snap.Active { s.InvalidateSession("epoch_bump", StateDisconnected) count++ } diff --git a/sw-block/engine/replication/sender.go b/sw-block/engine/replication/sender.go index 236c2c6a0..cc89a2f77 100644 --- a/sw-block/engine/replication/sender.go +++ b/sw-block/engine/replication/sender.go @@ -6,15 +6,15 @@ import ( ) // Sender owns the replication channel to one replica. It is the authority -// boundary for all execution operations — every API validates the session -// ID before mutating state. +// boundary for all execution operations. All mutable state is unexported — +// external code reads state through accessors and mutates through execution APIs. type Sender struct { mu sync.Mutex - ReplicaID string - Endpoint Endpoint - Epoch uint64 - State ReplicaState + replicaID string + endpoint Endpoint + epoch uint64 + state ReplicaState session *Session stopped bool @@ -23,26 +23,92 @@ type Sender struct { // NewSender creates a sender for a replica. func NewSender(replicaID string, endpoint Endpoint, epoch uint64) *Sender { return &Sender{ - ReplicaID: replicaID, - Endpoint: endpoint, - Epoch: epoch, - State: StateDisconnected, + replicaID: replicaID, + endpoint: endpoint, + epoch: epoch, + state: StateDisconnected, } } +// Read-only accessors. + +func (s *Sender) ReplicaID() string { s.mu.Lock(); defer s.mu.Unlock(); return s.replicaID } +func (s *Sender) Endpoint() Endpoint { s.mu.Lock(); defer s.mu.Unlock(); return s.endpoint } +func (s *Sender) Epoch() uint64 { s.mu.Lock(); defer s.mu.Unlock(); return s.epoch } +func (s *Sender) State() ReplicaState { s.mu.Lock(); defer s.mu.Unlock(); return s.state } +func (s *Sender) Stopped() bool { s.mu.Lock(); defer s.mu.Unlock(); return s.stopped } + +// SessionSnapshot returns a read-only copy of the current session state. +// Returns nil if no session is active. The returned snapshot is disconnected +// from the live session — mutations to the Sender do not affect it. +func (s *Sender) SessionSnapshot() *SessionSnapshot { + s.mu.Lock() + defer s.mu.Unlock() + if s.session == nil { + return nil + } + return &SessionSnapshot{ + ID: s.session.id, + ReplicaID: s.session.replicaID, + Epoch: s.session.epoch, + Kind: s.session.kind, + Phase: s.session.phase, + InvalidateReason: s.session.invalidateReason, + StartLSN: s.session.startLSN, + TargetLSN: s.session.targetLSN, + FrozenTargetLSN: s.session.frozenTargetLSN, + RecoveredTo: s.session.recoveredTo, + Active: s.session.Active(), + } +} + +// SessionSnapshot is a read-only copy of session state for external inspection. +type SessionSnapshot struct { + ID uint64 + ReplicaID string + Epoch uint64 + Kind SessionKind + Phase SessionPhase + InvalidateReason string + StartLSN uint64 + TargetLSN uint64 + FrozenTargetLSN uint64 + RecoveredTo uint64 + Active bool +} + +// SessionID returns the current session ID, or 0 if no session. +func (s *Sender) SessionID() uint64 { + s.mu.Lock() + defer s.mu.Unlock() + if s.session == nil { + return 0 + } + return s.session.id +} + +// HasActiveSession returns true if a session is currently active. +func (s *Sender) HasActiveSession() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.session != nil && s.session.Active() +} + +// === Lifecycle APIs === + // UpdateEpoch advances the sender's epoch. Invalidates stale sessions. func (s *Sender) UpdateEpoch(epoch uint64) { s.mu.Lock() defer s.mu.Unlock() - if s.stopped || epoch <= s.Epoch { + if s.stopped || epoch <= s.epoch { return } - oldEpoch := s.Epoch - s.Epoch = epoch - if s.session != nil && s.session.Epoch < epoch { + oldEpoch := s.epoch + s.epoch = epoch + if s.session != nil && s.session.epoch < epoch { s.session.invalidate(fmt.Sprintf("epoch_advanced_%d_to_%d", oldEpoch, epoch)) s.session = nil - s.State = StateDisconnected + s.state = StateDisconnected } } @@ -53,59 +119,62 @@ func (s *Sender) UpdateEndpoint(ep Endpoint) { if s.stopped { return } - if s.Endpoint.Changed(ep) && s.session != nil { + if s.endpoint.Changed(ep) && s.session != nil { s.session.invalidate("endpoint_changed") s.session = nil - s.State = StateDisconnected + s.state = StateDisconnected + } + s.endpoint = ep +} + +// SessionOption configures a newly created session. +type SessionOption func(s *Session) + +// WithBudget attaches a catch-up budget to the session. +func WithBudget(budget CatchUpBudget) SessionOption { + return func(s *Session) { + b := budget // copy + s.budget = &b } - s.Endpoint = ep } // AttachSession creates a new recovery session. Epoch must match sender epoch. -func (s *Sender) AttachSession(epoch uint64, kind SessionKind) (*Session, error) { +func (s *Sender) AttachSession(epoch uint64, kind SessionKind, opts ...SessionOption) (uint64, error) { s.mu.Lock() defer s.mu.Unlock() if s.stopped { - return nil, fmt.Errorf("sender stopped") + return 0, fmt.Errorf("sender stopped") } - if epoch != s.Epoch { - return nil, fmt.Errorf("epoch mismatch: sender=%d session=%d", s.Epoch, epoch) + if epoch != s.epoch { + return 0, fmt.Errorf("epoch mismatch: sender=%d session=%d", s.epoch, epoch) } if s.session != nil && s.session.Active() { - return nil, fmt.Errorf("session already active (id=%d)", s.session.ID) + return 0, fmt.Errorf("session already active (id=%d)", s.session.id) + } + sess := newSession(s.replicaID, epoch, kind) + for _, opt := range opts { + opt(sess) } - sess := newSession(s.ReplicaID, epoch, kind) s.session = sess - return sess, nil + return sess.id, nil } // SupersedeSession invalidates current session and attaches new at sender epoch. -func (s *Sender) SupersedeSession(kind SessionKind, reason string) *Session { +func (s *Sender) SupersedeSession(kind SessionKind, reason string, opts ...SessionOption) uint64 { s.mu.Lock() defer s.mu.Unlock() if s.stopped { - return nil + return 0 } if s.session != nil { s.session.invalidate(reason) } - sess := newSession(s.ReplicaID, s.Epoch, kind) + sess := newSession(s.replicaID, s.epoch, kind) + for _, opt := range opts { + opt(sess) + } s.session = sess - return sess -} - -// Session returns the current session, or nil. -func (s *Sender) Session() *Session { - s.mu.Lock() - defer s.mu.Unlock() - return s.session -} - -// Stopped returns true if the sender has been stopped. -func (s *Sender) Stopped() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.stopped + return sess.id } // Stop shuts down the sender. @@ -130,26 +199,24 @@ func (s *Sender) InvalidateSession(reason string, targetState ReplicaState) { s.session.invalidate(reason) s.session = nil } - s.State = targetState + s.state = targetState } // === Catch-up execution APIs === -// BeginConnect transitions init → connecting. func (s *Sender) BeginConnect(sessionID uint64) error { s.mu.Lock() defer s.mu.Unlock() if err := s.checkAuthority(sessionID); err != nil { return err } - if !s.session.Advance(PhaseConnecting) { - return fmt.Errorf("cannot begin connect: phase=%s", s.session.Phase) + if !s.session.advance(PhaseConnecting) { + return fmt.Errorf("cannot begin connect: phase=%s", s.session.phase) } - s.State = StateConnecting + s.state = StateConnecting return nil } -// RecordHandshake records handshake result and sets catch-up range. func (s *Sender) RecordHandshake(sessionID uint64, startLSN, targetLSN uint64) error { s.mu.Lock() defer s.mu.Unlock() @@ -159,14 +226,13 @@ func (s *Sender) RecordHandshake(sessionID uint64, startLSN, targetLSN uint64) e if targetLSN < startLSN { return fmt.Errorf("invalid range: target=%d < start=%d", targetLSN, startLSN) } - if !s.session.Advance(PhaseHandshake) { - return fmt.Errorf("cannot record handshake: phase=%s", s.session.Phase) + if !s.session.advance(PhaseHandshake) { + return fmt.Errorf("cannot record handshake: phase=%s", s.session.phase) } - s.session.SetRange(startLSN, targetLSN) + s.session.setRange(startLSN, targetLSN) return nil } -// RecordHandshakeWithOutcome records handshake and classifies the recovery outcome. func (s *Sender) RecordHandshakeWithOutcome(sessionID uint64, result HandshakeResult) (RecoveryOutcome, error) { outcome := ClassifyRecoveryOutcome(result) s.mu.Lock() @@ -174,106 +240,100 @@ func (s *Sender) RecordHandshakeWithOutcome(sessionID uint64, result HandshakeRe if err := s.checkAuthority(sessionID); err != nil { return outcome, err } - if s.session.Phase != PhaseConnecting { - return outcome, fmt.Errorf("handshake requires connecting, got %s", s.session.Phase) + if s.session.phase != PhaseConnecting { + return outcome, fmt.Errorf("handshake requires connecting, got %s", s.session.phase) } if outcome == OutcomeNeedsRebuild { s.session.invalidate("gap_exceeds_retention") s.session = nil - s.State = StateNeedsRebuild + s.state = StateNeedsRebuild return outcome, nil } - if !s.session.Advance(PhaseHandshake) { - return outcome, fmt.Errorf("cannot advance to handshake: phase=%s", s.session.Phase) + if !s.session.advance(PhaseHandshake) { + return outcome, fmt.Errorf("cannot advance to handshake: phase=%s", s.session.phase) } switch outcome { case OutcomeZeroGap: - s.session.SetRange(result.ReplicaFlushedLSN, result.ReplicaFlushedLSN) + s.session.setRange(result.ReplicaFlushedLSN, result.ReplicaFlushedLSN) case OutcomeCatchUp: if result.ReplicaFlushedLSN > result.CommittedLSN { - s.session.TruncateRequired = true - s.session.TruncateToLSN = result.CommittedLSN - s.session.SetRange(result.CommittedLSN, result.CommittedLSN) + s.session.truncateRequired = true + s.session.truncateToLSN = result.CommittedLSN + s.session.setRange(result.CommittedLSN, result.CommittedLSN) } else { - s.session.SetRange(result.ReplicaFlushedLSN, result.CommittedLSN) + s.session.setRange(result.ReplicaFlushedLSN, result.CommittedLSN) } } return outcome, nil } -// BeginCatchUp transitions to catch-up phase. Rejects rebuild sessions. -// Freezes the target unconditionally. func (s *Sender) BeginCatchUp(sessionID uint64, startTick ...uint64) error { s.mu.Lock() defer s.mu.Unlock() if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Kind == SessionRebuild { + if s.session.kind == SessionRebuild { return fmt.Errorf("rebuild sessions must use rebuild APIs") } - if !s.session.Advance(PhaseCatchUp) { - return fmt.Errorf("cannot begin catch-up: phase=%s", s.session.Phase) + if !s.session.advance(PhaseCatchUp) { + return fmt.Errorf("cannot begin catch-up: phase=%s", s.session.phase) } - s.State = StateCatchingUp - s.session.FrozenTargetLSN = s.session.TargetLSN + s.state = StateCatchingUp + s.session.frozenTargetLSN = s.session.targetLSN if len(startTick) > 0 { - s.session.Tracker.StartTick = startTick[0] - s.session.Tracker.LastProgressTick = startTick[0] + s.session.tracker.StartTick = startTick[0] + s.session.tracker.LastProgressTick = startTick[0] } return nil } -// RecordCatchUpProgress records catch-up progress. Rejects rebuild sessions. -// Entry counting uses LSN delta. Tick is required when ProgressDeadlineTicks > 0. func (s *Sender) RecordCatchUpProgress(sessionID uint64, recoveredTo uint64, tick ...uint64) error { s.mu.Lock() defer s.mu.Unlock() if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Kind == SessionRebuild { + if s.session.kind == SessionRebuild { return fmt.Errorf("rebuild sessions must use rebuild APIs") } - if s.session.Phase != PhaseCatchUp { - return fmt.Errorf("progress requires catchup phase, got %s", s.session.Phase) + if s.session.phase != PhaseCatchUp { + return fmt.Errorf("progress requires catchup phase, got %s", s.session.phase) } - if recoveredTo <= s.session.RecoveredTo { - return fmt.Errorf("progress regression: %d <= %d", recoveredTo, s.session.RecoveredTo) + if recoveredTo <= s.session.recoveredTo { + return fmt.Errorf("progress regression: %d <= %d", recoveredTo, s.session.recoveredTo) } - if s.session.FrozenTargetLSN > 0 && recoveredTo > s.session.FrozenTargetLSN { - return fmt.Errorf("progress %d exceeds frozen target %d", recoveredTo, s.session.FrozenTargetLSN) + if s.session.frozenTargetLSN > 0 && recoveredTo > s.session.frozenTargetLSN { + return fmt.Errorf("progress %d exceeds frozen target %d", recoveredTo, s.session.frozenTargetLSN) } - if s.session.Budget != nil && s.session.Budget.ProgressDeadlineTicks > 0 && len(tick) == 0 { + if s.session.budget != nil && s.session.budget.ProgressDeadlineTicks > 0 && len(tick) == 0 { return fmt.Errorf("tick required when ProgressDeadlineTicks > 0") } - delta := recoveredTo - s.session.RecoveredTo - s.session.Tracker.EntriesReplayed += delta - s.session.UpdateProgress(recoveredTo) + delta := recoveredTo - s.session.recoveredTo + s.session.tracker.EntriesReplayed += delta + s.session.updateProgress(recoveredTo) if len(tick) > 0 { - s.session.Tracker.LastProgressTick = tick[0] + s.session.tracker.LastProgressTick = tick[0] } return nil } -// RecordTruncation confirms divergent tail cleanup. func (s *Sender) RecordTruncation(sessionID uint64, truncatedToLSN uint64) error { s.mu.Lock() defer s.mu.Unlock() if err := s.checkAuthority(sessionID); err != nil { return err } - if !s.session.TruncateRequired { + if !s.session.truncateRequired { return fmt.Errorf("truncation not required") } - if truncatedToLSN != s.session.TruncateToLSN { - return fmt.Errorf("truncation LSN mismatch: expected %d, got %d", s.session.TruncateToLSN, truncatedToLSN) + if truncatedToLSN != s.session.truncateToLSN { + return fmt.Errorf("truncation LSN mismatch: expected %d, got %d", s.session.truncateToLSN, truncatedToLSN) } - s.session.TruncateRecorded = true + s.session.truncateRecorded = true return nil } -// CompleteSessionByID completes catch-up sessions. Rejects rebuild sessions. func (s *Sender) CompleteSessionByID(sessionID uint64) bool { s.mu.Lock() defer s.mu.Unlock() @@ -281,19 +341,19 @@ func (s *Sender) CompleteSessionByID(sessionID uint64) bool { return false } sess := s.session - if sess.Kind == SessionRebuild { + if sess.kind == SessionRebuild { return false } - if sess.TruncateRequired && !sess.TruncateRecorded { + if sess.truncateRequired && !sess.truncateRecorded { return false } - switch sess.Phase { + switch sess.phase { case PhaseCatchUp: if !sess.Converged() { return false } case PhaseHandshake: - if sess.TargetLSN != sess.StartLSN { + if sess.targetLSN != sess.startLSN { return false } default: @@ -301,48 +361,46 @@ func (s *Sender) CompleteSessionByID(sessionID uint64) bool { } sess.complete() s.session = nil - s.State = StateInSync + s.state = StateInSync return true } -// CheckBudget evaluates catch-up budget. Auto-escalates on violation. func (s *Sender) CheckBudget(sessionID uint64, currentTick uint64) (BudgetViolation, error) { s.mu.Lock() defer s.mu.Unlock() if err := s.checkAuthority(sessionID); err != nil { return BudgetOK, err } - if s.session.Budget == nil { + if s.session.budget == nil { return BudgetOK, nil } - v := s.session.Budget.Check(s.session.Tracker, currentTick) + v := s.session.budget.Check(s.session.tracker, currentTick) if v != BudgetOK { s.session.invalidate(fmt.Sprintf("budget_%s", v)) s.session = nil - s.State = StateNeedsRebuild + s.state = StateNeedsRebuild } return v, nil } // === Rebuild execution APIs === -// SelectRebuildSource chooses rebuild source. Requires PhaseHandshake. func (s *Sender) SelectRebuildSource(sessionID uint64, snapshotLSN uint64, snapshotValid bool, committedLSN uint64) error { s.mu.Lock() defer s.mu.Unlock() if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Kind != SessionRebuild { + if s.session.kind != SessionRebuild { return fmt.Errorf("not a rebuild session") } - if s.session.Phase != PhaseHandshake { - return fmt.Errorf("requires PhaseHandshake, got %s", s.session.Phase) + if s.session.phase != PhaseHandshake { + return fmt.Errorf("requires PhaseHandshake, got %s", s.session.phase) } - if s.session.Rebuild == nil { + if s.session.rebuild == nil { return fmt.Errorf("rebuild state not initialized") } - return s.session.Rebuild.SelectSource(snapshotLSN, snapshotValid, committedLSN) + return s.session.rebuild.SelectSource(snapshotLSN, snapshotValid, committedLSN) } func (s *Sender) BeginRebuildTransfer(sessionID uint64) error { @@ -351,10 +409,10 @@ func (s *Sender) BeginRebuildTransfer(sessionID uint64) error { if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Rebuild == nil { + if s.session.rebuild == nil { return fmt.Errorf("no rebuild state") } - return s.session.Rebuild.BeginTransfer() + return s.session.rebuild.BeginTransfer() } func (s *Sender) RecordRebuildTransferProgress(sessionID uint64, transferredTo uint64) error { @@ -363,10 +421,10 @@ func (s *Sender) RecordRebuildTransferProgress(sessionID uint64, transferredTo u if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Rebuild == nil { + if s.session.rebuild == nil { return fmt.Errorf("no rebuild state") } - return s.session.Rebuild.RecordTransferProgress(transferredTo) + return s.session.rebuild.RecordTransferProgress(transferredTo) } func (s *Sender) BeginRebuildTailReplay(sessionID uint64) error { @@ -375,10 +433,10 @@ func (s *Sender) BeginRebuildTailReplay(sessionID uint64) error { if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Rebuild == nil { + if s.session.rebuild == nil { return fmt.Errorf("no rebuild state") } - return s.session.Rebuild.BeginTailReplay() + return s.session.rebuild.BeginTailReplay() } func (s *Sender) RecordRebuildTailProgress(sessionID uint64, replayedTo uint64) error { @@ -387,32 +445,30 @@ func (s *Sender) RecordRebuildTailProgress(sessionID uint64, replayedTo uint64) if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Rebuild == nil { + if s.session.rebuild == nil { return fmt.Errorf("no rebuild state") } - return s.session.Rebuild.RecordTailReplayProgress(replayedTo) + return s.session.rebuild.RecordTailReplayProgress(replayedTo) } -// CompleteRebuild completes a rebuild session. Requires ReadyToComplete. func (s *Sender) CompleteRebuild(sessionID uint64) error { s.mu.Lock() defer s.mu.Unlock() if err := s.checkAuthority(sessionID); err != nil { return err } - if s.session.Rebuild == nil { + if s.session.rebuild == nil { return fmt.Errorf("no rebuild state") } - if err := s.session.Rebuild.Complete(); err != nil { + if err := s.session.rebuild.Complete(); err != nil { return err } s.session.complete() s.session = nil - s.State = StateInSync + s.state = StateInSync return nil } -// checkAuthority validates session ownership. func (s *Sender) checkAuthority(sessionID uint64) error { if s.stopped { return fmt.Errorf("sender stopped") @@ -420,11 +476,11 @@ func (s *Sender) checkAuthority(sessionID uint64) error { if s.session == nil { return fmt.Errorf("no active session") } - if s.session.ID != sessionID { - return fmt.Errorf("session ID mismatch: active=%d requested=%d", s.session.ID, sessionID) + if s.session.id != sessionID { + return fmt.Errorf("session ID mismatch: active=%d requested=%d", s.session.id, sessionID) } if !s.session.Active() { - return fmt.Errorf("session %d not active (phase=%s)", sessionID, s.session.Phase) + return fmt.Errorf("session %d not active (phase=%s)", sessionID, s.session.phase) } return nil } diff --git a/sw-block/engine/replication/session.go b/sw-block/engine/replication/session.go index 0d5d2989d..4613411df 100644 --- a/sw-block/engine/replication/session.go +++ b/sw-block/engine/replication/session.go @@ -6,98 +6,100 @@ import "sync/atomic" var sessionIDCounter atomic.Uint64 // Session represents one recovery attempt for a specific replica at a -// specific epoch. It is owned by a Sender and gated by session ID at -// every execution step. -// -// Lifecycle: -// - Created via Sender.AttachSession or Sender.SupersedeSession -// - Advanced through phases: init → connecting → handshake → catchup → completed -// - Invalidated by: epoch bump, endpoint change, sender stop, timeout -// - Stale sessions (wrong ID) are rejected at every execution API +// specific epoch. All mutable state is unexported — external code interacts +// through Sender execution APIs only. type Session struct { - ID uint64 - ReplicaID string - Epoch uint64 - Kind SessionKind - Phase SessionPhase - InvalidateReason string - - // Progress tracking. - StartLSN uint64 // gap start (exclusive) - TargetLSN uint64 // gap end (inclusive) - FrozenTargetLSN uint64 // frozen at BeginCatchUp — catch-up will not chase beyond - RecoveredTo uint64 // highest LSN recovered so far - - // Truncation. - TruncateRequired bool - TruncateToLSN uint64 - TruncateRecorded bool - - // Budget (nil = no enforcement). - Budget *CatchUpBudget - Tracker BudgetCheck - - // Rebuild state (non-nil when Kind == SessionRebuild). - Rebuild *RebuildState + id uint64 + replicaID string + epoch uint64 + kind SessionKind + phase SessionPhase + invalidateReason string + + startLSN uint64 + targetLSN uint64 + frozenTargetLSN uint64 + recoveredTo uint64 + + truncateRequired bool + truncateToLSN uint64 + truncateRecorded bool + + budget *CatchUpBudget + tracker BudgetCheck + + rebuild *RebuildState } func newSession(replicaID string, epoch uint64, kind SessionKind) *Session { s := &Session{ - ID: sessionIDCounter.Add(1), - ReplicaID: replicaID, - Epoch: epoch, - Kind: kind, - Phase: PhaseInit, + id: sessionIDCounter.Add(1), + replicaID: replicaID, + epoch: epoch, + kind: kind, + phase: PhaseInit, } if kind == SessionRebuild { - s.Rebuild = NewRebuildState() + s.rebuild = NewRebuildState() } return s } +// Read-only accessors. + +func (s *Session) ID() uint64 { return s.id } +func (s *Session) ReplicaID() string { return s.replicaID } +func (s *Session) Epoch() uint64 { return s.epoch } +func (s *Session) Kind() SessionKind { return s.kind } +func (s *Session) Phase() SessionPhase { return s.phase } +func (s *Session) InvalidateReason() string { return s.invalidateReason } +func (s *Session) StartLSN() uint64 { return s.startLSN } +func (s *Session) TargetLSN() uint64 { return s.targetLSN } +func (s *Session) FrozenTargetLSN() uint64 { return s.frozenTargetLSN } +func (s *Session) RecoveredTo() uint64 { return s.recoveredTo } + // Active returns true if the session is not completed or invalidated. func (s *Session) Active() bool { - return s.Phase != PhaseCompleted && s.Phase != PhaseInvalidated + return s.phase != PhaseCompleted && s.phase != PhaseInvalidated } -// Advance moves to the next phase. Returns false if the transition is invalid. -func (s *Session) Advance(phase SessionPhase) bool { +// Converged returns true if recovery reached the target. +func (s *Session) Converged() bool { + return s.targetLSN > 0 && s.recoveredTo >= s.targetLSN +} + +// Internal mutation methods — called by Sender under its lock. + +func (s *Session) advance(phase SessionPhase) bool { if !s.Active() { return false } - if !validTransitions[s.Phase][phase] { + if !validTransitions[s.phase][phase] { return false } - s.Phase = phase + s.phase = phase return true } -// SetRange sets the recovery LSN range. -func (s *Session) SetRange(start, target uint64) { - s.StartLSN = start - s.TargetLSN = target +func (s *Session) setRange(start, target uint64) { + s.startLSN = start + s.targetLSN = target } -// UpdateProgress records catch-up progress (monotonic). -func (s *Session) UpdateProgress(recoveredTo uint64) { - if recoveredTo > s.RecoveredTo { - s.RecoveredTo = recoveredTo +func (s *Session) updateProgress(recoveredTo uint64) { + if recoveredTo > s.recoveredTo { + s.recoveredTo = recoveredTo } } -// Converged returns true if recovery reached the target. -func (s *Session) Converged() bool { - return s.TargetLSN > 0 && s.RecoveredTo >= s.TargetLSN -} - func (s *Session) complete() { - s.Phase = PhaseCompleted + s.phase = PhaseCompleted } func (s *Session) invalidate(reason string) { if !s.Active() { return } - s.Phase = PhaseInvalidated - s.InvalidateReason = reason + s.phase = PhaseInvalidated + s.invalidateReason = reason }