diff --git a/server/client_resp.go b/server/client_resp.go index 6d493f9..475445a 100644 --- a/server/client_resp.go +++ b/server/client_resp.go @@ -134,6 +134,11 @@ func (c *respClient) run() { reqData, err := c.readRequest() if err == nil { err = c.handleRequest(reqData) + + c.cmd = "" + c.args = nil + + c.ar.Reset() } if err != nil { @@ -154,6 +159,16 @@ func (c *respClient) handleRequest(reqData [][]byte) error { c.cmd = hack.String(lowerSlice(reqData[0])) c.args = reqData[1:] } + + if c.cmd == "xuse" { + err := c.handleUseThenCmd() + if err != nil { + c.resp.writeError(err) + c.resp.flush() + return nil + } + } + if c.cmd == "quit" { c.activeQuit = true c.resp.writeStatus(OK) @@ -164,10 +179,35 @@ func (c *respClient) handleRequest(reqData [][]byte) error { c.perform() - c.cmd = "" - c.args = nil + return nil +} - c.ar.Reset() +// XUSE db THEN command +func (c *respClient) handleUseThenCmd() error { + if len(c.args) <= 2 { + // invalid command format + return fmt.Errorf("invalid format for XUSE, must XUSE db THEN your command") + } + + if hack.String(upperSlice(c.args[1])) != "THEN" { + // invalid command format, just resturn here + return fmt.Errorf("invalid format for XUSE, must XUSE db THEN your command") + } + + index, err := strconv.Atoi(hack.String(c.args[0])) + if err != nil { + return fmt.Errorf("invalid db for XUSE, err %v", err) + } + + db, err := c.app.ldb.Select(index) + if err != nil { + return fmt.Errorf("invalid db for XUSE, err %v", err) + } + + c.db = db + + c.cmd = hack.String(lowerSlice(c.args[2])) + c.args = c.args[3:] return nil } diff --git a/server/cmd_server_test.go b/server/cmd_server_test.go new file mode 100644 index 0000000..636b90e --- /dev/null +++ b/server/cmd_server_test.go @@ -0,0 +1,38 @@ +package server + +import ( + "github.com/siddontang/goredis" + "testing" +) + +func TestXuse(t *testing.T) { + c1 := getTestConn() + defer c1.Close() + + c2 := getTestConn() + defer c2.Close() + + _, err := c1.Do("XUSE", "1", "THEN", "SET", "tmp_select_key", "1") + if err != nil { + t.Fatal(err) + } + + _, err = goredis.Int(c2.Do("GET", "tmp_select_key")) + if err != goredis.ErrNil { + t.Fatal(err) + } + + n, _ := goredis.Int(c2.Do("XUSE", "1", "THEN", "GET", "tmp_select_key")) + if n != 1 { + t.Fatal(n) + } + + n, _ = goredis.Int(c2.Do("GET", "tmp_select_key")) + if n != 1 { + t.Fatal(n) + } + + c1.Do("SELECT", 0) + c2.Do("SELECT", 0) + +} diff --git a/server/util.go b/server/util.go index 7c5b73d..a6ef5d4 100644 --- a/server/util.go +++ b/server/util.go @@ -134,3 +134,14 @@ func lowerSlice(buf []byte) []byte { } return buf } + +func upperSlice(buf []byte) []byte { + for i, r := range buf { + if 'a' <= r && r <= 'z' { + r -= 'a' - 'A' + } + + buf[i] = r + } + return buf +}