You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
530 lines
15 KiB
530 lines
15 KiB
package replication
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
)
|
|
|
|
// Sender owns the replication channel to one replica. It is the authority
|
|
// boundary for all execution operations. All mutable state is unexported —
|
|
// external code reads state through accessors and mutates through execution APIs.
|
|
type Sender struct {
|
|
mu sync.Mutex
|
|
|
|
replicaID string
|
|
endpoint Endpoint
|
|
epoch uint64
|
|
state ReplicaState
|
|
|
|
session *Session
|
|
stopped bool
|
|
}
|
|
|
|
// NewSender creates a sender for a replica.
|
|
func NewSender(replicaID string, endpoint Endpoint, epoch uint64) *Sender {
|
|
return &Sender{
|
|
replicaID: replicaID,
|
|
endpoint: endpoint,
|
|
epoch: epoch,
|
|
state: StateDisconnected,
|
|
}
|
|
}
|
|
|
|
// Read-only accessors.
|
|
|
|
func (s *Sender) ReplicaID() string { s.mu.Lock(); defer s.mu.Unlock(); return s.replicaID }
|
|
func (s *Sender) Endpoint() Endpoint { s.mu.Lock(); defer s.mu.Unlock(); return s.endpoint }
|
|
func (s *Sender) Epoch() uint64 { s.mu.Lock(); defer s.mu.Unlock(); return s.epoch }
|
|
func (s *Sender) State() ReplicaState { s.mu.Lock(); defer s.mu.Unlock(); return s.state }
|
|
func (s *Sender) Stopped() bool { s.mu.Lock(); defer s.mu.Unlock(); return s.stopped }
|
|
|
|
// SessionSnapshot returns a read-only copy of the current session state.
|
|
// Returns nil if no session is active. The returned snapshot is disconnected
|
|
// from the live session — mutations to the Sender do not affect it.
|
|
func (s *Sender) SessionSnapshot() *SessionSnapshot {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.session == nil {
|
|
return nil
|
|
}
|
|
snap := &SessionSnapshot{
|
|
ID: s.session.id,
|
|
ReplicaID: s.session.replicaID,
|
|
Epoch: s.session.epoch,
|
|
Kind: s.session.kind,
|
|
Phase: s.session.phase,
|
|
InvalidateReason: s.session.invalidateReason,
|
|
StartLSN: s.session.startLSN,
|
|
TargetLSN: s.session.targetLSN,
|
|
FrozenTargetLSN: s.session.frozenTargetLSN,
|
|
RecoveredTo: s.session.recoveredTo,
|
|
Active: s.session.Active(),
|
|
TruncateRequired: s.session.truncateRequired,
|
|
TruncateToLSN: s.session.truncateToLSN,
|
|
TruncateRecorded: s.session.truncateRecorded,
|
|
}
|
|
if s.session.rebuild != nil {
|
|
snap.RebuildSource = s.session.rebuild.Source
|
|
snap.RebuildPhase = s.session.rebuild.Phase
|
|
}
|
|
return snap
|
|
}
|
|
|
|
// SessionSnapshot is a read-only copy of session state for external inspection.
|
|
type SessionSnapshot struct {
|
|
ID uint64
|
|
ReplicaID string
|
|
Epoch uint64
|
|
Kind SessionKind
|
|
Phase SessionPhase
|
|
InvalidateReason string
|
|
StartLSN uint64
|
|
TargetLSN uint64
|
|
FrozenTargetLSN uint64
|
|
RecoveredTo uint64
|
|
Active bool
|
|
|
|
// Truncation state.
|
|
TruncateRequired bool
|
|
TruncateToLSN uint64
|
|
TruncateRecorded bool
|
|
|
|
// Rebuild state (nil if not a rebuild session).
|
|
RebuildSource RebuildSource
|
|
RebuildPhase RebuildPhase
|
|
}
|
|
|
|
// SessionID returns the current session ID, or 0 if no session.
|
|
func (s *Sender) SessionID() uint64 {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.session == nil {
|
|
return 0
|
|
}
|
|
return s.session.id
|
|
}
|
|
|
|
// HasActiveSession returns true if a session is currently active.
|
|
func (s *Sender) HasActiveSession() bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
return s.session != nil && s.session.Active()
|
|
}
|
|
|
|
// === Lifecycle APIs ===
|
|
|
|
// UpdateEpoch advances the sender's epoch. Invalidates stale sessions.
|
|
func (s *Sender) UpdateEpoch(epoch uint64) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.stopped || epoch <= s.epoch {
|
|
return
|
|
}
|
|
oldEpoch := s.epoch
|
|
s.epoch = epoch
|
|
if s.session != nil && s.session.epoch < epoch {
|
|
s.session.invalidate(fmt.Sprintf("epoch_advanced_%d_to_%d", oldEpoch, epoch))
|
|
s.session = nil
|
|
s.state = StateDisconnected
|
|
}
|
|
}
|
|
|
|
// UpdateEndpoint updates the target address. Invalidates session on change.
|
|
func (s *Sender) UpdateEndpoint(ep Endpoint) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.stopped {
|
|
return
|
|
}
|
|
if s.endpoint.Changed(ep) && s.session != nil {
|
|
s.session.invalidate("endpoint_changed")
|
|
s.session = nil
|
|
s.state = StateDisconnected
|
|
}
|
|
s.endpoint = ep
|
|
}
|
|
|
|
// SessionOption configures a newly created session.
|
|
type SessionOption func(s *Session)
|
|
|
|
// WithBudget attaches a catch-up budget to the session.
|
|
func WithBudget(budget CatchUpBudget) SessionOption {
|
|
return func(s *Session) {
|
|
b := budget // copy
|
|
s.budget = &b
|
|
}
|
|
}
|
|
|
|
// AttachSession creates a new recovery session. Epoch must match sender epoch.
|
|
func (s *Sender) AttachSession(epoch uint64, kind SessionKind, opts ...SessionOption) (uint64, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.stopped {
|
|
return 0, fmt.Errorf("sender stopped")
|
|
}
|
|
if epoch != s.epoch {
|
|
return 0, fmt.Errorf("epoch mismatch: sender=%d session=%d", s.epoch, epoch)
|
|
}
|
|
if s.session != nil && s.session.Active() {
|
|
return 0, fmt.Errorf("session already active (id=%d)", s.session.id)
|
|
}
|
|
sess := newSession(s.replicaID, epoch, kind)
|
|
for _, opt := range opts {
|
|
opt(sess)
|
|
}
|
|
s.session = sess
|
|
return sess.id, nil
|
|
}
|
|
|
|
// SupersedeSession invalidates current session and attaches new at sender epoch.
|
|
func (s *Sender) SupersedeSession(kind SessionKind, reason string, opts ...SessionOption) uint64 {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.stopped {
|
|
return 0
|
|
}
|
|
if s.session != nil {
|
|
s.session.invalidate(reason)
|
|
}
|
|
sess := newSession(s.replicaID, s.epoch, kind)
|
|
for _, opt := range opts {
|
|
opt(sess)
|
|
}
|
|
s.session = sess
|
|
return sess.id
|
|
}
|
|
|
|
// Stop shuts down the sender.
|
|
func (s *Sender) Stop() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.stopped {
|
|
return
|
|
}
|
|
s.stopped = true
|
|
if s.session != nil {
|
|
s.session.invalidate("sender_stopped")
|
|
s.session = nil
|
|
}
|
|
}
|
|
|
|
// InvalidateSession invalidates current session with target state.
|
|
func (s *Sender) InvalidateSession(reason string, targetState ReplicaState) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.session != nil {
|
|
s.session.invalidate(reason)
|
|
s.session = nil
|
|
}
|
|
s.state = targetState
|
|
}
|
|
|
|
// === Catch-up execution APIs ===
|
|
|
|
func (s *Sender) BeginConnect(sessionID uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if !s.session.advance(PhaseConnecting) {
|
|
return fmt.Errorf("cannot begin connect: phase=%s", s.session.phase)
|
|
}
|
|
s.state = StateConnecting
|
|
return nil
|
|
}
|
|
|
|
func (s *Sender) RecordHandshake(sessionID uint64, startLSN, targetLSN uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if targetLSN < startLSN {
|
|
return fmt.Errorf("invalid range: target=%d < start=%d", targetLSN, startLSN)
|
|
}
|
|
if !s.session.advance(PhaseHandshake) {
|
|
return fmt.Errorf("cannot record handshake: phase=%s", s.session.phase)
|
|
}
|
|
s.session.setRange(startLSN, targetLSN)
|
|
return nil
|
|
}
|
|
|
|
func (s *Sender) RecordHandshakeWithOutcome(sessionID uint64, result HandshakeResult) (RecoveryOutcome, error) {
|
|
outcome := ClassifyRecoveryOutcome(result)
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return outcome, err
|
|
}
|
|
if s.session.phase != PhaseConnecting {
|
|
return outcome, fmt.Errorf("handshake requires connecting, got %s", s.session.phase)
|
|
}
|
|
if outcome == OutcomeNeedsRebuild {
|
|
s.session.invalidate("gap_exceeds_retention")
|
|
s.session = nil
|
|
s.state = StateNeedsRebuild
|
|
return outcome, nil
|
|
}
|
|
if !s.session.advance(PhaseHandshake) {
|
|
return outcome, fmt.Errorf("cannot advance to handshake: phase=%s", s.session.phase)
|
|
}
|
|
switch outcome {
|
|
case OutcomeZeroGap:
|
|
s.session.setRange(result.ReplicaFlushedLSN, result.ReplicaFlushedLSN)
|
|
case OutcomeCatchUp:
|
|
if result.ReplicaFlushedLSN > result.CommittedLSN {
|
|
s.session.truncateRequired = true
|
|
s.session.truncateToLSN = result.CommittedLSN
|
|
s.session.setRange(result.CommittedLSN, result.CommittedLSN)
|
|
} else {
|
|
s.session.setRange(result.ReplicaFlushedLSN, result.CommittedLSN)
|
|
}
|
|
}
|
|
return outcome, nil
|
|
}
|
|
|
|
// RecordHandshakeFromHistory records the handshake using the primary's
|
|
// RetainedHistory as the authoritative recoverability source. This is the
|
|
// preferred engine-level API — it ensures recovery decisions are backed
|
|
// by actual retention state, not caller-supplied values.
|
|
func (s *Sender) RecordHandshakeFromHistory(sessionID uint64, replicaFlushedLSN uint64, history *RetainedHistory) (RecoveryOutcome, *RecoverabilityProof, error) {
|
|
if history == nil {
|
|
return OutcomeNeedsRebuild, nil, fmt.Errorf("nil RetainedHistory")
|
|
}
|
|
proof := history.ProveRecoverability(replicaFlushedLSN)
|
|
hr := history.MakeHandshakeResult(replicaFlushedLSN)
|
|
outcome, err := s.RecordHandshakeWithOutcome(sessionID, hr)
|
|
return outcome, &proof, err
|
|
}
|
|
|
|
// SelectRebuildFromHistory selects the rebuild source using the primary's
|
|
// RetainedHistory. This is the preferred engine-level API — it ensures
|
|
// the rebuild-source decision accounts for both checkpoint trust AND
|
|
// tail replayability.
|
|
func (s *Sender) SelectRebuildFromHistory(sessionID uint64, history *RetainedHistory) error {
|
|
if history == nil {
|
|
return fmt.Errorf("nil RetainedHistory")
|
|
}
|
|
source, snapLSN := history.RebuildSourceDecision()
|
|
valid := source == RebuildSnapshotTail
|
|
return s.SelectRebuildSource(sessionID, snapLSN, valid, history.CommittedLSN)
|
|
}
|
|
|
|
func (s *Sender) BeginCatchUp(sessionID uint64, startTick ...uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.kind == SessionRebuild {
|
|
return fmt.Errorf("rebuild sessions must use rebuild APIs")
|
|
}
|
|
if !s.session.advance(PhaseCatchUp) {
|
|
return fmt.Errorf("cannot begin catch-up: phase=%s", s.session.phase)
|
|
}
|
|
s.state = StateCatchingUp
|
|
s.session.frozenTargetLSN = s.session.targetLSN
|
|
if len(startTick) > 0 {
|
|
s.session.tracker.StartTick = startTick[0]
|
|
s.session.tracker.LastProgressTick = startTick[0]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Sender) RecordCatchUpProgress(sessionID uint64, recoveredTo uint64, tick ...uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.kind == SessionRebuild {
|
|
return fmt.Errorf("rebuild sessions must use rebuild APIs")
|
|
}
|
|
if s.session.phase != PhaseCatchUp {
|
|
return fmt.Errorf("progress requires catchup phase, got %s", s.session.phase)
|
|
}
|
|
if recoveredTo <= s.session.recoveredTo {
|
|
return fmt.Errorf("progress regression: %d <= %d", recoveredTo, s.session.recoveredTo)
|
|
}
|
|
if s.session.frozenTargetLSN > 0 && recoveredTo > s.session.frozenTargetLSN {
|
|
return fmt.Errorf("progress %d exceeds frozen target %d", recoveredTo, s.session.frozenTargetLSN)
|
|
}
|
|
if s.session.budget != nil && s.session.budget.ProgressDeadlineTicks > 0 && len(tick) == 0 {
|
|
return fmt.Errorf("tick required when ProgressDeadlineTicks > 0")
|
|
}
|
|
delta := recoveredTo - s.session.recoveredTo
|
|
s.session.tracker.EntriesReplayed += delta
|
|
s.session.updateProgress(recoveredTo)
|
|
if len(tick) > 0 {
|
|
s.session.tracker.LastProgressTick = tick[0]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Sender) RecordTruncation(sessionID uint64, truncatedToLSN uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if !s.session.truncateRequired {
|
|
return fmt.Errorf("truncation not required")
|
|
}
|
|
if truncatedToLSN != s.session.truncateToLSN {
|
|
return fmt.Errorf("truncation LSN mismatch: expected %d, got %d", s.session.truncateToLSN, truncatedToLSN)
|
|
}
|
|
s.session.truncateRecorded = true
|
|
return nil
|
|
}
|
|
|
|
func (s *Sender) CompleteSessionByID(sessionID uint64) bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.checkAuthority(sessionID) != nil {
|
|
return false
|
|
}
|
|
sess := s.session
|
|
if sess.kind == SessionRebuild {
|
|
return false
|
|
}
|
|
if sess.truncateRequired && !sess.truncateRecorded {
|
|
return false
|
|
}
|
|
switch sess.phase {
|
|
case PhaseCatchUp:
|
|
if !sess.Converged() {
|
|
return false
|
|
}
|
|
case PhaseHandshake:
|
|
if sess.targetLSN != sess.startLSN {
|
|
return false
|
|
}
|
|
default:
|
|
return false
|
|
}
|
|
sess.complete()
|
|
s.session = nil
|
|
s.state = StateInSync
|
|
return true
|
|
}
|
|
|
|
func (s *Sender) CheckBudget(sessionID uint64, currentTick uint64) (BudgetViolation, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return BudgetOK, err
|
|
}
|
|
if s.session.budget == nil {
|
|
return BudgetOK, nil
|
|
}
|
|
v := s.session.budget.Check(s.session.tracker, currentTick)
|
|
if v != BudgetOK {
|
|
s.session.invalidate(fmt.Sprintf("budget_%s", v))
|
|
s.session = nil
|
|
s.state = StateNeedsRebuild
|
|
}
|
|
return v, nil
|
|
}
|
|
|
|
// === Rebuild execution APIs ===
|
|
|
|
func (s *Sender) SelectRebuildSource(sessionID uint64, snapshotLSN uint64, snapshotValid bool, committedLSN uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.kind != SessionRebuild {
|
|
return fmt.Errorf("not a rebuild session")
|
|
}
|
|
if s.session.phase != PhaseHandshake {
|
|
return fmt.Errorf("requires PhaseHandshake, got %s", s.session.phase)
|
|
}
|
|
if s.session.rebuild == nil {
|
|
return fmt.Errorf("rebuild state not initialized")
|
|
}
|
|
return s.session.rebuild.SelectSource(snapshotLSN, snapshotValid, committedLSN)
|
|
}
|
|
|
|
func (s *Sender) BeginRebuildTransfer(sessionID uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.rebuild == nil {
|
|
return fmt.Errorf("no rebuild state")
|
|
}
|
|
return s.session.rebuild.BeginTransfer()
|
|
}
|
|
|
|
func (s *Sender) RecordRebuildTransferProgress(sessionID uint64, transferredTo uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.rebuild == nil {
|
|
return fmt.Errorf("no rebuild state")
|
|
}
|
|
return s.session.rebuild.RecordTransferProgress(transferredTo)
|
|
}
|
|
|
|
func (s *Sender) BeginRebuildTailReplay(sessionID uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.rebuild == nil {
|
|
return fmt.Errorf("no rebuild state")
|
|
}
|
|
return s.session.rebuild.BeginTailReplay()
|
|
}
|
|
|
|
func (s *Sender) RecordRebuildTailProgress(sessionID uint64, replayedTo uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.rebuild == nil {
|
|
return fmt.Errorf("no rebuild state")
|
|
}
|
|
return s.session.rebuild.RecordTailReplayProgress(replayedTo)
|
|
}
|
|
|
|
func (s *Sender) CompleteRebuild(sessionID uint64) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := s.checkAuthority(sessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.session.rebuild == nil {
|
|
return fmt.Errorf("no rebuild state")
|
|
}
|
|
if err := s.session.rebuild.Complete(); err != nil {
|
|
return err
|
|
}
|
|
s.session.complete()
|
|
s.session = nil
|
|
s.state = StateInSync
|
|
return nil
|
|
}
|
|
|
|
func (s *Sender) checkAuthority(sessionID uint64) error {
|
|
if s.stopped {
|
|
return fmt.Errorf("sender stopped")
|
|
}
|
|
if s.session == nil {
|
|
return fmt.Errorf("no active session")
|
|
}
|
|
if s.session.id != sessionID {
|
|
return fmt.Errorf("session ID mismatch: active=%d requested=%d", s.session.id, sessionID)
|
|
}
|
|
if !s.session.Active() {
|
|
return fmt.Errorf("session %d not active (phase=%s)", sessionID, s.session.phase)
|
|
}
|
|
return nil
|
|
}
|