timingwheel now only support After func
This commit is contained in:
parent
5cd336a6a1
commit
b0d7cb4278
|
@ -1,20 +1,10 @@
|
|||
package timingwheel
|
||||
|
||||
import (
|
||||
"github.com/siddontang/golib/log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TaskFunc func()
|
||||
|
||||
type bucket struct {
|
||||
c chan struct{}
|
||||
tasks []TaskFunc
|
||||
}
|
||||
|
||||
const defaultTasksSize = 16
|
||||
|
||||
type TimingWheel struct {
|
||||
sync.Mutex
|
||||
|
||||
|
@ -25,7 +15,7 @@ type TimingWheel struct {
|
|||
|
||||
maxTimeout time.Duration
|
||||
|
||||
buckets []bucket
|
||||
cs []chan struct{}
|
||||
|
||||
pos int
|
||||
}
|
||||
|
@ -40,11 +30,10 @@ func NewTimingWheel(interval time.Duration, buckets int) *TimingWheel {
|
|||
|
||||
w.maxTimeout = time.Duration(interval * (time.Duration(buckets)))
|
||||
|
||||
w.buckets = make([]bucket, buckets)
|
||||
w.cs = make([]chan struct{}, buckets)
|
||||
|
||||
for i := range w.buckets {
|
||||
w.buckets[i].c = make(chan struct{})
|
||||
w.buckets[i].tasks = make([]TaskFunc, 0, defaultTasksSize)
|
||||
for i := range w.cs {
|
||||
w.cs[i] = make(chan struct{})
|
||||
}
|
||||
|
||||
w.ticker = time.NewTicker(interval)
|
||||
|
@ -64,29 +53,15 @@ func (w *TimingWheel) After(timeout time.Duration) <-chan struct{} {
|
|||
|
||||
w.Lock()
|
||||
|
||||
index := (w.pos + int(timeout/w.interval)) % len(w.buckets)
|
||||
index := (w.pos + int(timeout/w.interval)) % len(w.cs)
|
||||
|
||||
b := w.buckets[index].c
|
||||
b := w.cs[index]
|
||||
|
||||
w.Unlock()
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (w *TimingWheel) AddTask(timeout time.Duration, f TaskFunc) {
|
||||
if timeout >= w.maxTimeout {
|
||||
panic("timeout too much, over maxtimeout")
|
||||
}
|
||||
|
||||
w.Lock()
|
||||
|
||||
index := (w.pos + int(timeout/w.interval)) % len(w.buckets)
|
||||
|
||||
w.buckets[index].tasks = append(w.buckets[index].tasks, f)
|
||||
|
||||
w.Unlock()
|
||||
}
|
||||
|
||||
func (w *TimingWheel) run() {
|
||||
for {
|
||||
select {
|
||||
|
@ -102,30 +77,12 @@ func (w *TimingWheel) run() {
|
|||
func (w *TimingWheel) onTicker() {
|
||||
w.Lock()
|
||||
|
||||
lastC := w.buckets[w.pos].c
|
||||
tasks := w.buckets[w.pos].tasks
|
||||
lastC := w.cs[w.pos]
|
||||
w.cs[w.pos] = make(chan struct{})
|
||||
|
||||
w.buckets[w.pos].c = make(chan struct{})
|
||||
w.buckets[w.pos].tasks = w.buckets[w.pos].tasks[0:0:defaultTasksSize]
|
||||
|
||||
w.pos = (w.pos + 1) % len(w.buckets)
|
||||
w.pos = (w.pos + 1) % len(w.cs)
|
||||
|
||||
w.Unlock()
|
||||
|
||||
close(lastC)
|
||||
|
||||
if len(tasks) > 0 {
|
||||
f := func(tasks []TaskFunc) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Fatal("run task fatal %v", e)
|
||||
}
|
||||
}()
|
||||
for _, task := range tasks {
|
||||
task()
|
||||
}
|
||||
}
|
||||
|
||||
go f(tasks)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,29 +6,12 @@ import (
|
|||
)
|
||||
|
||||
func TestTimingWheel(t *testing.T) {
|
||||
w := NewTimingWheel(1*time.Second, 10)
|
||||
w := NewTimingWheel(100*time.Millisecond, 10)
|
||||
|
||||
println(time.Now().Unix())
|
||||
for {
|
||||
select {
|
||||
case <-w.After(1 * time.Second):
|
||||
println(time.Now().Unix())
|
||||
case <-w.After(200 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTask(t *testing.T) {
|
||||
w := NewTimingWheel(1*time.Second, 10)
|
||||
|
||||
r := make(chan struct{})
|
||||
f := func() {
|
||||
println("hello world")
|
||||
r <- struct{}{}
|
||||
}
|
||||
|
||||
w.AddTask(1*time.Second, f)
|
||||
|
||||
<-r
|
||||
println("over")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue