From 3e39e52ddfd6c196c5fb34e6a65ffbce6bbae637 Mon Sep 17 00:00:00 2001 From: Evan Borgstrom Date: Wed, 25 Sep 2019 09:59:28 +0800 Subject: [PATCH] Fix concurrent access data race --- backoff.go | 24 +++++++++++++----------- backoff_test.go | 14 ++++++++------ 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/backoff.go b/backoff.go index b4941b6..e4fb068 100644 --- a/backoff.go +++ b/backoff.go @@ -4,6 +4,7 @@ package backoff import ( "math" "math/rand" + "sync/atomic" "time" ) @@ -14,19 +15,20 @@ import ( // Backoff is not generally concurrent-safe, but the ForAttempt method can // be used concurrently. type Backoff struct { - //Factor is the multiplying factor for each increment step - attempt, Factor float64 - //Jitter eases contention by randomizing backoff steps + attempt uint64 + // Factor is the multiplying factor for each increment step + Factor float64 + // Jitter eases contention by randomizing backoff steps Jitter bool - //Min and Max are the minimum and maximum values of the counter + // Min and Max are the minimum and maximum values of the counter Min, Max time.Duration } // Duration returns the duration for the current attempt before incrementing // the attempt counter. See ForAttempt. func (b *Backoff) Duration() time.Duration { - d := b.ForAttempt(b.attempt) - b.attempt++ + d := b.ForAttempt(atomic.LoadUint64(&b.attempt)) + atomic.AddUint64(&b.attempt, 1) return d } @@ -38,7 +40,7 @@ const maxInt64 = float64(math.MaxInt64 - 512) // attempt should be 0. // // ForAttempt is concurrent-safe. -func (b *Backoff) ForAttempt(attempt float64) time.Duration { +func (b *Backoff) ForAttempt(attempt uint64) time.Duration { // Zero-values are nonsensical, so we use // them to apply defaults min := b.Min @@ -59,7 +61,7 @@ func (b *Backoff) ForAttempt(attempt float64) time.Duration { } //calculate this duration minf := float64(min) - durf := minf * math.Pow(factor, attempt) + durf := minf * math.Pow(factor, float64(attempt)) if b.Jitter { durf = rand.Float64()*(durf-minf) + minf } @@ -80,12 +82,12 @@ func (b *Backoff) ForAttempt(attempt float64) time.Duration { // Reset restarts the current attempt counter at zero. func (b *Backoff) Reset() { - b.attempt = 0 + atomic.StoreUint64(&b.attempt, 0) } // Attempt returns the current attempt counter value. -func (b *Backoff) Attempt() float64 { - return b.attempt +func (b *Backoff) Attempt() uint64 { + return atomic.LoadUint64(&b.attempt) } // Copy returns a backoff with equals constraints as the original diff --git a/backoff_test.go b/backoff_test.go index 3f49f0a..7b6a6a9 100644 --- a/backoff_test.go +++ b/backoff_test.go @@ -83,17 +83,17 @@ func TestGetAttempt(t *testing.T) { Max: 10 * time.Second, Factor: 2, } - equals(t, b.Attempt(), float64(0)) + equals(t, b.Attempt(), uint64(0)) equals(t, b.Duration(), 100*time.Millisecond) - equals(t, b.Attempt(), float64(1)) + equals(t, b.Attempt(), uint64(1)) equals(t, b.Duration(), 200*time.Millisecond) - equals(t, b.Attempt(), float64(2)) + equals(t, b.Attempt(), uint64(2)) equals(t, b.Duration(), 400*time.Millisecond) - equals(t, b.Attempt(), float64(3)) + equals(t, b.Attempt(), uint64(3)) b.Reset() - equals(t, b.Attempt(), float64(0)) + equals(t, b.Attempt(), uint64(0)) equals(t, b.Duration(), 100*time.Millisecond) - equals(t, b.Attempt(), float64(1)) + equals(t, b.Attempt(), uint64(1)) } func TestJitter(t *testing.T) { @@ -142,6 +142,7 @@ func TestConcurrent(t *testing.T) { } func between(t *testing.T, actual, low, high time.Duration) { + t.Helper() if actual < low { t.Fatalf("Got %s, Expecting >= %s", actual, low) } @@ -151,6 +152,7 @@ func between(t *testing.T, actual, low, high time.Duration) { } func equals(t *testing.T, v1, v2 interface{}) { + t.Helper() if !reflect.DeepEqual(v1, v2) { t.Fatalf("Got %v, Expecting %v", v1, v2) }