Browse Source

fix: encapsulate engine sender/session authority state

All mutable state on Sender and Session is now unexported:
- Sender.state, .epoch, .endpoint, .session, .stopped → accessors
- Session.id, .phase, .kind, etc. → read-only accessors
- Session() replaced by SessionSnapshot() (returns disconnected copy)
- SessionID() and HasActiveSession() for common queries
- AttachSession returns (sessionID, error) not (*Session, error)
- SupersedeSession returns sessionID not *Session

Budget configuration via SessionOption:
- WithBudget(CatchUpBudget) passed to AttachSession
- No direct field mutation on session from external code

New test: Encapsulation_SnapshotIsReadOnly proves snapshot
mutation does not leak back to sender state.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
feature/sw-block
pingqiu 2 days ago
parent
commit
bb24b4b039
  1. 198
      sw-block/engine/replication/ownership_test.go
  2. 14
      sw-block/engine/replication/registry.go
  3. 302
      sw-block/engine/replication/sender.go
  4. 120
      sw-block/engine/replication/session.go

198
sw-block/engine/replication/ownership_test.go

@ -4,7 +4,6 @@ import "testing"
// ============================================================
// Phase 05 Slice 1: Engine ownership/fencing tests
// Mapped to V2 acceptance criteria and boundary cases.
// ============================================================
// --- Changed-address invalidation (A10) ---
@ -20,22 +19,18 @@ func TestEngine_ChangedAddress_InvalidatesSession(t *testing.T) {
})
s := r.Sender("r1:9333")
sess := s.Session()
s.BeginConnect(sess.ID)
sessID := s.SessionID()
s.BeginConnect(sessID)
// Address changes mid-recovery.
r.Reconcile(map[string]Endpoint{
"r1:9333": {DataAddr: "r1:9333", CtrlAddr: "r1:9445", Version: 2},
}, 1)
if sess.Active() {
t.Fatal("session should be invalidated by address change")
if s.HasActiveSession() {
t.Fatal("session should be invalidated")
}
if r.Sender("r1:9333") != s {
t.Fatal("sender identity should be preserved")
}
if s.State != StateDisconnected {
t.Fatalf("state=%s, want disconnected", s.State)
if s.State() != StateDisconnected {
t.Fatalf("state=%s", s.State())
}
}
@ -50,10 +45,8 @@ func TestEngine_ChangedAddress_NewSessionAfterUpdate(t *testing.T) {
})
s := r.Sender("r1:9333")
oldSess := s.Session()
s.BeginConnect(oldSess.ID)
oldID := s.SessionID()
// Address change + new assignment.
r.Reconcile(map[string]Endpoint{
"r1:9333": {DataAddr: "r1:9333", Version: 2},
}, 1)
@ -66,9 +59,8 @@ func TestEngine_ChangedAddress_NewSessionAfterUpdate(t *testing.T) {
if len(result.SessionsCreated) != 1 {
t.Fatalf("should create new session: %v", result)
}
newSess := s.Session()
if newSess.ID == oldSess.ID {
t.Fatal("new session should have different ID")
if s.SessionID() == oldID {
t.Fatal("should have different session ID")
}
}
@ -76,8 +68,7 @@ func TestEngine_ChangedAddress_NewSessionAfterUpdate(t *testing.T) {
func TestEngine_StaleSessionID_RejectedAtAllAPIs(t *testing.T) {
s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1)
sess, _ := s.AttachSession(1, SessionCatchUp)
staleID := sess.ID
staleID, _ := s.AttachSession(1, SessionCatchUp)
s.UpdateEpoch(2)
s.AttachSession(2, SessionCatchUp)
@ -101,22 +92,18 @@ func TestEngine_StaleSessionID_RejectedAtAllAPIs(t *testing.T) {
func TestEngine_StaleCompletion_AfterSupersede(t *testing.T) {
s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1)
sess1, _ := s.AttachSession(1, SessionCatchUp)
id1 := sess1.ID
id1, _ := s.AttachSession(1, SessionCatchUp)
s.UpdateEpoch(2)
sess2, _ := s.AttachSession(2, SessionCatchUp)
s.AttachSession(2, SessionCatchUp)
// Old session completion rejected.
if s.CompleteSessionByID(id1) {
t.Fatal("stale completion must be rejected")
}
// New session still active.
if !sess2.Active() {
if s.HasActiveSession() != true {
t.Fatal("new session should be active")
}
// Sender not moved to InSync.
if s.State == StateInSync {
if s.State() == StateInSync {
t.Fatal("sender should not be InSync from stale completion")
}
}
@ -130,23 +117,18 @@ func TestEngine_EpochBump_InvalidatesAllSessions(t *testing.T) {
"r1:9333": {DataAddr: "r1:9333", Version: 1},
"r2:9333": {DataAddr: "r2:9333", Version: 1},
},
Epoch: 1,
Epoch: 1,
RecoveryTargets: map[string]SessionKind{
"r1:9333": SessionCatchUp,
"r2:9333": SessionCatchUp,
},
})
s1 := r.Sender("r1:9333")
s2 := r.Sender("r2:9333")
sess1 := s1.Session()
sess2 := s2.Session()
count := r.InvalidateEpoch(2)
if count != 2 {
t.Fatalf("should invalidate 2, got %d", count)
}
if sess1.Active() || sess2.Active() {
if r.Sender("r1:9333").HasActiveSession() || r.Sender("r2:9333").HasActiveSession() {
t.Fatal("both sessions should be invalidated")
}
}
@ -154,10 +136,8 @@ func TestEngine_EpochBump_InvalidatesAllSessions(t *testing.T) {
func TestEngine_EpochBump_StaleAssignment_Rejected(t *testing.T) {
r := NewRegistry()
r.ApplyAssignment(AssignmentIntent{
Endpoints: map[string]Endpoint{
"r1:9333": {DataAddr: "r1:9333", Version: 1},
},
Epoch: 2,
Endpoints: map[string]Endpoint{"r1:9333": {DataAddr: "r1:9333", Version: 1}},
Epoch: 2,
})
result := r.ApplyAssignment(AssignmentIntent{
@ -175,38 +155,38 @@ func TestEngine_EpochBump_StaleAssignment_Rejected(t *testing.T) {
func TestEngine_Rebuild_CatchUpAPIs_Rejected(t *testing.T) {
s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1)
sess, _ := s.AttachSession(1, SessionRebuild)
s.BeginConnect(sess.ID)
s.RecordHandshake(sess.ID, 0, 100)
sessID, _ := s.AttachSession(1, SessionRebuild)
s.BeginConnect(sessID)
s.RecordHandshake(sessID, 0, 100)
if err := s.BeginCatchUp(sess.ID); err == nil {
t.Fatal("rebuild: BeginCatchUp should be rejected")
if err := s.BeginCatchUp(sessID); err == nil {
t.Fatal("rebuild: BeginCatchUp should reject")
}
if err := s.RecordCatchUpProgress(sess.ID, 50); err == nil {
t.Fatal("rebuild: RecordCatchUpProgress should be rejected")
if err := s.RecordCatchUpProgress(sessID, 50); err == nil {
t.Fatal("rebuild: RecordCatchUpProgress should reject")
}
if s.CompleteSessionByID(sess.ID) {
t.Fatal("rebuild: catch-up completion should be rejected")
if s.CompleteSessionByID(sessID) {
t.Fatal("rebuild: catch-up completion should reject")
}
}
func TestEngine_Rebuild_FullLifecycle(t *testing.T) {
s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1)
sess, _ := s.AttachSession(1, SessionRebuild)
s.BeginConnect(sess.ID)
s.RecordHandshake(sess.ID, 0, 100)
s.SelectRebuildSource(sess.ID, 50, true, 100)
s.BeginRebuildTransfer(sess.ID)
s.RecordRebuildTransferProgress(sess.ID, 50)
s.BeginRebuildTailReplay(sess.ID)
s.RecordRebuildTailProgress(sess.ID, 100)
if err := s.CompleteRebuild(sess.ID); err != nil {
t.Fatalf("rebuild completion: %v", err)
sessID, _ := s.AttachSession(1, SessionRebuild)
s.BeginConnect(sessID)
s.RecordHandshake(sessID, 0, 100)
s.SelectRebuildSource(sessID, 50, true, 100)
s.BeginRebuildTransfer(sessID)
s.RecordRebuildTransferProgress(sessID, 50)
s.BeginRebuildTailReplay(sessID)
s.RecordRebuildTailProgress(sessID, 100)
if err := s.CompleteRebuild(sessID); err != nil {
t.Fatalf("rebuild: %v", err)
}
if s.State != StateInSync {
t.Fatalf("state=%s, want in_sync", s.State)
if s.State() != StateInSync {
t.Fatalf("state=%s", s.State())
}
}
@ -214,37 +194,66 @@ func TestEngine_Rebuild_FullLifecycle(t *testing.T) {
func TestEngine_FrozenTarget_RejectsChase(t *testing.T) {
s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1)
sess, _ := s.AttachSession(1, SessionCatchUp)
sessID, _ := s.AttachSession(1, SessionCatchUp)
s.BeginConnect(sess.ID)
s.RecordHandshake(sess.ID, 0, 50)
s.BeginCatchUp(sess.ID)
s.BeginConnect(sessID)
s.RecordHandshake(sessID, 0, 50)
s.BeginCatchUp(sessID)
if err := s.RecordCatchUpProgress(sess.ID, 51); err == nil {
t.Fatal("progress beyond frozen target should be rejected")
if err := s.RecordCatchUpProgress(sessID, 51); err == nil {
t.Fatal("beyond frozen target should be rejected")
}
}
func TestEngine_BudgetViolation_Escalates(t *testing.T) {
s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1)
sess, _ := s.AttachSession(1, SessionCatchUp)
sess.Budget = &CatchUpBudget{MaxDurationTicks: 5}
sessID, _ := s.AttachSession(1, SessionCatchUp, WithBudget(CatchUpBudget{MaxDurationTicks: 5}))
s.BeginConnect(sess.ID)
s.RecordHandshake(sess.ID, 0, 100)
s.BeginCatchUp(sess.ID, 0)
s.RecordCatchUpProgress(sess.ID, 10)
s.BeginConnect(sessID)
s.RecordHandshake(sessID, 0, 100)
s.BeginCatchUp(sessID, 0)
s.RecordCatchUpProgress(sessID, 10)
v, _ := s.CheckBudget(sess.ID, 10)
v, _ := s.CheckBudget(sessID, 10)
if v != BudgetDurationExceeded {
t.Fatalf("budget=%s", v)
}
if s.State != StateNeedsRebuild {
t.Fatalf("state=%s", s.State)
if s.State() != StateNeedsRebuild {
t.Fatalf("state=%s", s.State())
}
}
// --- Encapsulation: no direct state mutation ---
func TestEngine_Encapsulation_SnapshotIsReadOnly(t *testing.T) {
s := NewSender("r1:9333", Endpoint{DataAddr: "r1:9333", Version: 1}, 1)
sessID, _ := s.AttachSession(1, SessionCatchUp)
snap := s.SessionSnapshot()
if snap == nil || !snap.Active {
t.Fatal("should have active session snapshot")
}
// Mutating the snapshot does not affect the sender.
snap.Phase = PhaseCompleted
snap.Active = false
// Sender's session is still active.
if !s.HasActiveSession() {
t.Fatal("sender should still have active session after snapshot mutation")
}
snap2 := s.SessionSnapshot()
if snap2.Phase == PhaseCompleted {
t.Fatal("snapshot mutation should not leak back to sender")
}
// Can still execute on the real session.
if err := s.BeginConnect(sessID); err != nil {
t.Fatalf("execution should still work: %v", err)
}
}
// --- E2E: 3 replicas, 3 outcomes ---
// --- E2E ---
func TestEngine_E2E_ThreeReplicas_ThreeOutcomes(t *testing.T) {
r := NewRegistry()
@ -264,47 +273,46 @@ func TestEngine_E2E_ThreeReplicas_ThreeOutcomes(t *testing.T) {
// r1: zero-gap.
r1 := r.Sender("r1:9333")
s1 := r1.Session()
r1.BeginConnect(s1.ID)
o1, _ := r1.RecordHandshakeWithOutcome(s1.ID, HandshakeResult{
id1 := r1.SessionID()
r1.BeginConnect(id1)
o1, _ := r1.RecordHandshakeWithOutcome(id1, HandshakeResult{
ReplicaFlushedLSN: 100, CommittedLSN: 100, RetentionStartLSN: 50,
})
if o1 != OutcomeZeroGap {
t.Fatalf("r1: %s", o1)
}
r1.CompleteSessionByID(s1.ID)
r1.CompleteSessionByID(id1)
// r2: catch-up.
r2 := r.Sender("r2:9333")
s2 := r2.Session()
r2.BeginConnect(s2.ID)
o2, _ := r2.RecordHandshakeWithOutcome(s2.ID, HandshakeResult{
id2 := r2.SessionID()
r2.BeginConnect(id2)
o2, _ := r2.RecordHandshakeWithOutcome(id2, HandshakeResult{
ReplicaFlushedLSN: 70, CommittedLSN: 100, RetentionStartLSN: 50,
})
if o2 != OutcomeCatchUp {
t.Fatalf("r2: %s", o2)
}
r2.BeginCatchUp(s2.ID)
r2.RecordCatchUpProgress(s2.ID, 100)
r2.CompleteSessionByID(s2.ID)
r2.BeginCatchUp(id2)
r2.RecordCatchUpProgress(id2, 100)
r2.CompleteSessionByID(id2)
// r3: needs rebuild.
r3 := r.Sender("r3:9333")
s3 := r3.Session()
r3.BeginConnect(s3.ID)
o3, _ := r3.RecordHandshakeWithOutcome(s3.ID, HandshakeResult{
id3 := r3.SessionID()
r3.BeginConnect(id3)
o3, _ := r3.RecordHandshakeWithOutcome(id3, HandshakeResult{
ReplicaFlushedLSN: 10, CommittedLSN: 100, RetentionStartLSN: 50,
})
if o3 != OutcomeNeedsRebuild {
t.Fatalf("r3: %s", o3)
}
// Final states.
if r1.State != StateInSync || r2.State != StateInSync {
t.Fatalf("r1=%s r2=%s", r1.State, r2.State)
if r1.State() != StateInSync || r2.State() != StateInSync {
t.Fatalf("r1=%s r2=%s", r1.State(), r2.State())
}
if r3.State != StateNeedsRebuild {
t.Fatalf("r3=%s", r3.State)
if r3.State() != StateNeedsRebuild {
t.Fatalf("r3=%s", r3.State())
}
if r.InSyncCount() != 2 {
t.Fatalf("in_sync=%d", r.InSyncCount())

14
sw-block/engine/replication/registry.go

@ -75,14 +75,14 @@ func (r *Registry) ApplyAssignment(intent AssignmentIntent) AssignmentResult {
result.SessionsFailed = append(result.SessionsFailed, replicaID)
continue
}
if intent.Epoch < sender.Epoch {
if intent.Epoch < sender.Epoch() {
result.SessionsFailed = append(result.SessionsFailed, replicaID)
continue
}
_, err := sender.AttachSession(intent.Epoch, kind)
if err != nil {
sess := sender.SupersedeSession(kind, "assignment_intent")
if sess != nil {
id := sender.SupersedeSession(kind, "assignment_intent")
if id != 0 {
result.SessionsSuperseded = append(result.SessionsSuperseded, replicaID)
} else {
result.SessionsFailed = append(result.SessionsFailed, replicaID)
@ -110,7 +110,7 @@ func (r *Registry) All() []*Sender {
out = append(out, s)
}
sort.Slice(out, func(i, j int) bool {
return out[i].ReplicaID < out[j].ReplicaID
return out[i].ReplicaID() < out[j].ReplicaID()
})
return out
}
@ -137,7 +137,7 @@ func (r *Registry) InSyncCount() int {
defer r.mu.RUnlock()
count := 0
for _, s := range r.senders {
if s.State == StateInSync {
if s.State() == StateInSync {
count++
}
}
@ -150,8 +150,8 @@ func (r *Registry) InvalidateEpoch(currentEpoch uint64) int {
defer r.mu.RUnlock()
count := 0
for _, s := range r.senders {
sess := s.Session()
if sess != nil && sess.Epoch < currentEpoch && sess.Active() {
snap := s.SessionSnapshot()
if snap != nil && snap.Epoch < currentEpoch && snap.Active {
s.InvalidateSession("epoch_bump", StateDisconnected)
count++
}

302
sw-block/engine/replication/sender.go

@ -6,15 +6,15 @@ import (
)
// Sender owns the replication channel to one replica. It is the authority
// boundary for all execution operations — every API validates the session
// ID before mutating state.
// boundary for all execution operations. All mutable state is unexported —
// external code reads state through accessors and mutates through execution APIs.
type Sender struct {
mu sync.Mutex
ReplicaID string
Endpoint Endpoint
Epoch uint64
State ReplicaState
replicaID string
endpoint Endpoint
epoch uint64
state ReplicaState
session *Session
stopped bool
@ -23,26 +23,92 @@ type Sender struct {
// NewSender creates a sender for a replica.
func NewSender(replicaID string, endpoint Endpoint, epoch uint64) *Sender {
return &Sender{
ReplicaID: replicaID,
Endpoint: endpoint,
Epoch: epoch,
State: StateDisconnected,
replicaID: replicaID,
endpoint: endpoint,
epoch: epoch,
state: StateDisconnected,
}
}
// Read-only accessors.
func (s *Sender) ReplicaID() string { s.mu.Lock(); defer s.mu.Unlock(); return s.replicaID }
func (s *Sender) Endpoint() Endpoint { s.mu.Lock(); defer s.mu.Unlock(); return s.endpoint }
func (s *Sender) Epoch() uint64 { s.mu.Lock(); defer s.mu.Unlock(); return s.epoch }
func (s *Sender) State() ReplicaState { s.mu.Lock(); defer s.mu.Unlock(); return s.state }
func (s *Sender) Stopped() bool { s.mu.Lock(); defer s.mu.Unlock(); return s.stopped }
// SessionSnapshot returns a read-only copy of the current session state.
// Returns nil if no session is active. The returned snapshot is disconnected
// from the live session — mutations to the Sender do not affect it.
func (s *Sender) SessionSnapshot() *SessionSnapshot {
s.mu.Lock()
defer s.mu.Unlock()
if s.session == nil {
return nil
}
return &SessionSnapshot{
ID: s.session.id,
ReplicaID: s.session.replicaID,
Epoch: s.session.epoch,
Kind: s.session.kind,
Phase: s.session.phase,
InvalidateReason: s.session.invalidateReason,
StartLSN: s.session.startLSN,
TargetLSN: s.session.targetLSN,
FrozenTargetLSN: s.session.frozenTargetLSN,
RecoveredTo: s.session.recoveredTo,
Active: s.session.Active(),
}
}
// SessionSnapshot is a read-only copy of session state for external inspection.
type SessionSnapshot struct {
ID uint64
ReplicaID string
Epoch uint64
Kind SessionKind
Phase SessionPhase
InvalidateReason string
StartLSN uint64
TargetLSN uint64
FrozenTargetLSN uint64
RecoveredTo uint64
Active bool
}
// SessionID returns the current session ID, or 0 if no session.
func (s *Sender) SessionID() uint64 {
s.mu.Lock()
defer s.mu.Unlock()
if s.session == nil {
return 0
}
return s.session.id
}
// HasActiveSession returns true if a session is currently active.
func (s *Sender) HasActiveSession() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.session != nil && s.session.Active()
}
// === Lifecycle APIs ===
// UpdateEpoch advances the sender's epoch. Invalidates stale sessions.
func (s *Sender) UpdateEpoch(epoch uint64) {
s.mu.Lock()
defer s.mu.Unlock()
if s.stopped || epoch <= s.Epoch {
if s.stopped || epoch <= s.epoch {
return
}
oldEpoch := s.Epoch
s.Epoch = epoch
if s.session != nil && s.session.Epoch < epoch {
oldEpoch := s.epoch
s.epoch = epoch
if s.session != nil && s.session.epoch < epoch {
s.session.invalidate(fmt.Sprintf("epoch_advanced_%d_to_%d", oldEpoch, epoch))
s.session = nil
s.State = StateDisconnected
s.state = StateDisconnected
}
}
@ -53,59 +119,62 @@ func (s *Sender) UpdateEndpoint(ep Endpoint) {
if s.stopped {
return
}
if s.Endpoint.Changed(ep) && s.session != nil {
if s.endpoint.Changed(ep) && s.session != nil {
s.session.invalidate("endpoint_changed")
s.session = nil
s.State = StateDisconnected
s.state = StateDisconnected
}
s.endpoint = ep
}
// SessionOption configures a newly created session.
type SessionOption func(s *Session)
// WithBudget attaches a catch-up budget to the session.
func WithBudget(budget CatchUpBudget) SessionOption {
return func(s *Session) {
b := budget // copy
s.budget = &b
}
s.Endpoint = ep
}
// AttachSession creates a new recovery session. Epoch must match sender epoch.
func (s *Sender) AttachSession(epoch uint64, kind SessionKind) (*Session, error) {
func (s *Sender) AttachSession(epoch uint64, kind SessionKind, opts ...SessionOption) (uint64, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.stopped {
return nil, fmt.Errorf("sender stopped")
return 0, fmt.Errorf("sender stopped")
}
if epoch != s.Epoch {
return nil, fmt.Errorf("epoch mismatch: sender=%d session=%d", s.Epoch, epoch)
if epoch != s.epoch {
return 0, fmt.Errorf("epoch mismatch: sender=%d session=%d", s.epoch, epoch)
}
if s.session != nil && s.session.Active() {
return nil, fmt.Errorf("session already active (id=%d)", s.session.ID)
return 0, fmt.Errorf("session already active (id=%d)", s.session.id)
}
sess := newSession(s.replicaID, epoch, kind)
for _, opt := range opts {
opt(sess)
}
sess := newSession(s.ReplicaID, epoch, kind)
s.session = sess
return sess, nil
return sess.id, nil
}
// SupersedeSession invalidates current session and attaches new at sender epoch.
func (s *Sender) SupersedeSession(kind SessionKind, reason string) *Session {
func (s *Sender) SupersedeSession(kind SessionKind, reason string, opts ...SessionOption) uint64 {
s.mu.Lock()
defer s.mu.Unlock()
if s.stopped {
return nil
return 0
}
if s.session != nil {
s.session.invalidate(reason)
}
sess := newSession(s.ReplicaID, s.Epoch, kind)
sess := newSession(s.replicaID, s.epoch, kind)
for _, opt := range opts {
opt(sess)
}
s.session = sess
return sess
}
// Session returns the current session, or nil.
func (s *Sender) Session() *Session {
s.mu.Lock()
defer s.mu.Unlock()
return s.session
}
// Stopped returns true if the sender has been stopped.
func (s *Sender) Stopped() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.stopped
return sess.id
}
// Stop shuts down the sender.
@ -130,26 +199,24 @@ func (s *Sender) InvalidateSession(reason string, targetState ReplicaState) {
s.session.invalidate(reason)
s.session = nil
}
s.State = targetState
s.state = targetState
}
// === Catch-up execution APIs ===
// BeginConnect transitions init → connecting.
func (s *Sender) BeginConnect(sessionID uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if !s.session.Advance(PhaseConnecting) {
return fmt.Errorf("cannot begin connect: phase=%s", s.session.Phase)
if !s.session.advance(PhaseConnecting) {
return fmt.Errorf("cannot begin connect: phase=%s", s.session.phase)
}
s.State = StateConnecting
s.state = StateConnecting
return nil
}
// RecordHandshake records handshake result and sets catch-up range.
func (s *Sender) RecordHandshake(sessionID uint64, startLSN, targetLSN uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
@ -159,14 +226,13 @@ func (s *Sender) RecordHandshake(sessionID uint64, startLSN, targetLSN uint64) e
if targetLSN < startLSN {
return fmt.Errorf("invalid range: target=%d < start=%d", targetLSN, startLSN)
}
if !s.session.Advance(PhaseHandshake) {
return fmt.Errorf("cannot record handshake: phase=%s", s.session.Phase)
if !s.session.advance(PhaseHandshake) {
return fmt.Errorf("cannot record handshake: phase=%s", s.session.phase)
}
s.session.SetRange(startLSN, targetLSN)
s.session.setRange(startLSN, targetLSN)
return nil
}
// RecordHandshakeWithOutcome records handshake and classifies the recovery outcome.
func (s *Sender) RecordHandshakeWithOutcome(sessionID uint64, result HandshakeResult) (RecoveryOutcome, error) {
outcome := ClassifyRecoveryOutcome(result)
s.mu.Lock()
@ -174,106 +240,100 @@ func (s *Sender) RecordHandshakeWithOutcome(sessionID uint64, result HandshakeRe
if err := s.checkAuthority(sessionID); err != nil {
return outcome, err
}
if s.session.Phase != PhaseConnecting {
return outcome, fmt.Errorf("handshake requires connecting, got %s", s.session.Phase)
if s.session.phase != PhaseConnecting {
return outcome, fmt.Errorf("handshake requires connecting, got %s", s.session.phase)
}
if outcome == OutcomeNeedsRebuild {
s.session.invalidate("gap_exceeds_retention")
s.session = nil
s.State = StateNeedsRebuild
s.state = StateNeedsRebuild
return outcome, nil
}
if !s.session.Advance(PhaseHandshake) {
return outcome, fmt.Errorf("cannot advance to handshake: phase=%s", s.session.Phase)
if !s.session.advance(PhaseHandshake) {
return outcome, fmt.Errorf("cannot advance to handshake: phase=%s", s.session.phase)
}
switch outcome {
case OutcomeZeroGap:
s.session.SetRange(result.ReplicaFlushedLSN, result.ReplicaFlushedLSN)
s.session.setRange(result.ReplicaFlushedLSN, result.ReplicaFlushedLSN)
case OutcomeCatchUp:
if result.ReplicaFlushedLSN > result.CommittedLSN {
s.session.TruncateRequired = true
s.session.TruncateToLSN = result.CommittedLSN
s.session.SetRange(result.CommittedLSN, result.CommittedLSN)
s.session.truncateRequired = true
s.session.truncateToLSN = result.CommittedLSN
s.session.setRange(result.CommittedLSN, result.CommittedLSN)
} else {
s.session.SetRange(result.ReplicaFlushedLSN, result.CommittedLSN)
s.session.setRange(result.ReplicaFlushedLSN, result.CommittedLSN)
}
}
return outcome, nil
}
// BeginCatchUp transitions to catch-up phase. Rejects rebuild sessions.
// Freezes the target unconditionally.
func (s *Sender) BeginCatchUp(sessionID uint64, startTick ...uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Kind == SessionRebuild {
if s.session.kind == SessionRebuild {
return fmt.Errorf("rebuild sessions must use rebuild APIs")
}
if !s.session.Advance(PhaseCatchUp) {
return fmt.Errorf("cannot begin catch-up: phase=%s", s.session.Phase)
if !s.session.advance(PhaseCatchUp) {
return fmt.Errorf("cannot begin catch-up: phase=%s", s.session.phase)
}
s.State = StateCatchingUp
s.session.FrozenTargetLSN = s.session.TargetLSN
s.state = StateCatchingUp
s.session.frozenTargetLSN = s.session.targetLSN
if len(startTick) > 0 {
s.session.Tracker.StartTick = startTick[0]
s.session.Tracker.LastProgressTick = startTick[0]
s.session.tracker.StartTick = startTick[0]
s.session.tracker.LastProgressTick = startTick[0]
}
return nil
}
// RecordCatchUpProgress records catch-up progress. Rejects rebuild sessions.
// Entry counting uses LSN delta. Tick is required when ProgressDeadlineTicks > 0.
func (s *Sender) RecordCatchUpProgress(sessionID uint64, recoveredTo uint64, tick ...uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Kind == SessionRebuild {
if s.session.kind == SessionRebuild {
return fmt.Errorf("rebuild sessions must use rebuild APIs")
}
if s.session.Phase != PhaseCatchUp {
return fmt.Errorf("progress requires catchup phase, got %s", s.session.Phase)
if s.session.phase != PhaseCatchUp {
return fmt.Errorf("progress requires catchup phase, got %s", s.session.phase)
}
if recoveredTo <= s.session.RecoveredTo {
return fmt.Errorf("progress regression: %d <= %d", recoveredTo, s.session.RecoveredTo)
if recoveredTo <= s.session.recoveredTo {
return fmt.Errorf("progress regression: %d <= %d", recoveredTo, s.session.recoveredTo)
}
if s.session.FrozenTargetLSN > 0 && recoveredTo > s.session.FrozenTargetLSN {
return fmt.Errorf("progress %d exceeds frozen target %d", recoveredTo, s.session.FrozenTargetLSN)
if s.session.frozenTargetLSN > 0 && recoveredTo > s.session.frozenTargetLSN {
return fmt.Errorf("progress %d exceeds frozen target %d", recoveredTo, s.session.frozenTargetLSN)
}
if s.session.Budget != nil && s.session.Budget.ProgressDeadlineTicks > 0 && len(tick) == 0 {
if s.session.budget != nil && s.session.budget.ProgressDeadlineTicks > 0 && len(tick) == 0 {
return fmt.Errorf("tick required when ProgressDeadlineTicks > 0")
}
delta := recoveredTo - s.session.RecoveredTo
s.session.Tracker.EntriesReplayed += delta
s.session.UpdateProgress(recoveredTo)
delta := recoveredTo - s.session.recoveredTo
s.session.tracker.EntriesReplayed += delta
s.session.updateProgress(recoveredTo)
if len(tick) > 0 {
s.session.Tracker.LastProgressTick = tick[0]
s.session.tracker.LastProgressTick = tick[0]
}
return nil
}
// RecordTruncation confirms divergent tail cleanup.
func (s *Sender) RecordTruncation(sessionID uint64, truncatedToLSN uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if !s.session.TruncateRequired {
if !s.session.truncateRequired {
return fmt.Errorf("truncation not required")
}
if truncatedToLSN != s.session.TruncateToLSN {
return fmt.Errorf("truncation LSN mismatch: expected %d, got %d", s.session.TruncateToLSN, truncatedToLSN)
if truncatedToLSN != s.session.truncateToLSN {
return fmt.Errorf("truncation LSN mismatch: expected %d, got %d", s.session.truncateToLSN, truncatedToLSN)
}
s.session.TruncateRecorded = true
s.session.truncateRecorded = true
return nil
}
// CompleteSessionByID completes catch-up sessions. Rejects rebuild sessions.
func (s *Sender) CompleteSessionByID(sessionID uint64) bool {
s.mu.Lock()
defer s.mu.Unlock()
@ -281,19 +341,19 @@ func (s *Sender) CompleteSessionByID(sessionID uint64) bool {
return false
}
sess := s.session
if sess.Kind == SessionRebuild {
if sess.kind == SessionRebuild {
return false
}
if sess.TruncateRequired && !sess.TruncateRecorded {
if sess.truncateRequired && !sess.truncateRecorded {
return false
}
switch sess.Phase {
switch sess.phase {
case PhaseCatchUp:
if !sess.Converged() {
return false
}
case PhaseHandshake:
if sess.TargetLSN != sess.StartLSN {
if sess.targetLSN != sess.startLSN {
return false
}
default:
@ -301,48 +361,46 @@ func (s *Sender) CompleteSessionByID(sessionID uint64) bool {
}
sess.complete()
s.session = nil
s.State = StateInSync
s.state = StateInSync
return true
}
// CheckBudget evaluates catch-up budget. Auto-escalates on violation.
func (s *Sender) CheckBudget(sessionID uint64, currentTick uint64) (BudgetViolation, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.checkAuthority(sessionID); err != nil {
return BudgetOK, err
}
if s.session.Budget == nil {
if s.session.budget == nil {
return BudgetOK, nil
}
v := s.session.Budget.Check(s.session.Tracker, currentTick)
v := s.session.budget.Check(s.session.tracker, currentTick)
if v != BudgetOK {
s.session.invalidate(fmt.Sprintf("budget_%s", v))
s.session = nil
s.State = StateNeedsRebuild
s.state = StateNeedsRebuild
}
return v, nil
}
// === Rebuild execution APIs ===
// SelectRebuildSource chooses rebuild source. Requires PhaseHandshake.
func (s *Sender) SelectRebuildSource(sessionID uint64, snapshotLSN uint64, snapshotValid bool, committedLSN uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Kind != SessionRebuild {
if s.session.kind != SessionRebuild {
return fmt.Errorf("not a rebuild session")
}
if s.session.Phase != PhaseHandshake {
return fmt.Errorf("requires PhaseHandshake, got %s", s.session.Phase)
if s.session.phase != PhaseHandshake {
return fmt.Errorf("requires PhaseHandshake, got %s", s.session.phase)
}
if s.session.Rebuild == nil {
if s.session.rebuild == nil {
return fmt.Errorf("rebuild state not initialized")
}
return s.session.Rebuild.SelectSource(snapshotLSN, snapshotValid, committedLSN)
return s.session.rebuild.SelectSource(snapshotLSN, snapshotValid, committedLSN)
}
func (s *Sender) BeginRebuildTransfer(sessionID uint64) error {
@ -351,10 +409,10 @@ func (s *Sender) BeginRebuildTransfer(sessionID uint64) error {
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Rebuild == nil {
if s.session.rebuild == nil {
return fmt.Errorf("no rebuild state")
}
return s.session.Rebuild.BeginTransfer()
return s.session.rebuild.BeginTransfer()
}
func (s *Sender) RecordRebuildTransferProgress(sessionID uint64, transferredTo uint64) error {
@ -363,10 +421,10 @@ func (s *Sender) RecordRebuildTransferProgress(sessionID uint64, transferredTo u
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Rebuild == nil {
if s.session.rebuild == nil {
return fmt.Errorf("no rebuild state")
}
return s.session.Rebuild.RecordTransferProgress(transferredTo)
return s.session.rebuild.RecordTransferProgress(transferredTo)
}
func (s *Sender) BeginRebuildTailReplay(sessionID uint64) error {
@ -375,10 +433,10 @@ func (s *Sender) BeginRebuildTailReplay(sessionID uint64) error {
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Rebuild == nil {
if s.session.rebuild == nil {
return fmt.Errorf("no rebuild state")
}
return s.session.Rebuild.BeginTailReplay()
return s.session.rebuild.BeginTailReplay()
}
func (s *Sender) RecordRebuildTailProgress(sessionID uint64, replayedTo uint64) error {
@ -387,32 +445,30 @@ func (s *Sender) RecordRebuildTailProgress(sessionID uint64, replayedTo uint64)
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Rebuild == nil {
if s.session.rebuild == nil {
return fmt.Errorf("no rebuild state")
}
return s.session.Rebuild.RecordTailReplayProgress(replayedTo)
return s.session.rebuild.RecordTailReplayProgress(replayedTo)
}
// CompleteRebuild completes a rebuild session. Requires ReadyToComplete.
func (s *Sender) CompleteRebuild(sessionID uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.checkAuthority(sessionID); err != nil {
return err
}
if s.session.Rebuild == nil {
if s.session.rebuild == nil {
return fmt.Errorf("no rebuild state")
}
if err := s.session.Rebuild.Complete(); err != nil {
if err := s.session.rebuild.Complete(); err != nil {
return err
}
s.session.complete()
s.session = nil
s.State = StateInSync
s.state = StateInSync
return nil
}
// checkAuthority validates session ownership.
func (s *Sender) checkAuthority(sessionID uint64) error {
if s.stopped {
return fmt.Errorf("sender stopped")
@ -420,11 +476,11 @@ func (s *Sender) checkAuthority(sessionID uint64) error {
if s.session == nil {
return fmt.Errorf("no active session")
}
if s.session.ID != sessionID {
return fmt.Errorf("session ID mismatch: active=%d requested=%d", s.session.ID, sessionID)
if s.session.id != sessionID {
return fmt.Errorf("session ID mismatch: active=%d requested=%d", s.session.id, sessionID)
}
if !s.session.Active() {
return fmt.Errorf("session %d not active (phase=%s)", sessionID, s.session.Phase)
return fmt.Errorf("session %d not active (phase=%s)", sessionID, s.session.phase)
}
return nil
}

120
sw-block/engine/replication/session.go

@ -6,98 +6,100 @@ import "sync/atomic"
var sessionIDCounter atomic.Uint64
// Session represents one recovery attempt for a specific replica at a
// specific epoch. It is owned by a Sender and gated by session ID at
// every execution step.
//
// Lifecycle:
// - Created via Sender.AttachSession or Sender.SupersedeSession
// - Advanced through phases: init → connecting → handshake → catchup → completed
// - Invalidated by: epoch bump, endpoint change, sender stop, timeout
// - Stale sessions (wrong ID) are rejected at every execution API
// specific epoch. All mutable state is unexported — external code interacts
// through Sender execution APIs only.
type Session struct {
ID uint64
ReplicaID string
Epoch uint64
Kind SessionKind
Phase SessionPhase
InvalidateReason string
// Progress tracking.
StartLSN uint64 // gap start (exclusive)
TargetLSN uint64 // gap end (inclusive)
FrozenTargetLSN uint64 // frozen at BeginCatchUp — catch-up will not chase beyond
RecoveredTo uint64 // highest LSN recovered so far
// Truncation.
TruncateRequired bool
TruncateToLSN uint64
TruncateRecorded bool
// Budget (nil = no enforcement).
Budget *CatchUpBudget
Tracker BudgetCheck
// Rebuild state (non-nil when Kind == SessionRebuild).
Rebuild *RebuildState
id uint64
replicaID string
epoch uint64
kind SessionKind
phase SessionPhase
invalidateReason string
startLSN uint64
targetLSN uint64
frozenTargetLSN uint64
recoveredTo uint64
truncateRequired bool
truncateToLSN uint64
truncateRecorded bool
budget *CatchUpBudget
tracker BudgetCheck
rebuild *RebuildState
}
func newSession(replicaID string, epoch uint64, kind SessionKind) *Session {
s := &Session{
ID: sessionIDCounter.Add(1),
ReplicaID: replicaID,
Epoch: epoch,
Kind: kind,
Phase: PhaseInit,
id: sessionIDCounter.Add(1),
replicaID: replicaID,
epoch: epoch,
kind: kind,
phase: PhaseInit,
}
if kind == SessionRebuild {
s.Rebuild = NewRebuildState()
s.rebuild = NewRebuildState()
}
return s
}
// Read-only accessors.
func (s *Session) ID() uint64 { return s.id }
func (s *Session) ReplicaID() string { return s.replicaID }
func (s *Session) Epoch() uint64 { return s.epoch }
func (s *Session) Kind() SessionKind { return s.kind }
func (s *Session) Phase() SessionPhase { return s.phase }
func (s *Session) InvalidateReason() string { return s.invalidateReason }
func (s *Session) StartLSN() uint64 { return s.startLSN }
func (s *Session) TargetLSN() uint64 { return s.targetLSN }
func (s *Session) FrozenTargetLSN() uint64 { return s.frozenTargetLSN }
func (s *Session) RecoveredTo() uint64 { return s.recoveredTo }
// Active returns true if the session is not completed or invalidated.
func (s *Session) Active() bool {
return s.Phase != PhaseCompleted && s.Phase != PhaseInvalidated
return s.phase != PhaseCompleted && s.phase != PhaseInvalidated
}
// Advance moves to the next phase. Returns false if the transition is invalid.
func (s *Session) Advance(phase SessionPhase) bool {
// Converged returns true if recovery reached the target.
func (s *Session) Converged() bool {
return s.targetLSN > 0 && s.recoveredTo >= s.targetLSN
}
// Internal mutation methods — called by Sender under its lock.
func (s *Session) advance(phase SessionPhase) bool {
if !s.Active() {
return false
}
if !validTransitions[s.Phase][phase] {
if !validTransitions[s.phase][phase] {
return false
}
s.Phase = phase
s.phase = phase
return true
}
// SetRange sets the recovery LSN range.
func (s *Session) SetRange(start, target uint64) {
s.StartLSN = start
s.TargetLSN = target
func (s *Session) setRange(start, target uint64) {
s.startLSN = start
s.targetLSN = target
}
// UpdateProgress records catch-up progress (monotonic).
func (s *Session) UpdateProgress(recoveredTo uint64) {
if recoveredTo > s.RecoveredTo {
s.RecoveredTo = recoveredTo
func (s *Session) updateProgress(recoveredTo uint64) {
if recoveredTo > s.recoveredTo {
s.recoveredTo = recoveredTo
}
}
// Converged returns true if recovery reached the target.
func (s *Session) Converged() bool {
return s.TargetLSN > 0 && s.RecoveredTo >= s.TargetLSN
}
func (s *Session) complete() {
s.Phase = PhaseCompleted
s.phase = PhaseCompleted
}
func (s *Session) invalidate(reason string) {
if !s.Active() {
return
}
s.Phase = PhaseInvalidated
s.InvalidateReason = reason
s.phase = PhaseInvalidated
s.invalidateReason = reason
}
Loading…
Cancel
Save