diff --git a/ledis/t_list.go b/ledis/t_list.go index faf6d39..24ec6d2 100644 --- a/ledis/t_list.go +++ b/ledis/t_list.go @@ -216,6 +216,60 @@ func (db *DB) lpop(key []byte, whereSeq int32) ([]byte, error) { return value, err } +func (db *DB) ltrim2(key []byte, startP, stopP int64) (err error) { + if err := checkKeySize(key); err != nil { + return err + } + + t := db.listBatch + t.Lock() + defer t.Unlock() + + var headSeq int32 + var llen int32 + start := int32(startP) + stop := int32(stopP) + + ek := db.lEncodeMetaKey(key) + if headSeq, _, llen, err = db.lGetMeta(nil, ek); err != nil { + return err + } else { + if start < 0 { + start = llen + start + } + if stop < 0 { + stop = llen + stop + } + if start >= llen || start > stop { + db.lDelete(t, key) + db.rmExpire(t, ListType, key) + return t.Commit() + } + + if start < 0 { + start = 0 + } + if stop >= llen { + stop = llen - 1 + } + } + + if start > 0 { + for i := int32(0); i < start; i++ { + t.Delete(db.lEncodeListKey(key, headSeq+i)) + } + } + if stop < int32(llen-1) { + for i := int32(stop + 1); i < llen; i++ { + t.Delete(db.lEncodeListKey(key, headSeq+i)) + } + } + + db.lSetMeta(ek, headSeq+start, headSeq+stop) + + return t.Commit() +} + func (db *DB) ltrim(key []byte, trimSize, whereSeq int32) (int32, error) { if err := checkKeySize(key); err != nil { return 0, err @@ -408,6 +462,10 @@ func (db *DB) LPop(key []byte) ([]byte, error) { return db.lpop(key, listHeadSeq) } +func (db *DB) LTrim(key []byte, start, stop int64) error { + return db.ltrim2(key, start, stop) +} + func (db *DB) LTrimFront(key []byte, trimSize int32) (int32, error) { return db.ltrim(key, trimSize, listHeadSeq) } diff --git a/ledis/t_list_test.go b/ledis/t_list_test.go index b21118a..32d960d 100644 --- a/ledis/t_list_test.go +++ b/ledis/t_list_test.go @@ -2,6 +2,7 @@ package ledis import ( "fmt" + "strconv" "sync" "testing" "time" @@ -29,6 +30,93 @@ func TestListCodec(t *testing.T) { } } +func TestListTrim(t *testing.T) { + db := getTestDB() + + key := []byte("test_list_trim") + + init := func() { + db.LClear(key) + for i := 0; i < 100; i++ { + n, err := db.RPush(key, []byte(strconv.Itoa(i))) + if err != nil { + t.Fatal(err) + } + if n != int64(i+1) { + t.Fatal("length wrong") + } + } + } + + init() + + err := db.LTrim(key, 0, 99) + if err != nil { + t.Fatal(err) + } + if l, _ := db.LLen(key); l != int64(100) { + t.Fatal("wrong len:", l) + } + + err = db.LTrim(key, 0, 50) + if err != nil { + t.Fatal(err) + } + if l, _ := db.LLen(key); l != int64(51) { + t.Fatal("wrong len:", l) + } + for i := int32(0); i < 51; i++ { + v, err := db.LIndex(key, i) + if err != nil { + t.Fatal(err) + } + if string(v) != strconv.Itoa(int(i)) { + t.Fatal("wrong value") + } + } + + err = db.LTrim(key, 11, 30) + if err != nil { + t.Fatal(err) + } + if l, _ := db.LLen(key); l != int64(30-11+1) { + t.Fatal("wrong len:", l) + } + for i := int32(11); i < 31; i++ { + v, err := db.LIndex(key, i-11) + if err != nil { + t.Fatal(err) + } + if string(v) != strconv.Itoa(int(i)) { + t.Fatal("wrong value") + } + } + + err = db.LTrim(key, 0, -1) + if err != nil { + t.Fatal(err) + } + if l, _ := db.LLen(key); l != int64(30-11+1) { + t.Fatal("wrong len:", l) + } + + init() + err = db.LTrim(key, -3, -3) + if err != nil { + t.Fatal(err) + } + if l, _ := db.LLen(key); l != int64(1) { + t.Fatal("wrong len:", l) + } + v, err := db.LIndex(key, 0) + if err != nil { + t.Fatal(err) + } + if string(v) != "97" { + t.Fatal("wrong value", string(v)) + } +} + func TestDBList(t *testing.T) { db := getTestDB() diff --git a/server/cmd_list.go b/server/cmd_list.go index 0f23094..e7523f3 100644 --- a/server/cmd_list.go +++ b/server/cmd_list.go @@ -292,6 +292,34 @@ func lkeyexistsCommand(c *client) error { return nil } +func lTrimCommand(c *client) error { + args := c.args + if len(args) != 3 { + return ErrCmdParams + } + + var start int32 + var stop int32 + var err error + + start, err = ledis.StrInt64(args[1], nil) + if err != nil { + return ErrValue + } + stop, err = ledis.StrInt64(args[2], nil) + if err != nil { + return ErrValue + } + + if err := c.db.LTrim(args[0], start, stop); err != nil { + return err + } else { + c.resp.writeStatus(OK) + } + + return nil +} + func lTrimFrontCommand(c *client) error { args := c.args if len(args) != 2 { @@ -361,4 +389,5 @@ func init() { register("ltrim_front", lTrimFrontCommand) register("ltrim_back", lTrimBackCommand) + register("ltrim", lTrimCommand) }