diff --git a/timingwheel/timingwheel.go b/timingwheel/timingwheel.go index a27d334..251993a 100644 --- a/timingwheel/timingwheel.go +++ b/timingwheel/timingwheel.go @@ -1,20 +1,10 @@ package timingwheel import ( - "github.com/siddontang/golib/log" "sync" "time" ) -type TaskFunc func() - -type bucket struct { - c chan struct{} - tasks []TaskFunc -} - -const defaultTasksSize = 16 - type TimingWheel struct { sync.Mutex @@ -25,7 +15,7 @@ type TimingWheel struct { maxTimeout time.Duration - buckets []bucket + cs []chan struct{} pos int } @@ -40,11 +30,10 @@ func NewTimingWheel(interval time.Duration, buckets int) *TimingWheel { w.maxTimeout = time.Duration(interval * (time.Duration(buckets))) - w.buckets = make([]bucket, buckets) + w.cs = make([]chan struct{}, buckets) - for i := range w.buckets { - w.buckets[i].c = make(chan struct{}) - w.buckets[i].tasks = make([]TaskFunc, 0, defaultTasksSize) + for i := range w.cs { + w.cs[i] = make(chan struct{}) } w.ticker = time.NewTicker(interval) @@ -64,29 +53,15 @@ func (w *TimingWheel) After(timeout time.Duration) <-chan struct{} { w.Lock() - index := (w.pos + int(timeout/w.interval)) % len(w.buckets) + index := (w.pos + int(timeout/w.interval)) % len(w.cs) - b := w.buckets[index].c + b := w.cs[index] w.Unlock() return b } -func (w *TimingWheel) AddTask(timeout time.Duration, f TaskFunc) { - if timeout >= w.maxTimeout { - panic("timeout too much, over maxtimeout") - } - - w.Lock() - - index := (w.pos + int(timeout/w.interval)) % len(w.buckets) - - w.buckets[index].tasks = append(w.buckets[index].tasks, f) - - w.Unlock() -} - func (w *TimingWheel) run() { for { select { @@ -102,30 +77,12 @@ func (w *TimingWheel) run() { func (w *TimingWheel) onTicker() { w.Lock() - lastC := w.buckets[w.pos].c - tasks := w.buckets[w.pos].tasks + lastC := w.cs[w.pos] + w.cs[w.pos] = make(chan struct{}) - w.buckets[w.pos].c = make(chan struct{}) - w.buckets[w.pos].tasks = w.buckets[w.pos].tasks[0:0:defaultTasksSize] - - w.pos = (w.pos + 1) % len(w.buckets) + w.pos = (w.pos + 1) % len(w.cs) w.Unlock() close(lastC) - - if len(tasks) > 0 { - f := func(tasks []TaskFunc) { - defer func() { - if e := recover(); e != nil { - log.Fatal("run task fatal %v", e) - } - }() - for _, task := range tasks { - task() - } - } - - go f(tasks) - } } diff --git a/timingwheel/timingwheel_test.go b/timingwheel/timingwheel_test.go index fce45d1..8aa09a6 100644 --- a/timingwheel/timingwheel_test.go +++ b/timingwheel/timingwheel_test.go @@ -6,29 +6,12 @@ import ( ) func TestTimingWheel(t *testing.T) { - w := NewTimingWheel(1*time.Second, 10) + w := NewTimingWheel(100*time.Millisecond, 10) - println(time.Now().Unix()) for { select { - case <-w.After(1 * time.Second): - println(time.Now().Unix()) + case <-w.After(200 * time.Millisecond): return } } } - -func TestTask(t *testing.T) { - w := NewTimingWheel(1*time.Second, 10) - - r := make(chan struct{}) - f := func() { - println("hello world") - r <- struct{}{} - } - - w.AddTask(1*time.Second, f) - - <-r - println("over") -}