diff --git a/ledis/ledis_db.go b/ledis/ledis_db.go index c24e295..c9609fb 100644 --- a/ledis/ledis_db.go +++ b/ledis/ledis_db.go @@ -39,6 +39,8 @@ type DB struct { setBatch *batch status uint8 + + lbkeys *lBlockKeys } func (l *Ledis) newDB(index uint8) *DB { @@ -60,6 +62,8 @@ func (l *Ledis) newDB(index uint8) *DB { d.binBatch = d.newBatch() d.setBatch = d.newBatch() + d.lbkeys = newLBlockKeys() + return d } diff --git a/ledis/multi.go b/ledis/multi.go index a549c2c..29abe34 100644 --- a/ledis/multi.go +++ b/ledis/multi.go @@ -47,6 +47,8 @@ func (db *DB) Multi() (*Multi, error) { m.DB.binBatch = m.newBatch() m.DB.setBatch = m.newBatch() + m.DB.lbkeys = db.lbkeys + return m, nil } diff --git a/ledis/t_list.go b/ledis/t_list.go index dd6e04b..e2f2655 100644 --- a/ledis/t_list.go +++ b/ledis/t_list.go @@ -1,9 +1,12 @@ package ledis import ( + "container/list" "encoding/binary" "errors" + "github.com/siddontang/go/hack" "github.com/siddontang/ledisdb/store" + "sync" "time" ) @@ -131,6 +134,11 @@ func (db *DB) lpush(key []byte, whereSeq int32, args ...[]byte) (int64, error) { db.lSetMeta(metaKey, headSeq, tailSeq) err = t.Commit() + + if err == nil { + db.lSignalAsReady(key, pushCnt) + } + return int64(size) + int64(pushCnt), err } @@ -485,3 +493,169 @@ func (db *DB) lEncodeMaxKey() []byte { ek[len(ek)-1] = LMetaType + 1 return ek } + +func (db *DB) BLPop(keys [][]byte, timeout int) ([]interface{}, error) { + return db.lblockPop(keys, listHeadSeq, timeout) +} + +func (db *DB) BRPop(keys [][]byte, timeout int) ([]interface{}, error) { + return db.lblockPop(keys, listTailSeq, timeout) +} + +func (db *DB) lblockPop(keys [][]byte, whereSeq int32, timeout int) ([]interface{}, error) { + ch := make(chan []byte) + + bkeys := [][]byte{} + for _, key := range keys { + v, err := db.lpop(key, whereSeq) + if err != nil { + return nil, err + } else if v != nil { + return []interface{}{key, v}, nil + } else { + if db.IsAutoCommit() { + //block wait can not be supported in transaction and multi + db.lbkeys.wait(key, ch) + bkeys = append(bkeys, key) + } + } + } + if len(bkeys) == 0 { + return nil, nil + } + + defer func() { + for _, key := range bkeys { + db.lbkeys.unwait(key, ch) + } + }() + + deadT := time.Now().Add(time.Duration(timeout) * time.Second) + + for { + if timeout == 0 { + key := <-ch + if v, err := db.lpop(key, whereSeq); err != nil { + return nil, err + } else if v == nil { + continue + } else { + return []interface{}{key, v}, nil + } + } else { + d := deadT.Sub(time.Now()) + if d < 0 { + return nil, nil + } + + select { + case key := <-ch: + if v, err := db.lpop(key, whereSeq); err != nil { + return nil, err + } else if v == nil { + db.lbkeys.wait(key, ch) + continue + } else { + return []interface{}{key, v}, nil + } + case <-time.After(d): + return nil, nil + } + } + + } +} + +func (db *DB) lSignalAsReady(key []byte, num int) { + if db.status == DBInTransaction { + //for transaction, only data can be pushed after tx commit and it is hard to signal + //so we don't handle it now + return + } + + db.lbkeys.signal(key, num) +} + +type lbKeyCh chan<- []byte + +type lBlockKeys struct { + sync.Mutex + + keys map[string]*list.List +} + +func newLBlockKeys() *lBlockKeys { + l := new(lBlockKeys) + + l.keys = make(map[string]*list.List) + return l +} + +func (l *lBlockKeys) signal(key []byte, num int) { + l.Lock() + defer l.Unlock() + + s := hack.String(key) + chs, ok := l.keys[s] + if !ok { + return + } + + var n *list.Element + + i := 0 + for e := chs.Front(); e != nil && i < num; e = n { + ch := e.Value.(lbKeyCh) + n = e.Next() + select { + case ch <- key: + chs.Remove(e) + i++ + default: + //waiter unwait + chs.Remove(e) + } + } + + if chs.Len() == 0 { + delete(l.keys, s) + } +} + +func (l *lBlockKeys) wait(key []byte, ch lbKeyCh) { + l.Lock() + defer l.Unlock() + + s := hack.String(key) + chs, ok := l.keys[s] + if !ok { + chs = list.New() + l.keys[s] = chs + } + + chs.PushBack(ch) +} + +func (l *lBlockKeys) unwait(key []byte, ch lbKeyCh) { + l.Lock() + defer l.Unlock() + + s := hack.String(key) + chs, ok := l.keys[s] + if !ok { + return + } else { + var n *list.Element + for e := chs.Front(); e != nil; e = n { + c := e.Value.(lbKeyCh) + n = e.Next() + if c == ch { + chs.Remove(e) + } + } + + if chs.Len() == 0 { + delete(l.keys, s) + } + } +} diff --git a/ledis/t_list_test.go b/ledis/t_list_test.go index 8373a43..4fd2633 100644 --- a/ledis/t_list_test.go +++ b/ledis/t_list_test.go @@ -2,7 +2,9 @@ package ledis import ( "fmt" + "sync" "testing" + "time" ) func TestListCodec(t *testing.T) { @@ -102,6 +104,36 @@ func TestListPersist(t *testing.T) { } } +func TestLBlock(t *testing.T) { + db := getTestDB() + + key1 := []byte("test_lblock_key1") + key2 := []byte("test_lblock_key2") + + var wg sync.WaitGroup + wg.Add(2) + + f := func(i int) { + defer wg.Done() + + ay, err := db.BLPop([][]byte{key1, key2}, 0) + if err != nil { + t.Fatal(err) + } else if len(ay) != 2 { + t.Fatal(len(ay)) + } + } + + go f(1) + go f(2) + + time.Sleep(100 * time.Millisecond) + + db.LPush(key1, []byte("value")) + db.LPush(key2, []byte("value")) + wg.Wait() +} + func TestLFlush(t *testing.T) { db := getTestDB() db.FlushAll() diff --git a/ledis/tx.go b/ledis/tx.go index a5ff883..5c1c52a 100644 --- a/ledis/tx.go +++ b/ledis/tx.go @@ -61,6 +61,8 @@ func (db *DB) Begin() (*Tx, error) { tx.DB.binBatch = tx.newBatch() tx.DB.setBatch = tx.newBatch() + tx.DB.lbkeys = db.lbkeys + return tx, nil } diff --git a/server/cmd_list.go b/server/cmd_list.go index 722d5f1..f6b4b4e 100644 --- a/server/cmd_list.go +++ b/server/cmd_list.go @@ -1,7 +1,9 @@ package server import ( + "github.com/siddontang/go/hack" "github.com/siddontang/ledisdb/ledis" + "strconv" ) func lpushCommand(c *client) error { @@ -249,7 +251,53 @@ func lxscanCommand(c *client) error { return nil } +func blpopCommand(c *client) error { + keys, timeout, err := lParseBPopArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.BLPop(keys, timeout); err != nil { + return err + } else { + c.resp.writeArray(ay) + } + return nil +} + +func brpopCommand(c *client) error { + keys, timeout, err := lParseBPopArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.BRPop(keys, timeout); err != nil { + return err + } else { + c.resp.writeArray(ay) + } + return nil + +} + +func lParseBPopArgs(c *client) (keys [][]byte, timeout int, err error) { + args := c.args + if len(args) < 2 { + err = ErrCmdParams + return + } + + if timeout, err = strconv.Atoi(hack.String(args[len(args)-1])); err != nil { + return + } + + keys = args[0 : len(args)-1] + return +} + func init() { + register("blpop", blpopCommand) + register("brpop", brpopCommand) register("lindex", lindexCommand) register("llen", llenCommand) register("lpop", lpopCommand)