diff --git a/backoff.go b/backoff.go index 3ae595f..fb5d9b7 100644 --- a/backoff.go +++ b/backoff.go @@ -11,9 +11,9 @@ import ( //Used in conjunction with the time package. type Backoff struct { //Factor is the multiplying factor for each increment step - attempts, Factor int + attempts, Factor float64 //Min and Max are the minimum and maximum values of the counter - curr, Min, Max time.Duration + Min, Max time.Duration } //Returns the current value of the counter and then @@ -30,16 +30,16 @@ func (b *Backoff) Duration() time.Duration { if b.Factor == 0 { b.Factor = 2 } - if b.curr == 0 { - b.curr = b.Min + //calculate this duration + dur := float64(b.Min) * math.Pow(b.Factor, b.attempts) + //cap! + if dur > float64(b.Max) { + return b.Max } - - //calculate next duration in ms - ms := float64(b.curr) * math.Pow(float64(b.Factor), float64(b.attempts)) //bump attempts count b.attempts++ //return as a time.Duration - return time.Duration(math.Min(ms, float64(b.Max))) + return time.Duration(dur) } //Resets the current value of the counter back to Min diff --git a/backoff_test.go b/backoff_test.go index d68dd41..27d7e23 100644 --- a/backoff_test.go +++ b/backoff_test.go @@ -13,21 +13,45 @@ func Test1(t *testing.T) { Factor: 2, } - if b.Duration() != 100*time.Millisecond { - t.Error("Should be 100ms") - } - - if b.Duration() != 200*time.Millisecond { - t.Error("Should be 200ms") - } - - if b.Duration() != 400*time.Millisecond { - t.Error("Should be 400ms") - } - + equals(t, b.Duration(), 100*time.Millisecond) + equals(t, b.Duration(), 200*time.Millisecond) + equals(t, b.Duration(), 400*time.Millisecond) b.Reset() + equals(t, b.Duration(), 100*time.Millisecond) +} - if b.Duration() != 100*time.Millisecond { - t.Error("Should be 100ms again") +func Test2(t *testing.T) { + + b := &Backoff{ + Min: 100 * time.Millisecond, + Max: 10 * time.Second, + Factor: 1.5, + } + + equals(t, b.Duration(), 100*time.Millisecond) + equals(t, b.Duration(), 150*time.Millisecond) + equals(t, b.Duration(), 225*time.Millisecond) + b.Reset() + equals(t, b.Duration(), 100*time.Millisecond) +} + +func Test3(t *testing.T) { + + b := &Backoff{ + Min: 100 * time.Nanosecond, + Max: 10 * time.Second, + Factor: 1.75, + } + + equals(t, b.Duration(), 100*time.Nanosecond) + equals(t, b.Duration(), 175*time.Nanosecond) + equals(t, b.Duration(), 306*time.Nanosecond) + b.Reset() + equals(t, b.Duration(), 100*time.Nanosecond) +} + +func equals(t *testing.T, d1, d2 time.Duration) { + if d1 != d2 { + t.Fatalf("Got %s, Expecting %s", d1, d2) } }