diff --git a/backoff.go b/backoff.go index b4941b6..d113e68 100644 --- a/backoff.go +++ b/backoff.go @@ -4,6 +4,7 @@ package backoff import ( "math" "math/rand" + "sync/atomic" "time" ) @@ -14,19 +15,19 @@ import ( // Backoff is not generally concurrent-safe, but the ForAttempt method can // be used concurrently. type Backoff struct { - //Factor is the multiplying factor for each increment step - attempt, Factor float64 - //Jitter eases contention by randomizing backoff steps + attempt uint64 + // Factor is the multiplying factor for each increment step + Factor float64 + // Jitter eases contention by randomizing backoff steps Jitter bool - //Min and Max are the minimum and maximum values of the counter + // Min and Max are the minimum and maximum values of the counter Min, Max time.Duration } // Duration returns the duration for the current attempt before incrementing // the attempt counter. See ForAttempt. func (b *Backoff) Duration() time.Duration { - d := b.ForAttempt(b.attempt) - b.attempt++ + d := b.ForAttempt(float64(atomic.AddUint64(&b.attempt, 1) - 1)) return d } @@ -80,12 +81,12 @@ func (b *Backoff) ForAttempt(attempt float64) time.Duration { // Reset restarts the current attempt counter at zero. func (b *Backoff) Reset() { - b.attempt = 0 + atomic.StoreUint64(&b.attempt, 0) } // Attempt returns the current attempt counter value. func (b *Backoff) Attempt() float64 { - return b.attempt + return float64(atomic.LoadUint64(&b.attempt)) } // Copy returns a backoff with equals constraints as the original diff --git a/backoff_test.go b/backoff_test.go index d1c9845..90b68c2 100644 --- a/backoff_test.go +++ b/backoff_test.go @@ -2,6 +2,7 @@ package backoff import ( "reflect" + "sync" "testing" "time" ) @@ -120,7 +121,28 @@ func TestCopy(t *testing.T) { equals(t, b, b2) } +func TestConcurrent(t *testing.T) { + b := &Backoff{ + Min: 100 * time.Millisecond, + Max: 10 * time.Second, + Factor: 2, + } + + wg := &sync.WaitGroup{} + + test := func() { + time.Sleep(b.Duration()) + wg.Done() + } + + wg.Add(2) + go test() + go test() + wg.Wait() +} + func between(t *testing.T, actual, low, high time.Duration) { + t.Helper() if actual < low { t.Fatalf("Got %s, Expecting >= %s", actual, low) } @@ -130,6 +152,7 @@ func between(t *testing.T, actual, low, high time.Duration) { } func equals(t *testing.T, v1, v2 interface{}) { + t.Helper() if !reflect.DeepEqual(v1, v2) { t.Fatalf("Got %v, Expecting %v", v1, v2) }