You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
3.6 KiB
146 lines
3.6 KiB
package util
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestRetryUntil(t *testing.T) {
|
|
// Test case 1: Function succeeds immediately
|
|
t.Run("SucceedsImmediately", func(t *testing.T) {
|
|
callCount := 0
|
|
err := RetryUntil("test", func() error {
|
|
callCount++
|
|
return nil
|
|
}, func(err error) bool {
|
|
return false
|
|
})
|
|
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if callCount != 1 {
|
|
t.Errorf("Expected 1 call, got %d", callCount)
|
|
}
|
|
})
|
|
|
|
// Test case 2: Function fails with retryable error, then succeeds
|
|
t.Run("SucceedsAfterRetry", func(t *testing.T) {
|
|
callCount := 0
|
|
err := RetryUntil("test", func() error {
|
|
callCount++
|
|
if callCount < 3 {
|
|
return errors.New("retryable error")
|
|
}
|
|
return nil
|
|
}, func(err error) bool {
|
|
return err.Error() == "retryable error"
|
|
})
|
|
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if callCount != 3 {
|
|
t.Errorf("Expected 3 calls, got %d", callCount)
|
|
}
|
|
})
|
|
|
|
// Test case 3: Function fails with non-retryable error
|
|
t.Run("FailsNonRetryable", func(t *testing.T) {
|
|
callCount := 0
|
|
err := RetryUntil("test", func() error {
|
|
callCount++
|
|
return errors.New("fatal error")
|
|
}, func(err error) bool {
|
|
return err.Error() == "retryable error"
|
|
})
|
|
|
|
if err == nil || err.Error() != "fatal error" {
|
|
t.Errorf("Expected 'fatal error', got %v", err)
|
|
}
|
|
if callCount != 1 {
|
|
t.Errorf("Expected 1 call, got %d", callCount)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestRetryWithBackoff(t *testing.T) {
|
|
retryableErr := errors.New("unavailable")
|
|
shouldRetry := func(err error) bool { return err == retryableErr }
|
|
|
|
t.Run("SucceedsAfterRetries", func(t *testing.T) {
|
|
callCount := 0
|
|
err := RetryWithBackoff(context.Background(), "test", 30*time.Second, shouldRetry, func() error {
|
|
callCount++
|
|
if callCount < 3 {
|
|
return retryableErr
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Errorf("expected success, got %v", err)
|
|
}
|
|
if callCount != 3 {
|
|
t.Errorf("expected 3 calls, got %d", callCount)
|
|
}
|
|
})
|
|
|
|
t.Run("StopsOnNonRetryableError", func(t *testing.T) {
|
|
callCount := 0
|
|
fatalErr := errors.New("fatal")
|
|
err := RetryWithBackoff(context.Background(), "test", 30*time.Second, shouldRetry, func() error {
|
|
callCount++
|
|
return fatalErr
|
|
})
|
|
if err != fatalErr {
|
|
t.Errorf("expected fatal error, got %v", err)
|
|
}
|
|
if callCount != 1 {
|
|
t.Errorf("expected 1 call, got %d", callCount)
|
|
}
|
|
})
|
|
|
|
t.Run("StopsOnContextCancel", func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
callCount := 0
|
|
start := time.Now()
|
|
err := RetryWithBackoff(ctx, "test", 30*time.Second, shouldRetry, func() error {
|
|
callCount++
|
|
return retryableErr
|
|
})
|
|
elapsed := time.Since(start)
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
|
t.Errorf("expected DeadlineExceeded, got %v", err)
|
|
}
|
|
if callCount <= 1 {
|
|
t.Errorf("expected multiple calls, got %d", callCount)
|
|
}
|
|
if elapsed > 5*time.Second {
|
|
t.Errorf("took %v, expected to stop near 2s deadline", elapsed)
|
|
}
|
|
})
|
|
|
|
t.Run("StopsOnMaxDuration", func(t *testing.T) {
|
|
callCount := 0
|
|
start := time.Now()
|
|
err := RetryWithBackoff(context.Background(), "test", 3*time.Second, shouldRetry, func() error {
|
|
callCount++
|
|
return retryableErr
|
|
})
|
|
elapsed := time.Since(start)
|
|
if err != retryableErr {
|
|
t.Errorf("expected retryable error, got %v", err)
|
|
}
|
|
if callCount <= 1 {
|
|
t.Errorf("expected multiple calls, got %d", callCount)
|
|
}
|
|
// Should stop around 3s (maxDuration), not run forever
|
|
if elapsed > 6*time.Second {
|
|
t.Errorf("took %v, expected to stop near 3s maxDuration", elapsed)
|
|
}
|
|
})
|
|
}
|