diff --git a/server/cmd_list.go b/server/cmd_list.go index 7b0da6b..0acf52f 100644 --- a/server/cmd_list.go +++ b/server/cmd_list.go @@ -4,6 +4,7 @@ import ( "strconv" "time" + "bytes" "github.com/siddontang/go/hack" "github.com/siddontang/ledisdb/ledis" ) @@ -285,6 +286,15 @@ func brpoplpushCommand(c *client) error { return err } + var ttl int64 = -1 + if bytes.Compare(source, dest) == 0 { + var err error + ttl, err = c.db.LTTL(source) + if err != nil { + return err + } + } + ay, err := c.db.BRPop([][]byte{source}, timeout) if err != nil { return err @@ -305,6 +315,11 @@ func brpoplpushCommand(c *client) error { return err } + //reset ttl + if ttl != -1 { + c.db.LExpire(source, ttl) + } + c.resp.writeBulk(data) return nil @@ -336,6 +351,15 @@ func rpoplpushCommand(c *client) error { } source, dest := args[0], args[1] + var ttl int64 = -1 + if bytes.Compare(source, dest) == 0 { + var err error + ttl, err = c.db.LTTL(source) + if err != nil { + return err + } + } + data, err := c.db.RPop(source) if err != nil { return err @@ -351,6 +375,11 @@ func rpoplpushCommand(c *client) error { return err } + //reset ttl + if ttl != -1 { + c.db.LExpire(source, ttl) + } + c.resp.writeBulk(data) return nil } diff --git a/server/cmd_list_test.go b/server/cmd_list_test.go index c9bef53..12f0b8a 100644 --- a/server/cmd_list_test.go +++ b/server/cmd_list_test.go @@ -304,6 +304,9 @@ func TestRPopLPush(t *testing.T) { src := []byte("sr") des := []byte("de") + c.Do("lclear", src) + c.Do("lclear", des) + if _, err := goredis.Int(c.Do("rpoplpush", src, des)); err != goredis.ErrNil { t.Fatal(err) } @@ -363,6 +366,36 @@ func TestRPopLPush(t *testing.T) { } } +func TestRPopLPushSingleElement(t *testing.T) { + c := getTestConn() + defer c.Close() + + src := []byte("sr") + + c.Do("lclear", src) + if n, err := goredis.Int(c.Do("rpush", src, 1)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + ttl := 300 + if _, err := c.Do("lexpire", src, ttl); err != nil { + t.Fatal(err) + } + + if v, err := goredis.Int(c.Do("rpoplpush", src, src)); err != nil { + t.Fatal(err) + } else if v != 1 { + t.Fatal(v) + } + + if tl, err := goredis.Int(c.Do("lttl", src)); err != nil { + t.Fatal(err) + } else if tl == -1 || tl > ttl { + t.Fatal(tl) + } +} + func TestTrim(t *testing.T) { c := getTestConn() defer c.Close()