diff --git a/timingwheel/timingwheel.go b/timingwheel/timingwheel.go index dc62b44..210d8fa 100644 --- a/timingwheel/timingwheel.go +++ b/timingwheel/timingwheel.go @@ -1,25 +1,36 @@ package timingwheel import ( + "github.com/siddontang/golib/log" + "sync" "time" ) -type regItem struct { - timeout time.Duration - reply chan chan bool +type TaskFunc func() + +type bucket struct { + c chan struct{} + tasks []TaskFunc } +const defaultTasksSize = 16 +const defaultTaskPool = 4 + type TimingWheel struct { + sync.Mutex + interval time.Duration ticker *time.Ticker - quit chan bool - - reg chan *regItem + quit chan struct{} maxTimeout time.Duration - buckets []chan bool - pos int + + buckets []bucket + + pos int + + tasks chan []TaskFunc } func NewTimingWheel(interval time.Duration, buckets int) *TimingWheel { @@ -27,46 +38,67 @@ func NewTimingWheel(interval time.Duration, buckets int) *TimingWheel { w.interval = interval - w.reg = make(chan *regItem, 128) - - w.quit = make(chan bool) + w.quit = make(chan struct{}) w.pos = 0 w.maxTimeout = time.Duration(interval * (time.Duration(buckets))) - w.buckets = make([]chan bool, buckets) + w.buckets = make([]bucket, buckets) for i := range w.buckets { - w.buckets[i] = make(chan bool) + w.buckets[i].c = make(chan struct{}) + w.buckets[i].tasks = make([]TaskFunc, 0, defaultTasksSize) } w.ticker = time.NewTicker(interval) go w.run() + w.tasks = make(chan []TaskFunc, 1024) + + for i := 0; i < defaultTaskPool; i++ { + go w.taskPool() + } + return w } func (w *TimingWheel) Stop() { - w.quit <- true + close(w.quit) } -func (w *TimingWheel) After(timeout time.Duration) <-chan bool { +func (w *TimingWheel) After(timeout time.Duration) <-chan struct{} { if timeout >= w.maxTimeout { panic("timeout too much, over maxtimeout") } - reply := make(chan chan bool) + w.Lock() - w.reg <- ®Item{timeout: timeout, reply: reply} + index := (w.pos + int(timeout/w.interval)) % len(w.buckets) - return <-reply + b := w.buckets[index].c + + w.Unlock() + + return b +} + +func (w *TimingWheel) AddTask(timeout time.Duration, f TaskFunc) { + if timeout >= w.maxTimeout { + panic("timeout too much, over maxtimeout") + } + + w.Lock() + + index := (w.pos + int(timeout/w.interval)) % len(w.buckets) + + w.buckets[index].tasks = append(w.buckets[index].tasks, f) + + w.Unlock() } func (w *TimingWheel) run() { for { select { - case item := <-w.reg: - w.register(item) case <-w.ticker.C: w.onTicker() case <-w.quit: @@ -76,20 +108,39 @@ func (w *TimingWheel) run() { } } -func (w *TimingWheel) register(item *regItem) { - timeout := item.timeout - - index := (w.pos + int(timeout/w.interval)) % len(w.buckets) - - b := w.buckets[index] - - item.reply <- b -} - func (w *TimingWheel) onTicker() { - close(w.buckets[w.pos]) + w.Lock() - w.buckets[w.pos] = make(chan bool) + lastC := w.buckets[w.pos].c + tasks := w.buckets[w.pos].tasks + + w.buckets[w.pos].c = make(chan struct{}) + w.buckets[w.pos].tasks = w.buckets[w.pos].tasks[0:0:defaultTasksSize] w.pos = (w.pos + 1) % len(w.buckets) + + w.Unlock() + + close(lastC) + + w.tasks <- tasks +} + +func (w *TimingWheel) taskPool() { + defer func() { + if e := recover(); e != nil { + log.Fatal("task pool fatal %v", e) + } + }() + + for { + select { + case tasks := <-w.tasks: + for _, task := range tasks { + task() + } + case <-w.quit: + return + } + } } diff --git a/timingwheel/timingwheel_test.go b/timingwheel/timingwheel_test.go index 6385831..fce45d1 100644 --- a/timingwheel/timingwheel_test.go +++ b/timingwheel/timingwheel_test.go @@ -8,12 +8,27 @@ import ( func TestTimingWheel(t *testing.T) { w := NewTimingWheel(1*time.Second, 10) - t.Log(time.Now().Unix()) + println(time.Now().Unix()) for { select { case <-w.After(1 * time.Second): - t.Log(time.Now().Unix()) + println(time.Now().Unix()) return } } } + +func TestTask(t *testing.T) { + w := NewTimingWheel(1*time.Second, 10) + + r := make(chan struct{}) + f := func() { + println("hello world") + r <- struct{}{} + } + + w.AddTask(1*time.Second, f) + + <-r + println("over") +}