diff --git a/ledis/t_list.go b/ledis/t_list.go index 021da7e..faf6d39 100644 --- a/ledis/t_list.go +++ b/ledis/t_list.go @@ -9,6 +9,7 @@ import ( "github.com/siddontang/go/hack" "github.com/siddontang/go/log" + "github.com/siddontang/go/num" "github.com/siddontang/ledisdb/store" ) @@ -215,6 +216,61 @@ func (db *DB) lpop(key []byte, whereSeq int32) ([]byte, error) { return value, err } +func (db *DB) ltrim(key []byte, trimSize, whereSeq int32) (int32, error) { + if err := checkKeySize(key); err != nil { + return 0, err + } + + if trimSize == 0 { + return 0, nil + } + + t := db.listBatch + t.Lock() + defer t.Unlock() + + var headSeq int32 + var tailSeq int32 + var size int32 + var err error + + metaKey := db.lEncodeMetaKey(key) + headSeq, tailSeq, size, err = db.lGetMeta(nil, metaKey) + if err != nil { + return 0, err + } else if size == 0 { + return 0, nil + } + + var ( + trimStartSeq int32 + trimEndSeq int32 + ) + + if whereSeq == listHeadSeq { + trimStartSeq = headSeq + trimEndSeq = num.MinInt32(trimStartSeq+trimSize-1, tailSeq) + headSeq = trimEndSeq + 1 + } else { + trimEndSeq = tailSeq + trimStartSeq = num.MaxInt32(trimEndSeq-trimSize+1, headSeq) + tailSeq = trimStartSeq - 1 + } + + for trimSeq := trimStartSeq; trimSeq <= trimEndSeq; trimSeq++ { + itemKey := db.lEncodeListKey(key, trimSeq) + t.Delete(itemKey) + } + + size = db.lSetMeta(metaKey, headSeq, tailSeq) + if size == 0 { + db.rmExpire(t, ListType, key) + } + + err = t.Commit() + return trimEndSeq - trimStartSeq + 1, err +} + // ps : here just focus on deleting the list data, // any other likes expire is ignore. func (db *DB) lDelete(t *batch, key []byte) int64 { @@ -352,6 +408,14 @@ func (db *DB) LPop(key []byte) ([]byte, error) { return db.lpop(key, listHeadSeq) } +func (db *DB) LTrimFront(key []byte, trimSize int32) (int32, error) { + return db.ltrim(key, trimSize, listHeadSeq) +} + +func (db *DB) LTrimBack(key []byte, trimSize int32) (int32, error) { + return db.ltrim(key, trimSize, listTailSeq) +} + func (db *DB) LPush(key []byte, args ...[]byte) (int64, error) { return db.lpush(key, listHeadSeq, args...) } diff --git a/server/cmd_list.go b/server/cmd_list.go index 4934621..0f23094 100644 --- a/server/cmd_list.go +++ b/server/cmd_list.go @@ -292,6 +292,52 @@ func lkeyexistsCommand(c *client) error { return nil } +func lTrimFrontCommand(c *client) error { + args := c.args + if len(args) != 2 { + return ErrCmdParams + } + + var trimSize int32 + var err error + + trimSize, err = ledis.StrInt32(args[1], nil) + if err != nil || trimSize < 0 { + return ErrValue + } + + if n, err := c.db.LTrimFront(args[0], trimSize); err != nil { + return err + } else { + c.resp.writeInteger(int64(n)) + } + + return nil +} + +func lTrimBackCommand(c *client) error { + args := c.args + if len(args) != 2 { + return ErrCmdParams + } + + var trimSize int32 + var err error + + trimSize, err = ledis.StrInt32(args[1], nil) + if err != nil || trimSize < 0 { + return ErrValue + } + + if n, err := c.db.LTrimBack(args[0], trimSize); err != nil { + return err + } else { + c.resp.writeInteger(int64(n)) + } + + return nil +} + func init() { register("blpop", blpopCommand) register("brpop", brpopCommand) @@ -312,4 +358,7 @@ func init() { register("lttl", lttlCommand) register("lpersist", lpersistCommand) register("lkeyexists", lkeyexistsCommand) + + register("ltrim_front", lTrimFrontCommand) + register("ltrim_back", lTrimBackCommand) } diff --git a/server/cmd_list_test.go b/server/cmd_list_test.go index 3386d43..a96ec54 100644 --- a/server/cmd_list_test.go +++ b/server/cmd_list_test.go @@ -297,6 +297,72 @@ func TestPop(t *testing.T) { } +func TestTrim(t *testing.T) { + c := getTestConn() + defer c.Close() + + key := []byte("d") + if n, err := goredis.Int(c.Do("rpush", key, 1, 2, 3, 4, 5, 6)); err != nil { + t.Fatal(err) + } else if n != 6 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("ltrim_front", key, 2)); err != nil { + t.Fatal(err) + } else if n != 2 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("llen", key)); err != nil { + t.Fatal(err) + } else if n != 4 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("ltrim_back", key, 2)); err != nil { + t.Fatal(err) + } else if n != 2 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("llen", key)); err != nil { + t.Fatal(err) + } else if n != 2 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("ltrim_front", key, 5)); err != nil { + t.Fatal(err) + } else if n != 2 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("llen", key)); err != nil { + t.Fatal(err) + } else if n != 0 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("rpush", key, 1, 2)); err != nil { + t.Fatal(err) + } else if n != 2 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("ltrim_front", key, 2)); err != nil { + t.Fatal(err) + } else if n != 2 { + t.Fatal(n) + } + + if n, err := goredis.Int(c.Do("llen", key)); err != nil { + t.Fatal(err) + } else if n != 0 { + t.Fatal(n) + } +} + func TestListErrorParams(t *testing.T) { c := getTestConn() defer c.Close() @@ -353,4 +419,11 @@ func TestListErrorParams(t *testing.T) { t.Fatal("invalid err of %v", err) } + if _, err := c.Do("ltrim_front", "test_ltrimfront", "-1"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("ltrim_back", "test_ltrimback", "a"); err == nil { + t.Fatal("invalid err of %v", err) + } }