diff --git a/ledis/sort.go b/ledis/sort.go new file mode 100644 index 0000000..cfd8b10 --- /dev/null +++ b/ledis/sort.go @@ -0,0 +1,232 @@ +package ledis + +import ( + "bytes" + "fmt" + "github.com/siddontang/ledisdb/store" + "sort" + "strconv" +) + +type Limit struct { + Offset int + Size int +} + +func getSortRange(values [][]byte, offset int, size int) (int, int) { + var start = 0 + if offset > 0 { + start = offset + } + + valueLen := len(values) + var end = valueLen - 1 + if size > 0 { + end = start + size - 1 + } + + if start >= valueLen { + start = valueLen - 1 + end = valueLen - 2 + } + + if end >= valueLen { + end = valueLen - 1 + } + + return start, end +} + +var hashPattern = []byte("*->") + +func (db *DB) lookupKeyByPattern(pattern []byte, subKey []byte) []byte { + // If the pattern is #, return the substitution key itself + if bytes.Equal(pattern, []byte{'#'}) { + return subKey + } + + // If we can't find '*' in the pattern, return nil + if !bytes.Contains(pattern, []byte{'*'}) { + return nil + } + + key := pattern + var field []byte = nil + + // Find out if we're dealing with a hash dereference + if n := bytes.Index(pattern, hashPattern); n > 0 && n+3 < len(pattern) { + key = pattern[0 : n+1] + field = pattern[n+3:] + } + + // Perform the '*' substitution + key = bytes.Replace(key, []byte{'*'}, subKey, 1) + + var value []byte + if field == nil { + value, _ = db.Get(key) + } else { + value, _ = db.HGet(key, field) + } + + return value +} + +type sortItem struct { + value []byte + cmpValue []byte + score float64 +} + +type sortItemSlice struct { + alpha bool + sortByPattern bool + items []sortItem +} + +func (s *sortItemSlice) Len() int { + return len(s.items) +} + +func (s *sortItemSlice) Swap(i, j int) { + s.items[i], s.items[j] = s.items[j], s.items[i] +} + +func (s *sortItemSlice) Less(i, j int) bool { + s1 := s.items[i] + s2 := s.items[j] + if !s.alpha { + if s1.score < s2.score { + return true + } else if s1.score > s2.score { + return false + } else { + return bytes.Compare(s1.value, s2.value) < 0 + } + } else { + if s.sortByPattern { + if s1.cmpValue == nil || s2.cmpValue == nil { + if s1.cmpValue == nil { + return true + } else { + return false + } + } else { + // Unlike redis, we only use bytes compare + return bytes.Compare(s1.cmpValue, s2.cmpValue) < 0 + } + } else { + // Unlike redis, we only use bytes compare + return bytes.Compare(s1.value, s2.value) < 0 + } + } +} + +func (db *DB) xsort(values [][]byte, offset int, size int, alpha bool, desc bool, sortBy []byte, sortGet [][]byte) ([][]byte, error) { + if len(values) == 0 { + return [][]byte{}, nil + } + + start, end := getSortRange(values, offset, size) + + dontsort := 0 + + if sortBy != nil { + if !bytes.Contains(sortBy, []byte{'*'}) { + dontsort = 1 + } + } + + items := &sortItemSlice{ + alpha: alpha, + sortByPattern: sortBy != nil, + items: make([]sortItem, len(values)), + } + + for i, value := range values { + items.items[i].value = value + items.items[i].score = 0 + items.items[i].cmpValue = nil + + if dontsort == 0 { + var cmpValue []byte + if sortBy != nil { + cmpValue = db.lookupKeyByPattern(sortBy, value) + } else { + // use value iteself to sort by + cmpValue = value + } + + if cmpValue == nil { + continue + } + + if alpha { + if sortBy != nil { + items.items[i].cmpValue = cmpValue + } + } else { + score, err := strconv.ParseFloat(string(cmpValue), 64) + if err != nil { + return nil, fmt.Errorf("%s scores can't be converted into double", cmpValue) + } + items.items[i].score = score + } + } + } + + if dontsort == 0 { + if !desc { + sort.Sort(items) + } else { + sort.Sort(sort.Reverse(items)) + } + } + + var resLen int = end - start + 1 + if len(sortGet) > 0 { + resLen = len(sortGet) * (end - start + 1) + } + + res := make([][]byte, 0, resLen) + for i := start; i <= end; i++ { + if len(sortGet) == 0 { + res = append(res, items.items[i].value) + } else { + for _, getPattern := range sortGet { + v := db.lookupKeyByPattern(getPattern, items.items[i].value) + res = append(res, v) + } + } + } + + return res, nil +} + +func (db *DB) XLSort(key []byte, offset int, size int, alpha bool, desc bool, sortBy []byte, sortGet [][]byte) ([][]byte, error) { + values, err := db.LRange(key, 0, -1) + + if err != nil { + return nil, err + } + + return db.xsort(values, offset, size, alpha, desc, sortBy, sortGet) +} + +func (db *DB) XSSort(key []byte, offset int, size int, alpha bool, desc bool, sortBy []byte, sortGet [][]byte) ([][]byte, error) { + values, err := db.SMembers(key) + if err != nil { + return nil, err + } + + return db.xsort(values, offset, size, alpha, desc, sortBy, sortGet) +} + +func (db *DB) XZSort(key []byte, offset int, size int, alpha bool, desc bool, sortBy []byte, sortGet [][]byte) ([][]byte, error) { + values, err := db.ZRangeByLex(key, nil, nil, store.RangeClose, 0, -1) + if err != nil { + return nil, err + } + + return db.xsort(values, offset, size, alpha, desc, sortBy, sortGet) +} diff --git a/ledis/sort_test.go b/ledis/sort_test.go new file mode 100644 index 0000000..8d59544 --- /dev/null +++ b/ledis/sort_test.go @@ -0,0 +1,123 @@ +package ledis + +import ( + "runtime/debug" + "testing" +) + +func testBuildValues(values ...string) [][]byte { + v := make([][]byte, 0, len(values)) + for _, value := range values { + v = append(v, []byte(value)) + } + return v +} + +func checkSortRes(t *testing.T, res [][]byte, values ...string) { + if len(res) != len(values) { + debug.PrintStack() + t.Fatalf("invalid xsort res len, %d = %d", len(res), len(values)) + } + + for i := 0; i < len(res); i++ { + if string(res[i]) != values[i] { + debug.PrintStack() + t.Fatalf("invalid xsort res at %d, %s != %s", i, res[i], values[i]) + } + } +} + +func checkTestSort(t *testing.T, db *DB, items []string, offset int, size int, alpha bool, + desc bool, sortBy []byte, sortGet [][]byte, checkValues []string) { + + vv := testBuildValues(items...) + + res, err := db.xsort(vv, offset, size, alpha, desc, sortBy, sortGet) + if err != nil { + t.Fatal(err) + } + + checkSortRes(t, res, checkValues...) +} + +func TestSort(t *testing.T) { + db := getTestDB() + + db.FlushAll() + + // Prepare data + db.MSet( + KVPair{[]byte("weight_1"), []byte("30")}, + KVPair{[]byte("weight_2"), []byte("20")}, + KVPair{[]byte("weight_3"), []byte("10")}, + KVPair{[]byte("weight_a"), []byte("60")}, + KVPair{[]byte("weight_b"), []byte("50")}, + KVPair{[]byte("weight_c"), []byte("40")}) + + db.HSet([]byte("hash_weight_1"), []byte("index"), []byte("30")) + db.HSet([]byte("hash_weight_2"), []byte("index"), []byte("20")) + db.HSet([]byte("hash_weight_3"), []byte("index"), []byte("10")) + db.HSet([]byte("hash_weight_a"), []byte("index"), []byte("60")) + db.HSet([]byte("hash_weight_b"), []byte("index"), []byte("50")) + db.HSet([]byte("hash_weight_c"), []byte("index"), []byte("40")) + + db.MSet( + KVPair{[]byte("object_1"), []byte("30")}, + KVPair{[]byte("object_2"), []byte("20")}, + KVPair{[]byte("object_3"), []byte("10")}, + KVPair{[]byte("number_1"), []byte("10")}, + KVPair{[]byte("number_2"), []byte("20")}, + KVPair{[]byte("number_3"), []byte("30")}, + KVPair{[]byte("object_a"), []byte("60")}, + KVPair{[]byte("object_b"), []byte("50")}, + KVPair{[]byte("object_c"), []byte("40")}, + KVPair{[]byte("number_a"), []byte("40")}, + KVPair{[]byte("number_b"), []byte("50")}, + KVPair{[]byte("number_c"), []byte("60")}) + + db.HSet([]byte("hash_object_1"), []byte("index"), []byte("30")) + db.HSet([]byte("hash_object_2"), []byte("index"), []byte("20")) + db.HSet([]byte("hash_object_3"), []byte("index"), []byte("10")) + db.HSet([]byte("hash_number_1"), []byte("index"), []byte("10")) + db.HSet([]byte("hash_number_2"), []byte("index"), []byte("20")) + db.HSet([]byte("hash_number_3"), []byte("index"), []byte("30")) + + db.HSet([]byte("hash_object_a"), []byte("index"), []byte("60")) + db.HSet([]byte("hash_object_b"), []byte("index"), []byte("50")) + db.HSet([]byte("hash_object_c"), []byte("index"), []byte("40")) + db.HSet([]byte("hash_number_a"), []byte("index"), []byte("40")) + db.HSet([]byte("hash_number_b"), []byte("index"), []byte("50")) + db.HSet([]byte("hash_number_c"), []byte("index"), []byte("60")) + + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, nil, nil, []string{"1", "2", "3"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, 1, false, false, nil, nil, []string{"1"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, 2, false, false, nil, nil, []string{"1", "2"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, 3, false, false, nil, nil, []string{"1", "2", "3"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, 4, false, false, nil, nil, []string{"1", "2", "3"}) + + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, true, false, nil, nil, []string{"1", "2", "3"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, true, nil, nil, []string{"3", "2", "1"}) + + if _, err := db.xsort(testBuildValues("c", "b", "a"), 0, -1, false, false, nil, nil); err == nil { + t.Fatal("must nil") + } + + checkTestSort(t, db, []string{"c", "b", "a"}, 0, -1, true, false, nil, nil, []string{"a", "b", "c"}) + checkTestSort(t, db, []string{"c", "b", "a"}, 0, -1, true, true, nil, nil, []string{"c", "b", "a"}) + + checkTestSort(t, db, []string{"1", "2", "3"}, 0, 1, false, false, nil, nil, []string{"1"}) + checkTestSort(t, db, []string{"1", "2", "3"}, 0, 1, false, true, nil, nil, []string{"3"}) + + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, []byte("abc"), nil, []string{"3", "2", "1"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, []byte("weight_*"), nil, []string{"3", "2", "1"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, true, []byte("weight_*"), nil, []string{"1", "2", "3"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, []byte("hash_weight_*->index"), nil, []string{"3", "2", "1"}) + + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, nil, [][]byte{[]byte("object_*")}, []string{"30", "20", "10"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, nil, [][]byte{[]byte("number_*")}, []string{"10", "20", "30"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, nil, [][]byte{[]byte("#"), []byte("number_*")}, + []string{"1", "10", "2", "20", "3", "30"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, nil, [][]byte{[]byte("object_*"), []byte("number_*")}, + []string{"30", "10", "20", "20", "10", "30"}) + checkTestSort(t, db, []string{"3", "2", "1"}, 0, -1, false, false, nil, [][]byte{[]byte("object_*_abc")}, []string{"", "", ""}) +} diff --git a/server/cmd_sort.go b/server/cmd_sort.go new file mode 100644 index 0000000..5a9cf19 --- /dev/null +++ b/server/cmd_sort.go @@ -0,0 +1,118 @@ +package server + +import ( + "bytes" + "fmt" + "strconv" + "strings" +) + +func xsort(c *client, tp string, key []byte, offset int, size int, alpha bool, + desc bool, sortBy []byte, sortGet [][]byte) ([][]byte, error) { + var ay [][]byte + var err error + switch strings.ToUpper(tp) { + case "LIST": + ay, err = c.db.XLSort(key, offset, size, alpha, desc, sortBy, sortGet) + case "SET": + ay, err = c.db.XSSort(key, offset, size, alpha, desc, sortBy, sortGet) + case "ZSET": + ay, err = c.db.XZSort(key, offset, size, alpha, desc, sortBy, sortGet) + default: + err = fmt.Errorf("invalid key type %s", tp) + } + return ay, err +} + +func xlsortCommand(c *client) error { + return handleXSort(c, "LIST") +} +func xssortCommand(c *client) error { + return handleXSort(c, "SET") +} +func xzsortCommand(c *client) error { + return handleXSort(c, "ZSET") +} + +var ascArg = []byte("asc") +var descArg = []byte("desc") +var alphaArg = []byte("alpha") +var limitArg = []byte("limit") +var storeArg = []byte("store") +var byArg = []byte("by") +var getArg = []byte("get") + +func handleXSort(c *client, tp string) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + key := args[0] + desc := false + alpha := false + offset := 0 + size := 0 + var storeKey []byte + var sortBy []byte + var sortGet [][]byte + var err error + + for i := 1; i < len(args); { + if bytes.EqualFold(args[i], ascArg) { + desc = false + } else if bytes.EqualFold(args[i], descArg) { + desc = true + } else if bytes.EqualFold(args[i], alphaArg) { + alpha = true + } else if bytes.EqualFold(args[i], limitArg) && i+2 < len(args) { + if offset, err = strconv.Atoi(string(args[i+1])); err != nil { + return err + } + if size, err = strconv.Atoi(string(args[i+2])); err != nil { + return err + } + i = i + 2 + } else if bytes.EqualFold(args[i], storeArg) && i+1 < len(args) { + storeKey = args[i+1] + i++ + } else if bytes.EqualFold(args[i], byArg) && i+1 < len(args) { + sortBy = args[i+1] + i++ + } else if bytes.EqualFold(args[i], getArg) && i+1 < len(args) { + sortGet = append(sortGet, args[i+1]) + i++ + } else { + return ErrCmdParams + } + + i++ + } + + ay, err := xsort(c, tp, key, offset, size, alpha, desc, sortBy, sortGet) + if err != nil { + return err + } + + if storeKey == nil { + c.resp.writeSliceArray(ay) + } else { + // not threadsafe now, need lock??? + if _, err = c.db.LClear(storeKey); err != nil { + return err + } + + if n, err := c.db.RPush(storeKey, ay...); err != nil { + return err + } else { + c.resp.writeInteger(n) + } + } + return nil +} + +func init() { + register("xlsort", xlsortCommand) + register("xssort", xssortCommand) + register("xzsort", xzsortCommand) +} diff --git a/server/cmd_sort_test.go b/server/cmd_sort_test.go new file mode 100644 index 0000000..9812611 --- /dev/null +++ b/server/cmd_sort_test.go @@ -0,0 +1,91 @@ +package server + +import ( + "fmt" + goledis "github.com/siddontang/ledisdb/client/go/ledis" + "testing" +) + +func checkTestSortRes(ay interface{}, checks []string) error { + values, ok := ay.([]interface{}) + if !ok { + return fmt.Errorf("invalid res type %T, must [][]byte", ay) + } + + if len(values) != len(checks) { + return fmt.Errorf("invalid res number %d != %d", len(values), len(checks)) + } + + for i, _ := range values { + if string(values[i].([]byte)) != checks[i] { + return fmt.Errorf("invalid res at %d, %s != %s", i, values[i], checks[i]) + } + } + return nil +} + +func TestSort(t *testing.T) { + c := getTestConn() + defer c.Close() + + key := "my_sort_key" + storeKey := "my_sort_store_key" + + if _, err := c.Do("LPUSH", key, 3, 2, 1); err != nil { + t.Fatal(err) + } + + if _, err := c.Do("MSET", "weight_1", 3, "weight_2", 2, "weight_3", 1); err != nil { + t.Fatal(err) + } + + if _, err := c.Do("MSET", "object_1", 10, "object_2", 20, "object_3", 30); err != nil { + t.Fatal(err) + } + + if ay, err := c.Do("XLSORT", key); err != nil { + t.Fatal(err) + } else if err = checkTestSortRes(ay, []string{"1", "2", "3"}); err != nil { + t.Fatal(err) + } + + if ay, err := c.Do("XLSORT", key, "DESC"); err != nil { + t.Fatal(err) + } else if err = checkTestSortRes(ay, []string{"3", "2", "1"}); err != nil { + t.Fatal(err) + } + + if ay, err := c.Do("XLSORT", key, "LIMIT", 0, 1); err != nil { + t.Fatal(err) + } else if err = checkTestSortRes(ay, []string{"1"}); err != nil { + t.Fatal(err) + } + + if ay, err := c.Do("XLSORT", key, "BY", "weight_*"); err != nil { + t.Fatal(err) + } else if err = checkTestSortRes(ay, []string{"3", "2", "1"}); err != nil { + t.Fatal(err) + } + + if ay, err := c.Do("XLSORT", key, "GET", "object_*"); err != nil { + t.Fatal(err) + } else if err = checkTestSortRes(ay, []string{"10", "20", "30"}); err != nil { + t.Fatal(err) + } + + if ay, err := c.Do("XLSORT", key, "GET", "object_*", "GET", "#"); err != nil { + t.Fatal(err) + } else if err = checkTestSortRes(ay, []string{"10", "1", "20", "2", "30", "3"}); err != nil { + t.Fatal(err) + } + + if n, err := goledis.Int(c.Do("XLSORT", key, "STORE", storeKey)); err != nil { + t.Fatal(err) + } else if n != 3 { + t.Fatalf("invalid return store sort number, %d != 3", n) + } else if ay, err := c.Do("LRANGE", storeKey, 0, -1); err != nil { + t.Fatal(err) + } else if err = checkTestSortRes(ay, []string{"1", "2", "3"}); err != nil { + t.Fatal(err) + } +}