diff --git a/internal/server/crud.go b/internal/server/crud.go index b25a594d..8b0f0d2f 100644 --- a/internal/server/crud.go +++ b/internal/server/crud.go @@ -454,6 +454,62 @@ func (server *Server) cmdDrop(msg *Message) (res resp.Value, d commandDetails, e return } +func (server *Server) cmdRename(msg *Message, nx bool) (res resp.Value, d commandDetails, err error) { + start := time.Now() + vs := msg.Args[1:] + var ok bool + if vs, d.key, ok = tokenval(vs); !ok || d.key == "" { + err = errInvalidNumberOfArguments + return + } + if vs, d.newKey, ok = tokenval(vs); !ok || d.newKey == "" { + err = errInvalidNumberOfArguments + return + } + if len(vs) != 0 { + err = errInvalidNumberOfArguments + return + } + for _, h := range server.hooks { + if h.Key == d.key || h.Key == d.newKey { + err = errKeyHasHooksSet + return + } + } + col := server.getCol(d.key) + if col == nil { + err = errKeyNotFound + return + } + d.command = "rename" + d.updated = true + newCol := server.getCol(d.newKey) + if newCol != nil { + if nx { + d.updated = false + } else { + server.deleteCol(d.newKey) + } + } + if d.updated { + server.deleteCol(d.key) + server.setCol(d.newKey, col) + server.moveKeyExpires(d.key, d.newKey) + } + d.timestamp = time.Now() + switch msg.OutputType { + case JSON: + res = resp.StringValue(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}") + case RESP: + if d.updated { + res = resp.IntegerValue(1) + } else { + res = resp.IntegerValue(0) + } + } + return +} + func (server *Server) cmdFlushDB(msg *Message) (res resp.Value, d commandDetails, err error) { start := time.Now() vs := msg.Args[1:] diff --git a/internal/server/expire.go b/internal/server/expire.go index 823106cc..c7f5900b 100644 --- a/internal/server/expire.go +++ b/internal/server/expire.go @@ -64,6 +64,13 @@ func (c *Server) clearKeyExpires(key string) { delete(c.expires, key) } +// moveKeyExpires moves all items that are marked as expires from a key to a newKey. +func (c *Server) moveKeyExpires(key, newKey string) { + val := c.expires[key] + delete(c.expires, key) + c.expires[newKey] = val +} + // expireAt marks an item as expires at a specific time. func (c *Server) expireAt(key, id string, at time.Time) { m := c.expires[key] diff --git a/internal/server/scripts.go b/internal/server/scripts.go index 13832fbd..587222c6 100644 --- a/internal/server/scripts.go +++ b/internal/server/scripts.go @@ -575,6 +575,10 @@ func (c *Server) commandInScript(msg *Message) ( res, d, err = c.cmdDrop(msg) case "expire": res, d, err = c.cmdExpire(msg) + case "rename": + res, d, err = c.cmdRename(msg, false) + case "renamenx": + res, d, err = c.cmdRename(msg, true) case "persist": res, d, err = c.cmdPersist(msg) case "ttl": @@ -642,7 +646,8 @@ func (c *Server) luaTile38AtomicRW(msg *Message) (resp.Value, error) { switch msg.Command() { default: return resp.NullValue(), errCmdNotSupported - case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel": + case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel", + "rename", "renamenx": // write operations write = true if c.config.followHost() != "" { @@ -678,7 +683,9 @@ func (c *Server) luaTile38AtomicRO(msg *Message) (resp.Value, error) { default: return resp.NullValue(), errCmdNotSupported - case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel": + case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel", + "rename", "renamenx": + // write operations return resp.NullValue(), errReadOnly case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search", @@ -704,7 +711,8 @@ func (c *Server) luaTile38NonAtomic(msg *Message) (resp.Value, error) { switch msg.Command() { default: return resp.NullValue(), errCmdNotSupported - case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel": + case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel", + "rename", "renamenx": // write operations write = true c.mu.Lock() diff --git a/internal/server/server.go b/internal/server/server.go index 75b37106..4fb9056c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -49,6 +49,7 @@ const ( type commandDetails struct { command string // client command, like "SET" or "DEL" key, id string // collection key and object id of object + newKey string // new key, for RENAME command fmap map[string]int // map of field names to value indexes obj geojson.Object // new object fields []float64 // array of field values @@ -934,7 +935,7 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { case "set", "del", "drop", "fset", "flushdb", "setchan", "pdelchan", "delchan", "sethook", "pdelhook", "delhook", - "expire", "persist", "jset", "pdel": + "expire", "persist", "jset", "pdel", "rename", "renamenx": // write operations write = true server.mu.Lock() @@ -1072,6 +1073,10 @@ func (server *Server) command(msg *Message, client *Client) ( res, d, err = server.cmdDrop(msg) case "flushdb": res, d, err = server.cmdFlushDB(msg) + case "rename": + res, d, err = server.cmdRename(msg, false) + case "renamenx": + res, d, err = server.cmdRename(msg, true) case "sethook": res, d, err = server.cmdSetHook(msg, false) diff --git a/internal/server/token.go b/internal/server/token.go index 05721ec2..71a81044 100644 --- a/internal/server/token.go +++ b/internal/server/token.go @@ -17,6 +17,7 @@ var errKeyNotFound = errors.New("key not found") var errIDNotFound = errors.New("id not found") var errIDAlreadyExists = errors.New("id already exists") var errPathNotFound = errors.New("path not found") +var errKeyHasHooksSet = errors.New("key has hooks set") func errInvalidArgument(arg string) error { return fmt.Errorf("invalid argument '%s'", arg) diff --git a/tests/keys_test.go b/tests/keys_test.go index a02ed170..636aa31a 100644 --- a/tests/keys_test.go +++ b/tests/keys_test.go @@ -18,6 +18,8 @@ func subTestKeys(t *testing.T, mc *mockServer) { runStep(t, mc, "BOUNDS", keys_BOUNDS_test) runStep(t, mc, "DEL", keys_DEL_test) runStep(t, mc, "DROP", keys_DROP_test) + runStep(t, mc, "RENAME", keys_RENAME_test) + runStep(t, mc, "RENAMENX", keys_RENAMENX_test) runStep(t, mc, "EXPIRE", keys_EXPIRE_test) runStep(t, mc, "FSET", keys_FSET_test) runStep(t, mc, "GET", keys_GET_test) @@ -65,6 +67,40 @@ func keys_DROP_test(mc *mockServer) error { {"SCAN", "mykey", "COUNT"}, {0}, }) } +func keys_RENAME_test(mc *mockServer) error { + return mc.DoBatch([][]interface{}{ + {"SET", "mykey", "myid1", "HASH", "9my5xp7"}, {"OK"}, + {"SET", "mykey", "myid2", "HASH", "9my5xp8"}, {"OK"}, + {"SCAN", "mykey", "COUNT"}, {2}, + {"RENAME", "mykey", "mynewkey"}, {1}, + {"SCAN", "mykey", "COUNT"}, {0}, + {"SCAN", "mynewkey", "COUNT"}, {2}, + {"SET", "mykey", "myid3", "HASH", "9my5xp7"}, {"OK"}, + {"RENAME", "mykey", "mynewkey"}, {1}, + {"SCAN", "mykey", "COUNT"}, {0}, + {"SCAN", "mynewkey", "COUNT"}, {1}, + {"RENAME", "foo", "mynewkey"}, {"ERR key not found"}, + {"SCAN", "mynewkey", "COUNT"}, {1}, + }) +} +func keys_RENAMENX_test(mc *mockServer) error { + return mc.DoBatch([][]interface{}{ + {"SET", "mykey", "myid1", "HASH", "9my5xp7"}, {"OK"}, + {"SET", "mykey", "myid2", "HASH", "9my5xp8"}, {"OK"}, + {"SCAN", "mykey", "COUNT"}, {2}, + {"RENAMENX", "mykey", "mynewkey"}, {1}, + {"SCAN", "mykey", "COUNT"}, {0}, + {"DROP", "mykey"}, {0}, + {"SCAN", "mykey", "COUNT"}, {0}, + {"SCAN", "mynewkey", "COUNT"}, {2}, + {"SET", "mykey", "myid3", "HASH", "9my5xp7"}, {"OK"}, + {"RENAMENX", "mykey", "mynewkey"}, {0}, + {"SCAN", "mykey", "COUNT"}, {1}, + {"SCAN", "mynewkey", "COUNT"}, {2}, + {"RENAMENX", "foo", "mynewkey"}, {"ERR key not found"}, + {"SCAN", "mynewkey", "COUNT"}, {2}, + }) +} func keys_EXPIRE_test(mc *mockServer) error { return mc.DoBatch([][]interface{}{ {"SET", "mykey", "myid", "STRING", "value"}, {"OK"},