timingwheel now only support After func

This commit is contained in:
siddontang 2014-04-15 16:37:22 +08:00
parent 5cd336a6a1
commit b0d7cb4278
2 changed files with 11 additions and 71 deletions

View File

@ -1,20 +1,10 @@
package timingwheel package timingwheel
import ( import (
"github.com/siddontang/golib/log"
"sync" "sync"
"time" "time"
) )
type TaskFunc func()
type bucket struct {
c chan struct{}
tasks []TaskFunc
}
const defaultTasksSize = 16
type TimingWheel struct { type TimingWheel struct {
sync.Mutex sync.Mutex
@ -25,7 +15,7 @@ type TimingWheel struct {
maxTimeout time.Duration maxTimeout time.Duration
buckets []bucket cs []chan struct{}
pos int pos int
} }
@ -40,11 +30,10 @@ func NewTimingWheel(interval time.Duration, buckets int) *TimingWheel {
w.maxTimeout = time.Duration(interval * (time.Duration(buckets))) w.maxTimeout = time.Duration(interval * (time.Duration(buckets)))
w.buckets = make([]bucket, buckets) w.cs = make([]chan struct{}, buckets)
for i := range w.buckets { for i := range w.cs {
w.buckets[i].c = make(chan struct{}) w.cs[i] = make(chan struct{})
w.buckets[i].tasks = make([]TaskFunc, 0, defaultTasksSize)
} }
w.ticker = time.NewTicker(interval) w.ticker = time.NewTicker(interval)
@ -64,29 +53,15 @@ func (w *TimingWheel) After(timeout time.Duration) <-chan struct{} {
w.Lock() w.Lock()
index := (w.pos + int(timeout/w.interval)) % len(w.buckets) index := (w.pos + int(timeout/w.interval)) % len(w.cs)
b := w.buckets[index].c b := w.cs[index]
w.Unlock() w.Unlock()
return b return b
} }
func (w *TimingWheel) AddTask(timeout time.Duration, f TaskFunc) {
if timeout >= w.maxTimeout {
panic("timeout too much, over maxtimeout")
}
w.Lock()
index := (w.pos + int(timeout/w.interval)) % len(w.buckets)
w.buckets[index].tasks = append(w.buckets[index].tasks, f)
w.Unlock()
}
func (w *TimingWheel) run() { func (w *TimingWheel) run() {
for { for {
select { select {
@ -102,30 +77,12 @@ func (w *TimingWheel) run() {
func (w *TimingWheel) onTicker() { func (w *TimingWheel) onTicker() {
w.Lock() w.Lock()
lastC := w.buckets[w.pos].c lastC := w.cs[w.pos]
tasks := w.buckets[w.pos].tasks w.cs[w.pos] = make(chan struct{})
w.buckets[w.pos].c = make(chan struct{}) w.pos = (w.pos + 1) % len(w.cs)
w.buckets[w.pos].tasks = w.buckets[w.pos].tasks[0:0:defaultTasksSize]
w.pos = (w.pos + 1) % len(w.buckets)
w.Unlock() w.Unlock()
close(lastC) close(lastC)
if len(tasks) > 0 {
f := func(tasks []TaskFunc) {
defer func() {
if e := recover(); e != nil {
log.Fatal("run task fatal %v", e)
}
}()
for _, task := range tasks {
task()
}
}
go f(tasks)
}
} }

View File

@ -6,29 +6,12 @@ import (
) )
func TestTimingWheel(t *testing.T) { func TestTimingWheel(t *testing.T) {
w := NewTimingWheel(1*time.Second, 10) w := NewTimingWheel(100*time.Millisecond, 10)
println(time.Now().Unix())
for { for {
select { select {
case <-w.After(1 * time.Second): case <-w.After(200 * time.Millisecond):
println(time.Now().Unix())
return return
} }
} }
} }
func TestTask(t *testing.T) {
w := NewTimingWheel(1*time.Second, 10)
r := make(chan struct{})
f := func() {
println("hello world")
r <- struct{}{}
}
w.AddTask(1*time.Second, f)
<-r
println("over")
}