Merge branch 'develop'

This commit is contained in:
siddontang 2014-06-20 10:13:52 +08:00
commit 8e8f3704e6
74 changed files with 10504 additions and 81 deletions

View File

@ -2,6 +2,4 @@
. ./dev.sh . ./dev.sh
go get -u github.com/siddontang/go-leveldb/leveldb #nothing to do now
go get -u github.com/siddontang/go-log/log
go get -u github.com/garyburd/redigo/redis

View File

@ -0,0 +1,45 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"strings"
)
const (
watchState = 1 << iota
multiState
subscribeState
monitorState
)
type commandInfo struct {
set, clear int
}
var commandInfos = map[string]commandInfo{
"WATCH": commandInfo{set: watchState},
"UNWATCH": commandInfo{clear: watchState},
"MULTI": commandInfo{set: multiState},
"EXEC": commandInfo{clear: watchState | multiState},
"DISCARD": commandInfo{clear: watchState | multiState},
"PSUBSCRIBE": commandInfo{set: subscribeState},
"SUBSCRIBE": commandInfo{set: subscribeState},
"MONITOR": commandInfo{set: monitorState},
}
func lookupCommandInfo(commandName string) commandInfo {
return commandInfos[strings.ToUpper(commandName)]
}

418
client/go/redis/conn.go Normal file
View File

@ -0,0 +1,418 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
)
// conn is the low-level implementation of Conn
type conn struct {
// Shared
mu sync.Mutex
pending int
err error
conn net.Conn
// Read
readTimeout time.Duration
br *bufio.Reader
// Write
writeTimeout time.Duration
bw *bufio.Writer
// Scratch space for formatting argument length.
// '*' or '$', length, "\r\n"
lenScratch [32]byte
// Scratch space for formatting integers and floats.
numScratch [40]byte
}
// Dial connects to the Redis server at the given network and address.
func Dial(network, address string) (Conn, error) {
c, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewConn(c, 0, 0), nil
}
// DialTimeout acts like Dial but takes timeouts for establishing the
// connection to the server, writing a command and reading a reply.
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
var c net.Conn
var err error
if connectTimeout > 0 {
c, err = net.DialTimeout(network, address, connectTimeout)
} else {
c, err = net.Dial(network, address)
}
if err != nil {
return nil, err
}
return NewConn(c, readTimeout, writeTimeout), nil
}
// NewConn returns a new Redigo connection for the given net connection.
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
return &conn{
conn: netConn,
bw: bufio.NewWriter(netConn),
br: bufio.NewReader(netConn),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
}
func (c *conn) Close() error {
c.mu.Lock()
err := c.err
if c.err == nil {
c.err = errors.New("redigo: closed")
err = c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) fatal(err error) error {
c.mu.Lock()
if c.err == nil {
c.err = err
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
func (c *conn) writeLen(prefix byte, n int) error {
c.lenScratch[len(c.lenScratch)-1] = '\n'
c.lenScratch[len(c.lenScratch)-2] = '\r'
i := len(c.lenScratch) - 3
for {
c.lenScratch[i] = byte('0' + n%10)
i -= 1
n = n / 10
if n == 0 {
break
}
}
c.lenScratch[i] = prefix
_, err := c.bw.Write(c.lenScratch[i:])
return err
}
func (c *conn) writeString(s string) error {
c.writeLen('$', len(s))
c.bw.WriteString(s)
_, err := c.bw.WriteString("\r\n")
return err
}
func (c *conn) writeBytes(p []byte) error {
c.writeLen('$', len(p))
c.bw.Write(p)
_, err := c.bw.WriteString("\r\n")
return err
}
func (c *conn) writeInt64(n int64) error {
return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
}
func (c *conn) writeFloat64(n float64) error {
return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
}
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
c.writeLen('*', 1+len(args))
err = c.writeString(cmd)
for _, arg := range args {
if err != nil {
break
}
switch arg := arg.(type) {
case string:
err = c.writeString(arg)
case []byte:
err = c.writeBytes(arg)
case int:
err = c.writeInt64(int64(arg))
case int64:
err = c.writeInt64(arg)
case float64:
err = c.writeFloat64(arg)
case bool:
if arg {
err = c.writeString("1")
} else {
err = c.writeString("0")
}
case nil:
err = c.writeString("")
default:
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
err = c.writeBytes(buf.Bytes())
}
}
return err
}
func (c *conn) readLine() ([]byte, error) {
p, err := c.br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
return nil, errors.New("redigo: long response line")
}
if err != nil {
return nil, err
}
i := len(p) - 2
if i < 0 || p[i] != '\r' {
return nil, errors.New("redigo: bad response line terminator")
}
return p[:i], nil
}
// parseLen parses bulk string and array lengths.
func parseLen(p []byte) (int, error) {
if len(p) == 0 {
return -1, errors.New("redigo: malformed length")
}
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
// handle $-1 and $-1 null replies.
return -1, nil
}
var n int
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return -1, errors.New("redigo: illegal bytes in length")
}
n += int(b - '0')
}
return n, nil
}
// parseInt parses an integer reply.
func parseInt(p []byte) (interface{}, error) {
if len(p) == 0 {
return 0, errors.New("redigo: malformed integer")
}
var negate bool
if p[0] == '-' {
negate = true
p = p[1:]
if len(p) == 0 {
return 0, errors.New("redigo: malformed integer")
}
}
var n int64
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return 0, errors.New("redigo: illegal bytes in length")
}
n += int64(b - '0')
}
if negate {
n = -n
}
return n, nil
}
var (
okReply interface{} = "OK"
pongReply interface{} = "PONG"
)
func (c *conn) readReply() (interface{}, error) {
line, err := c.readLine()
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, errors.New("redigo: short response line")
}
switch line[0] {
case '+':
switch {
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
// Avoid allocation for frequent "+OK" response.
return okReply, nil
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
// Avoid allocation in PING command benchmarks :)
return pongReply, nil
default:
return string(line[1:]), nil
}
case '-':
return Error(string(line[1:])), nil
case ':':
return parseInt(line[1:])
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(c.br, p)
if err != nil {
return nil, err
}
if line, err := c.readLine(); err != nil {
return nil, err
} else if len(line) != 0 {
return nil, errors.New("redigo: bad bulk string format")
}
return p, nil
case '*':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
r := make([]interface{}, n)
for i := range r {
r[i], err = c.readReply()
if err != nil {
return nil, err
}
}
return r, nil
}
return nil, errors.New("redigo: unexpected response line")
}
func (c *conn) Send(cmd string, args ...interface{}) error {
c.mu.Lock()
c.pending += 1
c.mu.Unlock()
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if err := c.writeCommand(cmd, args); err != nil {
return c.fatal(err)
}
return nil
}
func (c *conn) Flush() error {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if err := c.bw.Flush(); err != nil {
return c.fatal(err)
}
return nil
}
func (c *conn) Receive() (reply interface{}, err error) {
c.mu.Lock()
// There can be more receives than sends when using pub/sub. To allow
// normal use of the connection after unsubscribe from all channels, do not
// decrement pending to a negative value.
if c.pending > 0 {
c.pending -= 1
}
c.mu.Unlock()
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err)
}
if err, ok := reply.(Error); ok {
return nil, err
}
return
}
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock()
pending := c.pending
c.pending = 0
c.mu.Unlock()
if cmd == "" && pending == 0 {
return nil, nil
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if cmd != "" {
c.writeCommand(cmd, args)
}
if err := c.bw.Flush(); err != nil {
return nil, c.fatal(err)
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if cmd == "" {
reply := make([]interface{}, pending)
for i := range reply {
r, e := c.readReply()
if e != nil {
return nil, c.fatal(e)
}
reply[i] = r
}
return reply, nil
}
var err error
var reply interface{}
for i := 0; i <= pending; i++ {
var e error
if reply, e = c.readReply(); e != nil {
return nil, c.fatal(e)
}
if e, ok := reply.(Error); ok && err == nil {
err = e
}
}
return reply, err
}

167
client/go/redis/doc.go Normal file
View File

@ -0,0 +1,167 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package redis is a client for the Redis database.
//
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
// documentation about this package.
//
// Connections
//
// The Conn interface is the primary interface for working with Redis.
// Applications create connections by calling the Dial, DialWithTimeout or
// NewConn functions. In the future, functions will be added for creating
// sharded and other types of connections.
//
// The application must call the connection Close method when the application
// is done with the connection.
//
// Executing Commands
//
// The Conn interface has a generic method for executing Redis commands:
//
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
//
// The Redis command reference (http://redis.io/commands) lists the available
// commands. An example of using the Redis APPEND command is:
//
// n, err := conn.Do("APPEND", "key", "value")
//
// The Do method converts command arguments to binary strings for transmission
// to the server as follows:
//
// Go Type Conversion
// []byte Sent as is
// string Sent as is
// int, int64 strconv.FormatInt(v)
// float64 strconv.FormatFloat(v, 'g', -1, 64)
// bool true -> "1", false -> "0"
// nil ""
// all other types fmt.Print(v)
//
// Redis command reply types are represented using the following Go types:
//
// Redis type Go type
// error redis.Error
// integer int64
// simple string string
// bulk string []byte or nil if value not present.
// array []interface{} or nil if value not present.
//
// Use type assertions or the reply helper functions to convert from
// interface{} to the specific Go type for the command result.
//
// Pipelining
//
// Connections support pipelining using the Send, Flush and Receive methods.
//
// Send(commandName string, args ...interface{}) error
// Flush() error
// Receive() (reply interface{}, err error)
//
// Send writes the command to the connection's output buffer. Flush flushes the
// connection's output buffer to the server. Receive reads a single reply from
// the server. The following example shows a simple pipeline.
//
// c.Send("SET", "foo", "bar")
// c.Send("GET", "foo")
// c.Flush()
// c.Receive() // reply from SET
// v, err = c.Receive() // reply from GET
//
// The Do method combines the functionality of the Send, Flush and Receive
// methods. The Do method starts by writing the command and flushing the output
// buffer. Next, the Do method receives all pending replies including the reply
// for the command just sent by Do. If any of the received replies is an error,
// then Do returns the error. If there are no errors, then Do returns the last
// reply. If the command argument to the Do method is "", then the Do method
// will flush the output buffer and receive pending replies without sending a
// command.
//
// Use the Send and Do methods to implement pipelined transactions.
//
// c.Send("MULTI")
// c.Send("INCR", "foo")
// c.Send("INCR", "bar")
// r, err := c.Do("EXEC")
// fmt.Println(r) // prints [1, 1]
//
// Concurrency
//
// Connections support a single concurrent caller to the write methods (Send,
// Flush) and a single concurrent caller to the read method (Receive). Because
// Do method combines the functionality of Send, Flush and Receive, the Do
// method cannot be called concurrently with the other methods.
//
// For full concurrent access to Redis, use the thread-safe Pool to get and
// release connections from within a goroutine.
//
// Publish and Subscribe
//
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
//
// c.Send("SUBSCRIBE", "example")
// c.Flush()
// for {
// reply, err := c.Receive()
// if err != nil {
// return err
// }
// // process pushed message
// }
//
// The PubSubConn type wraps a Conn with convenience methods for implementing
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
// send and flush a subscription management command. The receive method
// converts a pushed message to convenient types for use in a type switch.
//
// psc := PubSubConn{c}
// psc.Subscribe("example")
// for {
// switch v := psc.Receive().(type) {
// case redis.Message:
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
// case redis.Subscription:
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
// case error:
// return v
// }
// }
//
// Reply Helpers
//
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
// to a value of a specific type. To allow convenient wrapping of calls to the
// connection Do and Receive methods, the functions take a second argument of
// type error. If the error is non-nil, then the helper function returns the
// error. If the error is nil, the function converts the reply to the specified
// type:
//
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
// if err != nil {
// // handle error return from c.Do or type conversion error.
// }
//
// The Scan function converts elements of a array reply to Go types:
//
// var value1 int
// var value2 string
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
// if err != nil {
// // handle error
// }
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
// // handle error
// }
package redis

117
client/go/redis/log.go Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"fmt"
"log"
)
// NewLoggingConn returns a logging wrapper around a connection.
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
if prefix != "" {
prefix = prefix + "."
}
return &loggingConn{conn, logger, prefix}
}
type loggingConn struct {
Conn
logger *log.Logger
prefix string
}
func (c *loggingConn) Close() error {
err := c.Conn.Close()
var buf bytes.Buffer
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err)
c.logger.Output(2, buf.String())
return err
}
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) {
const chop = 32
switch v := v.(type) {
case []byte:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case string:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case []interface{}:
if len(v) == 0 {
buf.WriteString("[]")
} else {
sep := "["
fin := "]"
if len(v) > chop {
v = v[:chop]
fin = "...]"
}
for _, vv := range v {
buf.WriteString(sep)
c.printValue(buf, vv)
sep = ", "
}
buf.WriteString(fin)
}
default:
fmt.Fprint(buf, v)
}
}
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) {
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s%s(", c.prefix, method)
if method != "Receive" {
buf.WriteString(commandName)
for _, arg := range args {
buf.WriteString(", ")
c.printValue(&buf, arg)
}
}
buf.WriteString(") -> (")
if method != "Send" {
c.printValue(&buf, reply)
buf.WriteString(", ")
}
fmt.Fprintf(&buf, "%v)", err)
c.logger.Output(3, buf.String())
}
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) {
reply, err := c.Conn.Do(commandName, args...)
c.print("Do", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) Send(commandName string, args ...interface{}) error {
err := c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err)
return err
}
func (c *loggingConn) Receive() (interface{}, error) {
reply, err := c.Conn.Receive()
c.print("Receive", "", nil, reply, err)
return reply, err
}

358
client/go/redis/pool.go Normal file
View File

@ -0,0 +1,358 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"container/list"
"crypto/rand"
"crypto/sha1"
"errors"
"io"
"strconv"
"sync"
"time"
)
var nowFunc = time.Now // for testing
// ErrPoolExhausted is returned from a pool connection method (Do, Send,
// Receive, Flush, Err) when the maximum number of database connections in the
// pool has been reached.
var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
var errPoolClosed = errors.New("redigo: connection pool closed")
// Pool maintains a pool of connections. The application calls the Get method
// to get a connection from the pool and the connection's Close method to
// return the connection's resources to the pool.
//
// The following example shows how to use a pool in a web application. The
// application creates a pool at application startup and makes it available to
// request handlers using a global variable.
//
// func newPool(server, password string) *redis.Pool {
// return &redis.Pool{
// MaxIdle: 3,
// IdleTimeout: 240 * time.Second,
// Dial: func () (redis.Conn, error) {
// c, err := redis.Dial("tcp", server)
// if err != nil {
// return nil, err
// }
// if _, err := c.Do("AUTH", password); err != nil {
// c.Close()
// return nil, err
// }
// return c, err
// },
// TestOnBorrow: func(c redis.Conn, t time.Time) error {
// _, err := c.Do("PING")
// return err
// },
// }
// }
//
// var (
// pool *redis.Pool
// redisServer = flag.String("redisServer", ":6379", "")
// redisPassword = flag.String("redisPassword", "", "")
// )
//
// func main() {
// flag.Parse()
// pool = newPool(*redisServer, *redisPassword)
// ...
// }
//
// A request handler gets a connection from the pool and closes the connection
// when the handler is done:
//
// func serveHome(w http.ResponseWriter, r *http.Request) {
// conn := pool.Get()
// defer conn.Close()
// ....
// }
//
type Pool struct {
// Dial is an application supplied function for creating new connections.
Dial func() (Conn, error)
// TestOnBorrow is an optional application supplied function for checking
// the health of an idle connection before the connection is used again by
// the application. Argument t is the time that the connection was returned
// to the pool. If the function returns an error, then the connection is
// closed.
TestOnBorrow func(c Conn, t time.Time) error
// Maximum number of idle connections in the pool.
MaxIdle int
// Maximum number of connections allocated by the pool at a given time.
// When zero, there is no limit on the number of connections in the pool.
MaxActive int
// Close connections after remaining idle for this duration. If the value
// is zero, then idle connections are not closed. Applications should set
// the timeout to a value less than the server's timeout.
IdleTimeout time.Duration
// mu protects fields defined below.
mu sync.Mutex
closed bool
active int
// Stack of idleConn with most recently used at the front.
idle list.List
}
type idleConn struct {
c Conn
t time.Time
}
// NewPool is a convenience function for initializing a pool.
func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
return &Pool{Dial: newFn, MaxIdle: maxIdle}
}
// Get gets a connection. The application must close the returned connection.
// The connection acquires an underlying connection on the first call to the
// connection Do, Send, Receive, Flush or Err methods. An application can force
// the connection to acquire an underlying connection without executing a Redis
// command by calling the Err method.
func (p *Pool) Get() Conn {
return &pooledConnection{p: p}
}
// ActiveCount returns the number of active connections in the pool.
func (p *Pool) ActiveCount() int {
p.mu.Lock()
active := p.active
p.mu.Unlock()
return active
}
// Close releases the resources used by the pool.
func (p *Pool) Close() error {
p.mu.Lock()
idle := p.idle
p.idle.Init()
p.closed = true
p.active -= idle.Len()
p.mu.Unlock()
for e := idle.Front(); e != nil; e = e.Next() {
e.Value.(idleConn).c.Close()
}
return nil
}
// get prunes stale connections and returns a connection from the idle list or
// creates a new connection.
func (p *Pool) get() (Conn, error) {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return nil, errors.New("redigo: get on closed pool")
}
// Prune stale connections.
if timeout := p.IdleTimeout; timeout > 0 {
for i, n := 0, p.idle.Len(); i < n; i++ {
e := p.idle.Back()
if e == nil {
break
}
ic := e.Value.(idleConn)
if ic.t.Add(timeout).After(nowFunc()) {
break
}
p.idle.Remove(e)
p.active -= 1
p.mu.Unlock()
ic.c.Close()
p.mu.Lock()
}
}
// Get idle connection.
for i, n := 0, p.idle.Len(); i < n; i++ {
e := p.idle.Front()
if e == nil {
break
}
ic := e.Value.(idleConn)
p.idle.Remove(e)
test := p.TestOnBorrow
p.mu.Unlock()
if test == nil || test(ic.c, ic.t) == nil {
return ic.c, nil
}
ic.c.Close()
p.mu.Lock()
p.active -= 1
}
if p.MaxActive > 0 && p.active >= p.MaxActive {
p.mu.Unlock()
return nil, ErrPoolExhausted
}
// No idle connection, create new.
dial := p.Dial
p.active += 1
p.mu.Unlock()
c, err := dial()
if err != nil {
p.mu.Lock()
p.active -= 1
p.mu.Unlock()
c = nil
}
return c, err
}
func (p *Pool) put(c Conn, forceClose bool) error {
if c.Err() == nil && !forceClose {
p.mu.Lock()
if !p.closed {
p.idle.PushFront(idleConn{t: nowFunc(), c: c})
if p.idle.Len() > p.MaxIdle {
c = p.idle.Remove(p.idle.Back()).(idleConn).c
} else {
c = nil
}
}
p.mu.Unlock()
}
if c != nil {
p.mu.Lock()
p.active -= 1
p.mu.Unlock()
return c.Close()
}
return nil
}
type pooledConnection struct {
c Conn
err error
p *Pool
state int
}
func (c *pooledConnection) get() error {
if c.err == nil && c.c == nil {
c.c, c.err = c.p.get()
}
return c.err
}
var (
sentinel []byte
sentinelOnce sync.Once
)
func initSentinel() {
p := make([]byte, 64)
if _, err := rand.Read(p); err == nil {
sentinel = p
} else {
h := sha1.New()
io.WriteString(h, "Oops, rand failed. Use time instead.")
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
sentinel = h.Sum(nil)
}
}
func (c *pooledConnection) Close() (err error) {
if c.c != nil {
if c.state&multiState != 0 {
c.c.Send("DISCARD")
c.state &^= (multiState | watchState)
} else if c.state&watchState != 0 {
c.c.Send("UNWATCH")
c.state &^= watchState
}
if c.state&subscribeState != 0 {
c.c.Send("UNSUBSCRIBE")
c.c.Send("PUNSUBSCRIBE")
// To detect the end of the message stream, ask the server to echo
// a sentinel value and read until we see that value.
sentinelOnce.Do(initSentinel)
c.c.Send("ECHO", sentinel)
c.c.Flush()
for {
p, err := c.c.Receive()
if err != nil {
break
}
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
c.state &^= subscribeState
break
}
}
}
c.c.Do("")
c.p.put(c.c, c.state != 0)
c.c = nil
c.err = errPoolClosed
}
return err
}
func (c *pooledConnection) Err() error {
if err := c.get(); err != nil {
return err
}
return c.c.Err()
}
func (c *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
if err := c.get(); err != nil {
return nil, err
}
ci := lookupCommandInfo(commandName)
c.state = (c.state | ci.set) &^ ci.clear
return c.c.Do(commandName, args...)
}
func (c *pooledConnection) Send(commandName string, args ...interface{}) error {
if err := c.get(); err != nil {
return err
}
ci := lookupCommandInfo(commandName)
c.state = (c.state | ci.set) &^ ci.clear
return c.c.Send(commandName, args...)
}
func (c *pooledConnection) Flush() error {
if err := c.get(); err != nil {
return err
}
return c.c.Flush()
}
func (c *pooledConnection) Receive() (reply interface{}, err error) {
if err := c.get(); err != nil {
return nil, err
}
return c.c.Receive()
}

129
client/go/redis/pubsub.go Normal file
View File

@ -0,0 +1,129 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
)
// Subscription represents a subscribe or unsubscribe notification.
type Subscription struct {
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
Kind string
// The channel that was changed.
Channel string
// The current number of subscriptions for connection.
Count int
}
// Message represents a message notification.
type Message struct {
// The originating channel.
Channel string
// The message data.
Data []byte
}
// PMessage represents a pmessage notification.
type PMessage struct {
// The matched pattern.
Pattern string
// The originating channel.
Channel string
// The message data.
Data []byte
}
// PubSubConn wraps a Conn with convenience methods for subscribers.
type PubSubConn struct {
Conn Conn
}
// Close closes the connection.
func (c PubSubConn) Close() error {
return c.Conn.Close()
}
// Subscribe subscribes the connection to the specified channels.
func (c PubSubConn) Subscribe(channel ...interface{}) error {
c.Conn.Send("SUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PSubscribe subscribes the connection to the given patterns.
func (c PubSubConn) PSubscribe(channel ...interface{}) error {
c.Conn.Send("PSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Unsubscribe unsubscribes the connection from the given channels, or from all
// of them if none is given.
func (c PubSubConn) Unsubscribe(channel ...interface{}) error {
c.Conn.Send("UNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
// of them if none is given.
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
c.Conn.Send("PUNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Receive returns a pushed message as a Subscription, Message, PMessage or
// error. The return value is intended to be used directly in a type switch as
// illustrated in the PubSubConn example.
func (c PubSubConn) Receive() interface{} {
reply, err := Values(c.Conn.Receive())
if err != nil {
return err
}
var kind string
reply, err = Scan(reply, &kind)
if err != nil {
return err
}
switch kind {
case "message":
var m Message
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil {
return err
}
return m
case "pmessage":
var pm PMessage
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil {
return err
}
return pm
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe":
s := Subscription{Kind: kind}
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil {
return err
}
return s
}
return errors.New("redigo: unknown pubsub notification")
}

44
client/go/redis/redis.go Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
// Conn represents a connection to a Redis server.
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value if the connection is broken. The returned
// value is either the first non-nil value returned from the underlying
// network connection or a protocol parsing error. Applications should
// close broken connections.
Err() error
// Do sends a command to the server and returns the received reply.
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{}) error
// Flush flushes the output buffer to the Redis server.
Flush() error
// Receive receives a single reply from the Redis server
Receive() (reply interface{}, err error)
}

271
client/go/redis/reply.go Normal file
View File

@ -0,0 +1,271 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"fmt"
"strconv"
)
// ErrNil indicates that a reply value is nil.
var ErrNil = errors.New("redigo: nil returned")
// Int is a helper that converts a command reply to an integer. If err is not
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
// reply to an int as follows:
//
// Reply type Result
// integer int(reply), nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int(reply interface{}, err error) (int, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
x := int(reply)
if int64(x) != reply {
return 0, strconv.ErrRange
}
return x, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 0)
return int(n), err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Int, got type %T", reply)
}
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int64(reply interface{}, err error) (int64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
return reply, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Int64, got type %T", reply)
}
var errNegativeInt = errors.New("redigo: unexpected value for Uint64")
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Uint64(reply interface{}, err error) (uint64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
if reply < 0 {
return 0, errNegativeInt
}
return uint64(reply), nil
case []byte:
n, err := strconv.ParseUint(string(reply), 10, 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Uint64, got type %T", reply)
}
// Float64 is a helper that converts a command reply to 64 bit float. If err is
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
// the reply to an int as follows:
//
// Reply type Result
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Float64(reply interface{}, err error) (float64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case []byte:
n, err := strconv.ParseFloat(string(reply), 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Float64, got type %T", reply)
}
// String is a helper that converts a command reply to a string. If err is not
// equal to nil, then String returns "", err. Otherwise String converts the
// reply to a string as follows:
//
// Reply type Result
// bulk string string(reply), nil
// simple string reply, nil
// nil "", ErrNil
// other "", error
func String(reply interface{}, err error) (string, error) {
if err != nil {
return "", err
}
switch reply := reply.(type) {
case []byte:
return string(reply), nil
case string:
return reply, nil
case nil:
return "", ErrNil
case Error:
return "", reply
}
return "", fmt.Errorf("redigo: unexpected type for String, got type %T", reply)
}
// Bytes is a helper that converts a command reply to a slice of bytes. If err
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
// the reply to a slice of bytes as follows:
//
// Reply type Result
// bulk string reply, nil
// simple string []byte(reply), nil
// nil nil, ErrNil
// other nil, error
func Bytes(reply interface{}, err error) ([]byte, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []byte:
return reply, nil
case string:
return []byte(reply), nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, fmt.Errorf("redigo: unexpected type for Bytes, got type %T", reply)
}
// Bool is a helper that converts a command reply to a boolean. If err is not
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
// reply to boolean as follows:
//
// Reply type Result
// integer value != 0, nil
// bulk string strconv.ParseBool(reply)
// nil false, ErrNil
// other false, error
func Bool(reply interface{}, err error) (bool, error) {
if err != nil {
return false, err
}
switch reply := reply.(type) {
case int64:
return reply != 0, nil
case []byte:
return strconv.ParseBool(string(reply))
case nil:
return false, ErrNil
case Error:
return false, reply
}
return false, fmt.Errorf("redigo: unexpected type for Bool, got type %T", reply)
}
// MultiBulk is deprecated. Use Values.
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) }
// Values is a helper that converts an array command reply to a []interface{}.
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
// converts the reply as follows:
//
// Reply type Result
// array reply, nil
// nil nil, ErrNil
// other nil, error
func Values(reply interface{}, err error) ([]interface{}, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
return reply, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, fmt.Errorf("redigo: unexpected type for Values, got type %T", reply)
}
// Strings is a helper that converts an array command reply to a []string. If
// err is not equal to nil, then Strings returns nil, err. If one of the array
// items is not a bulk string or nil, then Strings returns an error.
func Strings(reply interface{}, err error) ([]string, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
result := make([]string, len(reply))
for i := range reply {
if reply[i] == nil {
continue
}
p, ok := reply[i].([]byte)
if !ok {
return nil, fmt.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i])
}
result[i] = string(p)
}
return result, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, fmt.Errorf("redigo: unexpected type for Strings, got type %T", reply)
}

513
client/go/redis/scan.go Normal file
View File

@ -0,0 +1,513 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
)
func ensureLen(d reflect.Value, n int) {
if n > d.Cap() {
d.Set(reflect.MakeSlice(d.Type(), n, n))
} else {
d.SetLen(n)
}
}
func cannotConvert(d reflect.Value, s interface{}) error {
return fmt.Errorf("redigo: Scan cannot convert from %s to %s",
reflect.TypeOf(s), d.Type())
}
func convertAssignBytes(d reflect.Value, s []byte) (err error) {
switch d.Type().Kind() {
case reflect.Float32, reflect.Float64:
var x float64
x, err = strconv.ParseFloat(string(s), d.Type().Bits())
d.SetFloat(x)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var x int64
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
d.SetInt(x)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var x uint64
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
d.SetUint(x)
case reflect.Bool:
var x bool
x, err = strconv.ParseBool(string(s))
d.SetBool(x)
case reflect.String:
d.SetString(string(s))
case reflect.Slice:
if d.Type().Elem().Kind() != reflect.Uint8 {
err = cannotConvert(d, s)
} else {
d.SetBytes(s)
}
default:
err = cannotConvert(d, s)
}
return
}
func convertAssignInt(d reflect.Value, s int64) (err error) {
switch d.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
d.SetInt(s)
if d.Int() != s {
err = strconv.ErrRange
d.SetInt(0)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if s < 0 {
err = strconv.ErrRange
} else {
x := uint64(s)
d.SetUint(x)
if d.Uint() != x {
err = strconv.ErrRange
d.SetUint(0)
}
}
case reflect.Bool:
d.SetBool(s != 0)
default:
err = cannotConvert(d, s)
}
return
}
func convertAssignValue(d reflect.Value, s interface{}) (err error) {
switch s := s.(type) {
case []byte:
err = convertAssignBytes(d, s)
case int64:
err = convertAssignInt(d, s)
default:
err = cannotConvert(d, s)
}
return err
}
func convertAssignValues(d reflect.Value, s []interface{}) error {
if d.Type().Kind() != reflect.Slice {
return cannotConvert(d, s)
}
ensureLen(d, len(s))
for i := 0; i < len(s); i++ {
if err := convertAssignValue(d.Index(i), s[i]); err != nil {
return err
}
}
return nil
}
func convertAssign(d interface{}, s interface{}) (err error) {
// Handle the most common destination types using type switches and
// fall back to reflection for all other types.
switch s := s.(type) {
case nil:
// ingore
case []byte:
switch d := d.(type) {
case *string:
*d = string(s)
case *int:
*d, err = strconv.Atoi(string(s))
case *bool:
*d, err = strconv.ParseBool(string(s))
case *[]byte:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignBytes(d.Elem(), s)
}
}
case int64:
switch d := d.(type) {
case *int:
x := int(s)
if int64(x) != s {
err = strconv.ErrRange
x = 0
}
*d = x
case *bool:
*d = s != 0
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignInt(d.Elem(), s)
}
}
case []interface{}:
switch d := d.(type) {
case *[]interface{}:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignValues(d.Elem(), s)
}
}
case Error:
err = s
default:
err = cannotConvert(reflect.ValueOf(d), s)
}
return
}
// Scan copies from src to the values pointed at by dest.
//
// The values pointed at by dest must be an integer, float, boolean, string,
// []byte, interface{} or slices of these types. Scan uses the standard strconv
// package to convert bulk strings to numeric and boolean types.
//
// If a dest value is nil, then the corresponding src value is skipped.
//
// If a src element is nil, then the corresponding dest value is not modified.
//
// To enable easy use of Scan in a loop, Scan returns the slice of src
// following the copied values.
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) {
if len(src) < len(dest) {
return nil, errors.New("redigo: Scan array short")
}
var err error
for i, d := range dest {
err = convertAssign(d, src[i])
if err != nil {
break
}
}
return src[len(dest):], err
}
type fieldSpec struct {
name string
index []int
//omitEmpty bool
}
type structSpec struct {
m map[string]*fieldSpec
l []*fieldSpec
}
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
return ss.m[string(name)]
}
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
switch {
case f.PkgPath != "":
// Ignore unexported fields.
case f.Anonymous:
// TODO: Handle pointers. Requires change to decoder and
// protection against infinite recursion.
if f.Type.Kind() == reflect.Struct {
compileStructSpec(f.Type, depth, append(index, i), ss)
}
default:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")
p := strings.Split(tag, ",")
if len(p) > 0 {
if p[0] == "-" {
continue
}
if len(p[0]) > 0 {
fs.name = p[0]
}
for _, s := range p[1:] {
switch s {
//case "omitempty":
// fs.omitempty = true
default:
panic(errors.New("redigo: unknown field flag " + s + " for type " + t.Name()))
}
}
}
d, found := depth[fs.name]
if !found {
d = 1 << 30
}
switch {
case len(index) == d:
// At same depth, remove from result.
delete(ss.m, fs.name)
j := 0
for i := 0; i < len(ss.l); i++ {
if fs.name != ss.l[i].name {
ss.l[j] = ss.l[i]
j += 1
}
}
ss.l = ss.l[:j]
case len(index) < d:
fs.index = make([]int, len(index)+1)
copy(fs.index, index)
fs.index[len(index)] = i
depth[fs.name] = len(index)
ss.m[fs.name] = fs
ss.l = append(ss.l, fs)
}
}
}
}
var (
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
defaultFieldSpec = &fieldSpec{}
)
func structSpecForType(t reflect.Type) *structSpec {
structSpecMutex.RLock()
ss, found := structSpecCache[t]
structSpecMutex.RUnlock()
if found {
return ss
}
structSpecMutex.Lock()
defer structSpecMutex.Unlock()
ss, found = structSpecCache[t]
if found {
return ss
}
ss = &structSpec{m: make(map[string]*fieldSpec)}
compileStructSpec(t, make(map[string]int), nil, ss)
structSpecCache[t] = ss
return ss
}
var errScanStructValue = errors.New("redigo: ScanStruct value must be non-nil pointer to a struct")
// ScanStruct scans alternating names and values from src to a struct. The
// HGETALL and CONFIG GET commands return replies in this format.
//
// ScanStruct uses exported field names to match values in the response. Use
// 'redis' field tag to override the name:
//
// Field int `redis:"myName"`
//
// Fields with the tag redis:"-" are ignored.
//
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
// standard strconv package to convert bulk string values to numeric and
// boolean types.
//
// If a src element is nil, then the corresponding field is not modified.
func ScanStruct(src []interface{}, dest interface{}) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return errScanStructValue
}
d = d.Elem()
if d.Kind() != reflect.Struct {
return errScanStructValue
}
ss := structSpecForType(d.Type())
if len(src)%2 != 0 {
return errors.New("redigo: ScanStruct expects even number of values in values")
}
for i := 0; i < len(src); i += 2 {
s := src[i+1]
if s == nil {
continue
}
name, ok := src[i].([]byte)
if !ok {
return errors.New("redigo: ScanStruct key not a bulk string value")
}
fs := ss.fieldSpec(name)
if fs == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
return err
}
}
return nil
}
var (
errScanSliceValue = errors.New("redigo: ScanSlice dest must be non-nil pointer to a struct")
)
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
// slice must be integer, float, boolean, string, struct or pointer to struct
// values.
//
// Struct fields must be integer, float, boolean or string values. All struct
// fields are used unless a subset is specified using fieldNames.
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return errScanSliceValue
}
d = d.Elem()
if d.Kind() != reflect.Slice {
return errScanSliceValue
}
isPtr := false
t := d.Type().Elem()
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
isPtr = true
t = t.Elem()
}
if t.Kind() != reflect.Struct {
ensureLen(d, len(src))
for i, s := range src {
if s == nil {
continue
}
if err := convertAssignValue(d.Index(i), s); err != nil {
return err
}
}
return nil
}
ss := structSpecForType(t)
fss := ss.l
if len(fieldNames) > 0 {
fss = make([]*fieldSpec, len(fieldNames))
for i, name := range fieldNames {
fss[i] = ss.m[name]
if fss[i] == nil {
return errors.New("redigo: ScanSlice bad field name " + name)
}
}
}
if len(fss) == 0 {
return errors.New("redigo: ScanSlice no struct fields")
}
n := len(src) / len(fss)
if n*len(fss) != len(src) {
return errors.New("redigo: ScanSlice length not a multiple of struct field count")
}
ensureLen(d, n)
for i := 0; i < n; i++ {
d := d.Index(i)
if isPtr {
if d.IsNil() {
d.Set(reflect.New(t))
}
d = d.Elem()
}
for j, fs := range fss {
s := src[i*len(fss)+j]
if s == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
return err
}
}
}
return nil
}
// Args is a helper for constructing command arguments from structured values.
type Args []interface{}
// Add returns the result of appending value to args.
func (args Args) Add(value ...interface{}) Args {
return append(args, value...)
}
// AddFlat returns the result of appending the flattened value of v to args.
//
// Maps are flattened by appending the alternating keys and map values to args.
//
// Slices are flattened by appending the slice elements to args.
//
// Structs are flattened by appending the alternating names and values of
// exported fields to args. If v is a nil struct pointer, then nothing is
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
// for more information on the use of the 'redis' field tag.
//
// Other types are appended to args as is.
func (args Args) AddFlat(v interface{}) Args {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Struct:
args = flattenStruct(args, rv)
case reflect.Slice:
for i := 0; i < rv.Len(); i++ {
args = append(args, rv.Index(i).Interface())
}
case reflect.Map:
for _, k := range rv.MapKeys() {
args = append(args, k.Interface(), rv.MapIndex(k).Interface())
}
case reflect.Ptr:
if rv.Type().Elem().Kind() == reflect.Struct {
if !rv.IsNil() {
args = flattenStruct(args, rv.Elem())
}
} else {
args = append(args, v)
}
default:
args = append(args, v)
}
return args
}
func flattenStruct(args Args, v reflect.Value) Args {
ss := structSpecForType(v.Type())
for _, fs := range ss.l {
fv := v.FieldByIndex(fs.index)
args = append(args, fs.name, fv.Interface())
}
return args
}

86
client/go/redis/script.go Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"crypto/sha1"
"encoding/hex"
"io"
"strings"
)
// Script encapsulates the source, hash and key count for a Lua script. See
// http://redis.io/commands/eval for information on scripts in Redis.
type Script struct {
keyCount int
src string
hash string
}
// NewScript returns a new script object. If keyCount is greater than or equal
// to zero, then the count is automatically inserted in the EVAL command
// argument list. If keyCount is less than zero, then the application supplies
// the count as the first value in the keysAndArgs argument to the Do, Send and
// SendHash methods.
func NewScript(keyCount int, src string) *Script {
h := sha1.New()
io.WriteString(h, src)
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
}
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
var args []interface{}
if s.keyCount < 0 {
args = make([]interface{}, 1+len(keysAndArgs))
args[0] = spec
copy(args[1:], keysAndArgs)
} else {
args = make([]interface{}, 2+len(keysAndArgs))
args[0] = spec
args[1] = s.keyCount
copy(args[2:], keysAndArgs)
}
return args
}
// Do evalutes the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus
// causing the script to load).
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) {
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...)
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...)
}
return v, err
}
// SendHash evaluates the script without waiting for the reply. The script is
// evaluated with the EVALSHA command. The application must ensure that the
// script is loaded by a previous call to Send, Do or Load methods.
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...)
}
// Send evaluates the script without waiting for the reply.
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVAL", s.args(s.src, keysAndArgs)...)
}
// Load loads the script without evaluating it.
func (s *Script) Load(c Conn) error {
_, err := c.Do("SCRIPT", "LOAD", s.src)
return err
}

View File

@ -0,0 +1,31 @@
from redis.client import Redis, StrictRedis
from redis.connection import (
BlockingConnectionPool,
ConnectionPool,
Connection,
UnixDomainSocketConnection
)
from redis.utils import from_url
from redis.exceptions import (
AuthenticationError,
ConnectionError,
BusyLoadingError,
DataError,
InvalidResponse,
PubSubError,
RedisError,
ResponseError,
WatchError,
)
__version__ = '2.7.6'
VERSION = tuple(map(int, __version__.split('.')))
__all__ = [
'Redis', 'StrictRedis', 'ConnectionPool', 'BlockingConnectionPool',
'Connection', 'UnixDomainSocketConnection',
'RedisError', 'ConnectionError', 'ResponseError', 'AuthenticationError',
'InvalidResponse', 'DataError', 'PubSubError', 'WatchError', 'from_url',
'BusyLoadingError'
]

View File

@ -0,0 +1,79 @@
"""Internal module for Python 2 backwards compatibility."""
import sys
if sys.version_info[0] < 3:
from urlparse import urlparse
from itertools import imap, izip
from string import letters as ascii_letters
from Queue import Queue
try:
from cStringIO import StringIO as BytesIO
except ImportError:
from StringIO import StringIO as BytesIO
iteritems = lambda x: x.iteritems()
iterkeys = lambda x: x.iterkeys()
itervalues = lambda x: x.itervalues()
nativestr = lambda x: \
x if isinstance(x, str) else x.encode('utf-8', 'replace')
u = lambda x: x.decode()
b = lambda x: x
next = lambda x: x.next()
byte_to_chr = lambda x: x
unichr = unichr
xrange = xrange
basestring = basestring
unicode = unicode
bytes = str
long = long
else:
from urllib.parse import urlparse
from io import BytesIO
from string import ascii_letters
from queue import Queue
iteritems = lambda x: iter(x.items())
iterkeys = lambda x: iter(x.keys())
itervalues = lambda x: iter(x.values())
byte_to_chr = lambda x: chr(x)
nativestr = lambda x: \
x if isinstance(x, str) else x.decode('utf-8', 'replace')
u = lambda x: x
b = lambda x: x.encode('iso-8859-1') if not isinstance(x, bytes) else x
next = next
unichr = chr
imap = map
izip = zip
xrange = range
basestring = str
unicode = str
bytes = bytes
long = int
try: # Python 3
from queue import LifoQueue, Empty, Full
except ImportError:
from Queue import Empty, Full
try: # Python 2.6 - 2.7
from Queue import LifoQueue
except ImportError: # Python 2.5
from Queue import Queue
# From the Python 2.7 lib. Python 2.5 already extracted the core
# methods to aid implementating different queue organisations.
class LifoQueue(Queue):
"Override queue methods to implement a last-in first-out queue."
def _init(self, maxsize):
self.maxsize = maxsize
self.queue = []
def _qsize(self, len=len):
return len(self.queue)
def _put(self, item):
self.queue.append(item)
def _get(self):
return self.queue.pop()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,580 @@
from itertools import chain
import os
import socket
import sys
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
BytesIO, nativestr, basestring,
LifoQueue, Empty, Full)
from redis.exceptions import (
RedisError,
ConnectionError,
BusyLoadingError,
ResponseError,
InvalidResponse,
AuthenticationError,
NoScriptError,
ExecAbortError,
)
from redis.utils import HIREDIS_AVAILABLE
if HIREDIS_AVAILABLE:
import hiredis
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_LF = b('\n')
class PythonParser(object):
"Plain Python parsing class"
MAX_READ_LENGTH = 1000000
encoding = None
EXCEPTION_CLASSES = {
'ERR': ResponseError,
'EXECABORT': ExecAbortError,
'LOADING': BusyLoadingError,
'NOSCRIPT': NoScriptError,
}
def __init__(self):
self._fp = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
"Called when the socket connects"
self._fp = connection._sock.makefile('rb')
if connection.decode_responses:
self.encoding = connection.encoding
def on_disconnect(self):
"Called when the socket disconnects"
if self._fp is not None:
self._fp.close()
self._fp = None
def read(self, length=None):
"""
Read a line from the socket if no length is specified,
otherwise read ``length`` bytes. Always strip away the newlines.
"""
try:
if length is not None:
bytes_left = length + 2 # read the line ending
if length > self.MAX_READ_LENGTH:
# apparently reading more than 1MB or so from a windows
# socket can cause MemoryErrors. See:
# https://github.com/andymccurdy/redis-py/issues/205
# read smaller chunks at a time to work around this
try:
buf = BytesIO()
while bytes_left > 0:
read_len = min(bytes_left, self.MAX_READ_LENGTH)
buf.write(self._fp.read(read_len))
bytes_left -= read_len
buf.seek(0)
return buf.read(length)
finally:
buf.close()
return self._fp.read(bytes_left)[:-2]
# no length, read a full line
return self._fp.readline()[:-2]
except (socket.error, socket.timeout):
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: %s" %
(e.args,))
def parse_error(self, response):
"Parse an error response"
error_code = response.split(' ')[0]
if error_code in self.EXCEPTION_CLASSES:
response = response[len(error_code) + 1:]
return self.EXCEPTION_CLASSES[error_code](response)
return ResponseError(response)
def read_response(self):
response = self.read()
if not response:
raise ConnectionError("Socket closed on remote end")
byte, response = byte_to_chr(response[0]), response[1:]
if byte not in ('-', '+', ':', '$', '*'):
raise InvalidResponse("Protocol Error")
# server returned an error
if byte == '-':
response = nativestr(response)
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == '+':
pass
# int value
elif byte == ':':
response = long(response)
# bulk response
elif byte == '$':
length = int(response)
if length == -1:
return None
response = self.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
if length == -1:
return None
response = [self.read_response() for i in xrange(length)]
if isinstance(response, bytes) and self.encoding:
response = response.decode(self.encoding)
return response
class HiredisParser(object):
"Parser class for connections using Hiredis"
def __init__(self):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
self._sock = connection._sock
kwargs = {
'protocolError': InvalidResponse,
'replyError': ResponseError,
}
if connection.decode_responses:
kwargs['encoding'] = connection.encoding
self._reader = hiredis.Reader(**kwargs)
def on_disconnect(self):
self._sock = None
self._reader = None
def read_response(self):
if not self._reader:
raise ConnectionError("Socket closed on remote end")
response = self._reader.gets()
while response is False:
try:
buffer = self._sock.recv(4096)
except (socket.error, socket.timeout):
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: %s" %
(e.args,))
if not buffer:
raise ConnectionError("Socket closed on remote end")
self._reader.feed(buffer)
# proactively, but not conclusively, check if more data is in the
# buffer. if the data received doesn't end with \n, there's more.
if not buffer.endswith(SYM_LF):
continue
response = self._reader.gets()
return response
if HIREDIS_AVAILABLE:
DefaultParser = HiredisParser
else:
DefaultParser = PythonParser
class Connection(object):
"Manages TCP communication to and from a Redis server"
def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser):
self.pid = os.getpid()
self.host = host
self.port = port
self.db = db
self.password = password
self.socket_timeout = socket_timeout
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
self._sock = None
self._parser = parser_class()
def __del__(self):
try:
self.disconnect()
except Exception:
pass
def connect(self):
"Connects to the Redis server if not already connected"
if self._sock:
return
try:
sock = self._connect()
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError(self._error_message(e))
self._sock = sock
self.on_connect()
def _connect(self):
"Create a TCP socket connection"
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.socket_timeout)
sock.connect((self.host, self.port))
return sock
def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
if len(exception.args) == 1:
return "Error connecting to %s:%s. %s." % \
(self.host, self.port, exception.args[0])
else:
return "Error %s connecting %s:%s. %s." % \
(exception.args[0], self.host, self.port, exception.args[1])
def on_connect(self):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)
# if a password is specified, authenticate
if self.password:
self.send_command('AUTH', self.password)
if nativestr(self.read_response()) != 'OK':
raise AuthenticationError('Invalid Password')
# if a database is specified, switch to it
if self.db:
self.send_command('SELECT', self.db)
if nativestr(self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')
def disconnect(self):
"Disconnects from the Redis server"
self._parser.on_disconnect()
if self._sock is None:
return
try:
self._sock.close()
except socket.error:
pass
self._sock = None
def send_packed_command(self, command):
"Send an already packed command to the Redis server"
if not self._sock:
self.connect()
try:
self._sock.sendall(command)
except socket.error:
e = sys.exc_info()[1]
self.disconnect()
if len(e.args) == 1:
_errno, errmsg = 'UNKNOWN', e.args[0]
else:
_errno, errmsg = e.args
raise ConnectionError("Error %s while writing to socket. %s." %
(_errno, errmsg))
except Exception:
self.disconnect()
raise
def send_command(self, *args):
"Pack and send a command to the Redis server"
self.send_packed_command(self.pack_command(*args))
def read_response(self):
"Read the response from a previously sent command"
try:
response = self._parser.read_response()
except Exception:
self.disconnect()
raise
if isinstance(response, ResponseError):
raise response
return response
def encode(self, value):
"Return a bytestring representation of the value"
if isinstance(value, bytes):
return value
if isinstance(value, float):
value = repr(value)
if not isinstance(value, basestring):
value = str(value)
if isinstance(value, unicode):
value = value.encode(self.encoding, self.encoding_errors)
return value
def pack_command(self, *args):
"Pack a series of arguments into a value Redis command"
output = SYM_STAR + b(str(len(args))) + SYM_CRLF
for enc_value in imap(self.encode, args):
output += SYM_DOLLAR
output += b(str(len(enc_value)))
output += SYM_CRLF
output += enc_value
output += SYM_CRLF
return output
class UnixDomainSocketConnection(Connection):
def __init__(self, path='', db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser):
self.pid = os.getpid()
self.path = path
self.db = db
self.password = password
self.socket_timeout = socket_timeout
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
self._sock = None
self._parser = parser_class()
def _connect(self):
"Create a Unix domain socket connection"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.socket_timeout)
sock.connect(self.path)
return sock
def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
if len(exception.args) == 1:
return "Error connecting to unix socket: %s. %s." % \
(self.path, exception.args[0])
else:
return "Error %s connecting to unix socket: %s. %s." % \
(exception.args[0], self.path, exception.args[1])
# TODO: add ability to block waiting on a connection to be released
class ConnectionPool(object):
"Generic connection pool"
def __init__(self, connection_class=Connection, max_connections=None,
**connection_kwargs):
self.pid = os.getpid()
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections or 2 ** 31
self._created_connections = 0
self._available_connections = []
self._in_use_connections = set()
def _checkpid(self):
if self.pid != os.getpid():
self.disconnect()
self.__init__(self.connection_class, self.max_connections,
**self.connection_kwargs)
def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
self._checkpid()
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
return connection
def make_connection(self):
"Create a new connection"
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
def release(self, connection):
"Releases the connection back to the pool"
self._checkpid()
if connection.pid == self.pid:
self._in_use_connections.remove(connection)
self._available_connections.append(connection)
def disconnect(self):
"Disconnects all connections in the pool"
all_conns = chain(self._available_connections,
self._in_use_connections)
for connection in all_conns:
connection.disconnect()
class BlockingConnectionPool(object):
"""
Thread-safe blocking connection pool::
>>> from redis.client import Redis
>>> client = Redis(connection_pool=BlockingConnectionPool())
It performs the same function as the default
``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
it maintains a pool of reusable connections that can be shared by
multiple redis clients (safely across threads if required).
The difference is that, in the event that a client tries to get a
connection from the pool when all of connections are in use, rather than
raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
makes the client wait ("blocks") for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
>>> pool = BlockingConnectionPool(max_connections=10)
Use ``timeout`` to tell it either how many seconds to wait for a connection
to become available, or to block forever:
# Block forever.
>>> pool = BlockingConnectionPool(timeout=None)
# Raise a ``ConnectionError`` after five seconds if a connection is
# not available.
>>> pool = BlockingConnectionPool(timeout=5)
"""
def __init__(self, max_connections=50, timeout=20, connection_class=None,
queue_class=None, **connection_kwargs):
"Compose and assign values."
# Compose.
if connection_class is None:
connection_class = Connection
if queue_class is None:
queue_class = LifoQueue
# Assign.
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.queue_class = queue_class
self.max_connections = max_connections
self.timeout = timeout
# Validate the ``max_connections``. With the "fill up the queue"
# algorithm we use, it must be a positive integer.
is_valid = isinstance(max_connections, int) and max_connections > 0
if not is_valid:
raise ValueError('``max_connections`` must be a positive integer')
# Get the current process id, so we can disconnect and reinstantiate if
# it changes.
self.pid = os.getpid()
# Create and fill up a thread safe queue with ``None`` values.
self.pool = self.queue_class(max_connections)
while True:
try:
self.pool.put_nowait(None)
except Full:
break
# Keep a list of actual connection instances so that we can
# disconnect them later.
self._connections = []
def _checkpid(self):
"""
Check the current process id. If it has changed, disconnect and
re-instantiate this connection pool instance.
"""
# Get the current process id.
pid = os.getpid()
# If it hasn't changed since we were instantiated, then we're fine, so
# just exit, remaining connected.
if self.pid == pid:
return
# If it has changed, then disconnect and re-instantiate.
self.disconnect()
self.reinstantiate()
def make_connection(self):
"Make a fresh connection."
connection = self.connection_class(**self.connection_kwargs)
self._connections.append(connection)
return connection
def get_connection(self, command_name, *keys, **options):
"""
Get a connection, blocking for ``self.timeout`` until a connection
is available from the pool.
If the connection returned is ``None`` then creates a new connection.
Because we use a last-in first-out queue, the existing connections
(having been returned to the pool after the initial ``None`` values
were added) will be returned before ``None`` values. This means we only
create new connections when we need to, i.e.: the actual number of
connections will only increase in response to demand.
"""
# Make sure we haven't changed process.
self._checkpid()
# Try and get a connection from the pool. If one isn't available within
# self.timeout then raise a ``ConnectionError``.
connection = None
try:
connection = self.pool.get(block=True, timeout=self.timeout)
except Empty:
# Note that this is not caught by the redis client and will be
# raised unless handled by application code. If you want never to
raise ConnectionError("No connection available.")
# If the ``connection`` is actually ``None`` then that's a cue to make
# a new connection to add to the pool.
if connection is None:
connection = self.make_connection()
return connection
def release(self, connection):
"Releases the connection back to the pool."
# Make sure we haven't changed process.
self._checkpid()
# Put the connection back into the pool.
try:
self.pool.put_nowait(connection)
except Full:
# This shouldn't normally happen but might perhaps happen after a
# reinstantiation. So, we can handle the exception by not putting
# the connection back on the pool, because we definitely do not
# want to reuse it.
pass
def disconnect(self):
"Disconnects all connections in the pool."
for connection in self._connections:
connection.disconnect()
def reinstantiate(self):
"""
Reinstatiate this instance within a new process with a new connection
pool set.
"""
self.__init__(max_connections=self.max_connections,
timeout=self.timeout,
connection_class=self.connection_class,
queue_class=self.queue_class, **self.connection_kwargs)

View File

@ -0,0 +1,49 @@
"Core exceptions raised by the Redis client"
class RedisError(Exception):
pass
class AuthenticationError(RedisError):
pass
class ServerError(RedisError):
pass
class ConnectionError(ServerError):
pass
class BusyLoadingError(ConnectionError):
pass
class InvalidResponse(ServerError):
pass
class ResponseError(RedisError):
pass
class DataError(RedisError):
pass
class PubSubError(RedisError):
pass
class WatchError(RedisError):
pass
class NoScriptError(ResponseError):
pass
class ExecAbortError(ResponseError):
pass

View File

@ -0,0 +1,16 @@
try:
import hiredis
HIREDIS_AVAILABLE = True
except ImportError:
HIREDIS_AVAILABLE = False
def from_url(url, db=None, **kwargs):
"""
Returns an active Redis client generated from the given database URL.
Will attempt to extract the database id from the path url fragment, if
none is provided.
"""
from redis.client import Redis
return Redis.from_url(url, db, **kwargs)

61
client/ledis-py/setup.py Normal file
View File

@ -0,0 +1,61 @@
#!/usr/bin/env python
import os
import sys
from redis import __version__
try:
from setuptools import setup
from setuptools.command.test import test as TestCommand
class PyTest(TestCommand):
def finalize_options(self):
TestCommand.finalize_options(self)
self.test_args = []
self.test_suite = True
def run_tests(self):
# import here, because outside the eggs aren't loaded
import pytest
errno = pytest.main(self.test_args)
sys.exit(errno)
except ImportError:
from distutils.core import setup
PyTest = lambda x: x
f = open(os.path.join(os.path.dirname(__file__), 'README.rst'))
long_description = f.read()
f.close()
setup(
name='redis',
version=__version__,
description='Python client for Redis key-value store',
long_description=long_description,
url='http://github.com/andymccurdy/redis-py',
author='Andy McCurdy',
author_email='sedrik@gmail.com',
maintainer='Andy McCurdy',
maintainer_email='sedrik@gmail.com',
keywords=['Redis', 'key-value store'],
license='MIT',
packages=['redis'],
tests_require=['pytest>=2.5.0'],
cmdclass={'test': PyTest},
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Console',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
]
)

View File

View File

@ -0,0 +1,46 @@
import pytest
import redis
from distutils.version import StrictVersion
_REDIS_VERSIONS = {}
def get_version(**kwargs):
params = {'host': 'localhost', 'port': 6379, 'db': 9}
params.update(kwargs)
key = '%s:%s' % (params['host'], params['port'])
if key not in _REDIS_VERSIONS:
client = redis.Redis(**params)
_REDIS_VERSIONS[key] = client.info()['redis_version']
client.connection_pool.disconnect()
return _REDIS_VERSIONS[key]
def _get_client(cls, request=None, **kwargs):
params = {'host': 'localhost', 'port': 6379, 'db': 9}
params.update(kwargs)
client = cls(**params)
client.flushdb()
if request:
def teardown():
client.flushdb()
client.connection_pool.disconnect()
request.addfinalizer(teardown)
return client
def skip_if_server_version_lt(min_version):
check = StrictVersion(get_version()) < StrictVersion(min_version)
return pytest.mark.skipif(check, reason="")
@pytest.fixture()
def r(request, **kwargs):
return _get_client(redis.Redis, request, **kwargs)
@pytest.fixture()
def sr(request, **kwargs):
return _get_client(redis.StrictRedis, request, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,402 @@
from __future__ import with_statement
import os
import pytest
import redis
import time
import re
from threading import Thread
from redis.connection import ssl_available
from .conftest import skip_if_server_version_lt
class DummyConnection(object):
description_format = "DummyConnection<>"
def __init__(self, **kwargs):
self.kwargs = kwargs
self.pid = os.getpid()
class TestConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=None,
connection_class=DummyConnection):
connection_kwargs = connection_kwargs or {}
pool = redis.ConnectionPool(
connection_class=connection_class,
max_connections=max_connections,
**connection_kwargs)
return pool
def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = pool.get_connection('_')
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs
def test_multiple_connections(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
c2 = pool.get_connection('_')
assert c1 != c2
def test_max_connections(self):
pool = self.get_pool(max_connections=2)
pool.get_connection('_')
pool.get_connection('_')
with pytest.raises(redis.ConnectionError):
pool.get_connection('_')
def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
pool.release(c1)
c2 = pool.get_connection('_')
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.Connection)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=1>>'
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
connection_kwargs = {'path': '/abc', 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.UnixDomainSocketConnection)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
assert repr(pool) == expected
class TestBlockingConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
connection_kwargs = connection_kwargs or {}
pool = redis.BlockingConnectionPool(connection_class=DummyConnection,
max_connections=max_connections,
timeout=timeout,
**connection_kwargs)
return pool
def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = pool.get_connection('_')
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs
def test_multiple_connections(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
c2 = pool.get_connection('_')
assert c1 != c2
def test_connection_pool_blocks_until_timeout(self):
"When out of connections, block for timeout seconds, then raise"
pool = self.get_pool(max_connections=1, timeout=0.1)
pool.get_connection('_')
start = time.time()
with pytest.raises(redis.ConnectionError):
pool.get_connection('_')
# we should have waited at least 0.1 seconds
assert time.time() - start >= 0.1
def connection_pool_blocks_until_another_connection_released(self):
"""
When out of connections, block until another connection is released
to the pool
"""
pool = self.get_pool(max_connections=1, timeout=2)
c1 = pool.get_connection('_')
def target():
time.sleep(0.1)
pool.release(c1)
Thread(target=target).start()
start = time.time()
pool.get_connection('_')
assert time.time() - start >= 0.1
def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
pool.release(c1)
c2 = pool.get_connection('_')
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=0>>'
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
pool = redis.ConnectionPool(
connection_class=redis.UnixDomainSocketConnection,
path='abc',
db=0,
)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=abc,db=0>>'
assert repr(pool) == expected
class TestConnectionPoolURLParsing(object):
def test_defaults(self):
pool = redis.ConnectionPool.from_url('redis://localhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
}
def test_hostname(self):
pool = redis.ConnectionPool.from_url('redis://myhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'myhost',
'port': 6379,
'db': 0,
'password': None,
}
def test_port(self):
pool = redis.ConnectionPool.from_url('redis://localhost:6380')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6380,
'db': 0,
'password': None,
}
def test_password(self):
pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': 'mypassword',
}
def test_db_as_argument(self):
pool = redis.ConnectionPool.from_url('redis://localhost', db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 1,
'password': None,
}
def test_db_in_path(self):
pool = redis.ConnectionPool.from_url('redis://localhost/2', db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 2,
'password': None,
}
def test_db_in_querystring(self):
pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3',
db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 3,
'password': None,
}
def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
'a': '1',
'b': '2'
}
def test_calling_from_subclass_returns_correct_instance(self):
pool = redis.BlockingConnectionPool.from_url('redis://localhost')
assert isinstance(pool, redis.BlockingConnectionPool)
def test_client_creates_connection_pool(self):
r = redis.StrictRedis.from_url('redis://myhost')
assert r.connection_pool.connection_class == redis.Connection
assert r.connection_pool.connection_kwargs == {
'host': 'myhost',
'port': 6379,
'db': 0,
'password': None,
}
class TestConnectionPoolUnixSocketURLParsing(object):
def test_defaults(self):
pool = redis.ConnectionPool.from_url('unix:///socket')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': None,
}
def test_password(self):
pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': 'mypassword',
}
def test_db_as_argument(self):
pool = redis.ConnectionPool.from_url('unix:///socket', db=1)
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 1,
'password': None,
}
def test_db_in_querystring(self):
pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1)
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 2,
'password': None,
}
def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': None,
'a': '1',
'b': '2'
}
class TestSSLConnectionURLParsing(object):
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
def test_defaults(self):
pool = redis.ConnectionPool.from_url('rediss://localhost')
assert pool.connection_class == redis.SSLConnection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
}
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
def test_cert_reqs_options(self):
import ssl
pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none')
assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE
pool = redis.ConnectionPool.from_url(
'rediss://?ssl_cert_reqs=optional')
assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL
pool = redis.ConnectionPool.from_url(
'rediss://?ssl_cert_reqs=required')
assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED
class TestConnection(object):
def test_on_connect_error(self):
"""
An error in Connection.on_connect should disconnect from the server
see for details: https://github.com/andymccurdy/redis-py/issues/368
"""
# this assumes the Redis server being tested against doesn't have
# 9999 databases ;)
bad_connection = redis.Redis(db=9999)
# an error should be raised on connect
with pytest.raises(redis.RedisError):
bad_connection.info()
pool = bad_connection.connection_pool
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_disconnects_socket(self, r):
"""
If Redis raises a LOADING error, the connection should be
disconnected and a BusyLoadingError raised
"""
with pytest.raises(redis.BusyLoadingError):
r.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
pool = r.connection_pool
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_from_pipeline_immediate_command(self, r):
"""
BusyLoadingErrors should raise from Pipelines that execute a
command immediately, like WATCH does.
"""
pipe = r.pipeline()
with pytest.raises(redis.BusyLoadingError):
pipe.immediate_execute_command('DEBUG', 'ERROR',
'LOADING fake message')
pool = r.connection_pool
assert not pipe.connection
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_from_pipeline(self, r):
"""
BusyLoadingErrors should be raised from a pipeline execution
regardless of the raise_on_error flag.
"""
pipe = r.pipeline()
pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
with pytest.raises(redis.BusyLoadingError):
pipe.execute()
pool = r.connection_pool
assert not pipe.connection
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_read_only_error(self, r):
"READONLY errors get turned in ReadOnlyError exceptions"
with pytest.raises(redis.ReadOnlyError):
r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah')
def test_connect_from_url_tcp(self):
connection = redis.Redis.from_url('redis://localhost')
pool = connection.connection_pool
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
'ConnectionPool',
'Connection',
'host=localhost,port=6379,db=0',
)
def test_connect_from_url_unix(self):
connection = redis.Redis.from_url('unix:///path/to/socket')
pool = connection.connection_pool
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
'ConnectionPool',
'UnixDomainSocketConnection',
'path=/path/to/socket,db=0',
)

View File

@ -0,0 +1,33 @@
from __future__ import with_statement
import pytest
from redis._compat import unichr, u, unicode
from .conftest import r as _redis_client
class TestEncoding(object):
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, decode_responses=True)
def test_simple_encoding(self, r):
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
r['unicode-string'] = unicode_string
cached_val = r['unicode-string']
assert isinstance(cached_val, unicode)
assert unicode_string == cached_val
def test_list_encoding(self, r):
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
result = [unicode_string, unicode_string, unicode_string]
r.rpush('a', *result)
assert r.lrange('a', 0, -1) == result
class TestCommandsAndTokensArentEncoded(object):
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, charset='utf-16')
def test_basic_command(self, r):
r.set('hello', 'world')

View File

@ -0,0 +1,167 @@
from __future__ import with_statement
import pytest
import time
from redis.exceptions import LockError, ResponseError
from redis.lock import Lock, LuaLock
class TestLock(object):
lock_class = Lock
def get_lock(self, redis, *args, **kwargs):
kwargs['lock_class'] = self.lock_class
return redis.lock(*args, **kwargs)
def test_lock(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
assert sr.get('foo') == lock.local.token
assert sr.ttl('foo') == -1
lock.release()
assert sr.get('foo') is None
def test_competing_locks(self, sr):
lock1 = self.get_lock(sr, 'foo')
lock2 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
assert lock2.acquire(blocking=False)
assert not lock1.acquire(blocking=False)
lock2.release()
def test_timeout(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
assert lock.acquire(blocking=False)
assert 8 < sr.ttl('foo') <= 10
lock.release()
def test_float_timeout(self, sr):
lock = self.get_lock(sr, 'foo', timeout=9.5)
assert lock.acquire(blocking=False)
assert 8 < sr.pttl('foo') <= 9500
lock.release()
def test_blocking_timeout(self, sr):
lock1 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2)
start = time.time()
assert not lock2.acquire()
assert (time.time() - start) > 0.2
lock1.release()
def test_context_manager(self, sr):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock:
assert sr.get('foo') == lock.local.token
assert sr.get('foo') is None
def test_high_sleep_raises_error(self, sr):
"If sleep is higher than timeout, it should raise an error"
with pytest.raises(LockError):
self.get_lock(sr, 'foo', timeout=1, sleep=2)
def test_releasing_unlocked_lock_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
with pytest.raises(LockError):
lock.release()
def test_releasing_lock_no_longer_owned_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
lock.acquire(blocking=False)
# manually change the token
sr.set('foo', 'a')
with pytest.raises(LockError):
lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
def test_extend_lock(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
assert lock.acquire(blocking=False)
assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10)
assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extend_lock_float(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10.0)
assert lock.acquire(blocking=False)
assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10.0)
assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extending_unlocked_lock_raises_error(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
with pytest.raises(LockError):
lock.extend(10)
def test_extending_lock_with_no_timeout_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
with pytest.raises(LockError):
lock.extend(10)
lock.release()
def test_extending_lock_no_longer_owned_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
sr.set('foo', 'a')
with pytest.raises(LockError):
lock.extend(10)
class TestLuaLock(TestLock):
lock_class = LuaLock
class TestLockClassSelection(object):
def test_lock_class_argument(self, sr):
lock = sr.lock('foo', lock_class=Lock)
assert type(lock) == Lock
lock = sr.lock('foo', lock_class=LuaLock)
assert type(lock) == LuaLock
def test_cached_lualock_flag(self, sr):
try:
sr._use_lua_lock = True
lock = sr.lock('foo')
assert type(lock) == LuaLock
finally:
sr._use_lua_lock = None
def test_cached_lock_flag(self, sr):
try:
sr._use_lua_lock = False
lock = sr.lock('foo')
assert type(lock) == Lock
finally:
sr._use_lua_lock = None
def test_lua_compatible_server(self, sr, monkeypatch):
@classmethod
def mock_register(cls, redis):
return
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = sr.lock('foo')
assert type(lock) == LuaLock
assert sr._use_lua_lock is True
finally:
sr._use_lua_lock = None
def test_lua_unavailable(self, sr, monkeypatch):
@classmethod
def mock_register(cls, redis):
raise ResponseError()
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = sr.lock('foo')
assert type(lock) == Lock
assert sr._use_lua_lock is False
finally:
sr._use_lua_lock = None

View File

@ -0,0 +1,226 @@
from __future__ import with_statement
import pytest
import redis
from redis._compat import b, u, unichr, unicode
class TestPipeline(object):
def test_pipeline(self, r):
with r.pipeline() as pipe:
pipe.set('a', 'a1').get('a').zadd('z', z1=1).zadd('z', z2=4)
pipe.zincrby('z', 'z1').zrange('z', 0, 5, withscores=True)
assert pipe.execute() == \
[
True,
b('a1'),
True,
True,
2.0,
[(b('z1'), 2.0), (b('z2'), 4)],
]
def test_pipeline_length(self, r):
with r.pipeline() as pipe:
# Initially empty.
assert len(pipe) == 0
assert not pipe
# Fill 'er up!
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
assert len(pipe) == 3
assert pipe
# Execute calls reset(), so empty once again.
pipe.execute()
assert len(pipe) == 0
assert not pipe
def test_pipeline_no_transaction(self, r):
with r.pipeline(transaction=False) as pipe:
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
assert pipe.execute() == [True, True, True]
assert r['a'] == b('a1')
assert r['b'] == b('b1')
assert r['c'] == b('c1')
def test_pipeline_no_transaction_watch(self, r):
r['a'] = 0
with r.pipeline(transaction=False) as pipe:
pipe.watch('a')
a = pipe.get('a')
pipe.multi()
pipe.set('a', int(a) + 1)
assert pipe.execute() == [True]
def test_pipeline_no_transaction_watch_failure(self, r):
r['a'] = 0
with r.pipeline(transaction=False) as pipe:
pipe.watch('a')
a = pipe.get('a')
r['a'] = 'bad'
pipe.multi()
pipe.set('a', int(a) + 1)
with pytest.raises(redis.WatchError):
pipe.execute()
assert r['a'] == b('bad')
def test_exec_error_in_response(self, r):
"""
an invalid pipeline command at exec time adds the exception instance
to the list of returned values
"""
r['c'] = 'a'
with r.pipeline() as pipe:
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
result = pipe.execute(raise_on_error=False)
assert result[0]
assert r['a'] == b('1')
assert result[1]
assert r['b'] == b('2')
# we can't lpush to a key that's a string value, so this should
# be a ResponseError exception
assert isinstance(result[2], redis.ResponseError)
assert r['c'] == b('a')
# since this isn't a transaction, the other commands after the
# error are still executed
assert result[3]
assert r['d'] == b('4')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_exec_error_raised(self, r):
r['c'] = 'a'
with r.pipeline() as pipe:
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 3 (LPUSH c 3) of '
'pipeline caused error: ')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_parse_error_raised(self, r):
with r.pipeline() as pipe:
# the zrem is invalid because we don't pass any keys to it
pipe.set('a', 1).zrem('b').set('b', 2)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 2 (ZREM b) of '
'pipeline caused error: ')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_watch_succeed(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
assert pipe.watching
a_value = pipe.get('a')
b_value = pipe.get('b')
assert a_value == b('1')
assert b_value == b('2')
pipe.multi()
pipe.set('c', 3)
assert pipe.execute() == [True]
assert not pipe.watching
def test_watch_failure(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
r['b'] = 3
pipe.multi()
pipe.get('a')
with pytest.raises(redis.WatchError):
pipe.execute()
assert not pipe.watching
def test_unwatch(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
r['b'] = 3
pipe.unwatch()
assert not pipe.watching
pipe.get('a')
assert pipe.execute() == [b('1')]
def test_transaction_callable(self, r):
r['a'] = 1
r['b'] = 2
has_run = []
def my_transaction(pipe):
a_value = pipe.get('a')
assert a_value in (b('1'), b('2'))
b_value = pipe.get('b')
assert b_value == b('2')
# silly run-once code... incr's "a" so WatchError should be raised
# forcing this all to run again. this should incr "a" once to "2"
if not has_run:
r.incr('a')
has_run.append('it has')
pipe.multi()
pipe.set('c', int(a_value) + int(b_value))
result = r.transaction(my_transaction, 'a', 'b')
assert result == [True]
assert r['c'] == b('4')
def test_exec_error_in_no_transaction_pipeline(self, r):
r['a'] = 1
with r.pipeline(transaction=False) as pipe:
pipe.llen('a')
pipe.expire('a', 100)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 1 (LLEN a) of '
'pipeline caused error: ')
assert r['a'] == b('1')
def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r):
key = unichr(3456) + u('abcd') + unichr(3421)
r[key] = 1
with r.pipeline(transaction=False) as pipe:
pipe.llen(key)
pipe.expire(key, 100)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
expected = unicode('Command # 1 (LLEN %s) of pipeline caused '
'error: ') % key
assert unicode(ex.value).startswith(expected)
assert r[key] == b('1')

View File

@ -0,0 +1,392 @@
from __future__ import with_statement
import pytest
import time
import redis
from redis.exceptions import ConnectionError
from redis._compat import basestring, u, unichr
from .conftest import r as _redis_client
def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
now = time.time()
timeout = now + timeout
while now < timeout:
message = pubsub.get_message(
ignore_subscribe_messages=ignore_subscribe_messages)
if message is not None:
return message
time.sleep(0.01)
now = time.time()
return None
def make_message(type, channel, data, pattern=None):
return {
'type': type,
'pattern': pattern and pattern.encode('utf-8') or None,
'channel': channel.encode('utf-8'),
'data': data.encode('utf-8') if isinstance(data, basestring) else data
}
def make_subscribe_test_data(pubsub, type):
if type == 'channel':
return {
'p': pubsub,
'sub_type': 'subscribe',
'unsub_type': 'unsubscribe',
'sub_func': pubsub.subscribe,
'unsub_func': pubsub.unsubscribe,
'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')]
}
elif type == 'pattern':
return {
'p': pubsub,
'sub_type': 'psubscribe',
'unsub_type': 'punsubscribe',
'sub_func': pubsub.psubscribe,
'unsub_func': pubsub.punsubscribe,
'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')]
}
assert False, 'invalid subscribe type: %s' % type
class TestPubSubSubscribeUnsubscribe(object):
def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func,
unsub_func, keys):
for key in keys:
assert sub_func(key) is None
# should be a message for each channel/pattern we just subscribed to
for i, key in enumerate(keys):
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
for key in keys:
assert unsub_func(key) is None
# should be a message for each channel/pattern we just unsubscribed
# from
for i, key in enumerate(keys):
i = len(keys) - 1 - i
assert wait_for_message(p) == make_message(unsub_type, key, i)
def test_channel_subscribe_unsubscribe(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_subscribe_unsubscribe(**kwargs)
def test_pattern_subscribe_unsubscribe(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_subscribe_unsubscribe(**kwargs)
def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type,
sub_func, unsub_func, keys):
for key in keys:
assert sub_func(key) is None
# should be a message for each channel/pattern we just subscribed to
for i, key in enumerate(keys):
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
# manually disconnect
p.connection.disconnect()
# calling get_message again reconnects and resubscribes
# note, we may not re-subscribe to channels in exactly the same order
# so we have to do some extra checks to make sure we got them all
messages = []
for i in range(len(keys)):
messages.append(wait_for_message(p))
unique_channels = set()
assert len(messages) == len(keys)
for i, message in enumerate(messages):
assert message['type'] == sub_type
assert message['data'] == i + 1
assert isinstance(message['channel'], bytes)
channel = message['channel'].decode('utf-8')
unique_channels.add(channel)
assert len(unique_channels) == len(keys)
for channel in unique_channels:
assert channel in keys
def test_resubscribe_to_channels_on_reconnection(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_resubscribe_on_reconnection(**kwargs)
def test_resubscribe_to_patterns_on_reconnection(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_resubscribe_on_reconnection(**kwargs)
def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func,
unsub_func, keys):
assert p.subscribed is False
sub_func(keys[0])
# we're now subscribed even though we haven't processed the
# reply from the server just yet
assert p.subscribed is True
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
# we're still subscribed
assert p.subscribed is True
# unsubscribe from all channels
unsub_func()
# we're still technically subscribed until we process the
# response messages from the server
assert p.subscribed is True
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
# now we're no longer subscribed as no more messages can be delivered
# to any channels we were listening to
assert p.subscribed is False
# subscribing again flips the flag back
sub_func(keys[0])
assert p.subscribed is True
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
# unsubscribe again
unsub_func()
assert p.subscribed is True
# subscribe to another channel before reading the unsubscribe response
sub_func(keys[1])
assert p.subscribed is True
# read the unsubscribe for key1
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
# we're still subscribed to key2, so subscribed should still be True
assert p.subscribed is True
# read the key2 subscribe message
assert wait_for_message(p) == make_message(sub_type, keys[1], 1)
unsub_func()
# haven't read the message yet, so we're still subscribed
assert p.subscribed is True
assert wait_for_message(p) == make_message(unsub_type, keys[1], 0)
# now we're finally unsubscribed
assert p.subscribed is False
def test_subscribe_property_with_channels(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_subscribed_property(**kwargs)
def test_subscribe_property_with_patterns(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_subscribed_property(**kwargs)
def test_ignore_all_subscribe_messages(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
checks = (
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
(p.punsubscribe, 'f*'),
)
assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
assert p.subscribed is True
assert wait_for_message(p) is None
assert p.subscribed is False
def test_ignore_individual_subscribe_messages(self, r):
p = r.pubsub()
checks = (
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
(p.punsubscribe, 'f*'),
)
assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
assert p.subscribed is True
message = wait_for_message(p, ignore_subscribe_messages=True)
assert message is None
assert p.subscribed is False
class TestPubSubMessages(object):
def setup_method(self, method):
self.message = None
def message_handler(self, message):
self.message = message
def test_published_message_to_channel(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe('foo')
assert r.publish('foo', 'test message') == 1
message = wait_for_message(p)
assert isinstance(message, dict)
assert message == make_message('message', 'foo', 'test message')
def test_published_message_to_pattern(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe('foo')
p.psubscribe('f*')
# 1 to pattern, 1 to channel
assert r.publish('foo', 'test message') == 2
message1 = wait_for_message(p)
message2 = wait_for_message(p)
assert isinstance(message1, dict)
assert isinstance(message2, dict)
expected = [
make_message('message', 'foo', 'test message'),
make_message('pmessage', 'foo', 'test message', pattern='f*')
]
assert message1 in expected
assert message2 in expected
assert message1 != message2
def test_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(foo=self.message_handler)
assert r.publish('foo', 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('message', 'foo', 'test message')
def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(**{'f*': self.message_handler})
assert r.publish('foo', 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('pmessage', 'foo', 'test message',
pattern='f*')
def test_unicode_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
channel = u('uni') + unichr(4456) + u('code')
channels = {channel: self.message_handler}
p.subscribe(**channels)
assert r.publish(channel, 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('message', channel, 'test message')
def test_unicode_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
pattern = u('uni') + unichr(4456) + u('*')
channel = u('uni') + unichr(4456) + u('code')
p.psubscribe(**{pattern: self.message_handler})
assert r.publish(channel, 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('pmessage', channel,
'test message', pattern=pattern)
class TestPubSubAutoDecoding(object):
"These tests only validate that we get unicode values back"
channel = u('uni') + unichr(4456) + u('code')
pattern = u('uni') + unichr(4456) + u('*')
data = u('abc') + unichr(4458) + u('123')
def make_message(self, type, channel, data, pattern=None):
return {
'type': type,
'channel': channel,
'pattern': pattern,
'data': data
}
def setup_method(self, method):
self.message = None
def message_handler(self, message):
self.message = message
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, decode_responses=True)
def test_channel_subscribe_unsubscribe(self, r):
p = r.pubsub()
p.subscribe(self.channel)
assert wait_for_message(p) == self.make_message('subscribe',
self.channel, 1)
p.unsubscribe(self.channel)
assert wait_for_message(p) == self.make_message('unsubscribe',
self.channel, 0)
def test_pattern_subscribe_unsubscribe(self, r):
p = r.pubsub()
p.psubscribe(self.pattern)
assert wait_for_message(p) == self.make_message('psubscribe',
self.pattern, 1)
p.punsubscribe(self.pattern)
assert wait_for_message(p) == self.make_message('punsubscribe',
self.pattern, 0)
def test_channel_publish(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(self.channel)
r.publish(self.channel, self.data)
assert wait_for_message(p) == self.make_message('message',
self.channel,
self.data)
def test_pattern_publish(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(self.pattern)
r.publish(self.channel, self.data)
assert wait_for_message(p) == self.make_message('pmessage',
self.channel,
self.data,
pattern=self.pattern)
def test_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(**{self.channel: self.message_handler})
r.publish(self.channel, self.data)
assert wait_for_message(p) is None
assert self.message == self.make_message('message', self.channel,
self.data)
# test that we reconnected to the correct channel
p.connection.disconnect()
assert wait_for_message(p) is None # should reconnect
new_data = self.data + u('new data')
r.publish(self.channel, new_data)
assert wait_for_message(p) is None
assert self.message == self.make_message('message', self.channel,
new_data)
def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(**{self.pattern: self.message_handler})
r.publish(self.channel, self.data)
assert wait_for_message(p) is None
assert self.message == self.make_message('pmessage', self.channel,
self.data,
pattern=self.pattern)
# test that we reconnected to the correct pattern
p.connection.disconnect()
assert wait_for_message(p) is None # should reconnect
new_data = self.data + u('new data')
r.publish(self.channel, new_data)
assert wait_for_message(p) is None
assert self.message == self.make_message('pmessage', self.channel,
new_data,
pattern=self.pattern)
class TestPubSubRedisDown(object):
def test_channel_subscribe(self, r):
r = redis.Redis(host='localhost', port=6390)
p = r.pubsub()
with pytest.raises(ConnectionError):
p.subscribe('foo')

View File

@ -0,0 +1,82 @@
from __future__ import with_statement
import pytest
from redis import exceptions
from redis._compat import b
multiply_script = """
local value = redis.call('GET', KEYS[1])
value = tonumber(value)
return value * ARGV[1]"""
class TestScripting(object):
@pytest.fixture(autouse=True)
def reset_scripts(self, r):
r.script_flush()
def test_eval(self, r):
r.set('a', 2)
# 2 * 3 == 6
assert r.eval(multiply_script, 1, 'a', 3) == 6
def test_evalsha(self, r):
r.set('a', 2)
sha = r.script_load(multiply_script)
# 2 * 3 == 6
assert r.evalsha(sha, 1, 'a', 3) == 6
def test_evalsha_script_not_loaded(self, r):
r.set('a', 2)
sha = r.script_load(multiply_script)
# remove the script from Redis's cache
r.script_flush()
with pytest.raises(exceptions.NoScriptError):
r.evalsha(sha, 1, 'a', 3)
def test_script_loading(self, r):
# get the sha, then clear the cache
sha = r.script_load(multiply_script)
r.script_flush()
assert r.script_exists(sha) == [False]
r.script_load(multiply_script)
assert r.script_exists(sha) == [True]
def test_script_object(self, r):
r.set('a', 2)
multiply = r.register_script(multiply_script)
assert not multiply.sha
# test evalsha fail -> script load + retry
assert multiply(keys=['a'], args=[3]) == 6
assert multiply.sha
assert r.script_exists(multiply.sha) == [True]
# test first evalsha
assert multiply(keys=['a'], args=[3]) == 6
def test_script_object_in_pipeline(self, r):
multiply = r.register_script(multiply_script)
assert not multiply.sha
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
multiply(keys=['a'], args=[3], client=pipe)
# even though the pipeline wasn't executed yet, we made sure the
# script was loaded and got a valid sha
assert multiply.sha
assert r.script_exists(multiply.sha) == [True]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]
# purge the script from redis's cache and re-run the pipeline
# the multiply script object knows it's sha, so it shouldn't get
# reloaded until pipe.execute()
r.script_flush()
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
assert multiply.sha
multiply(keys=['a'], args=[3], client=pipe)
assert r.script_exists(multiply.sha) == [False]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]

View File

@ -0,0 +1,173 @@
from __future__ import with_statement
import pytest
from redis import exceptions
from redis.sentinel import (Sentinel, SentinelConnectionPool,
MasterNotFoundError, SlaveNotFoundError)
from redis._compat import next
import redis.sentinel
class SentinelTestClient(object):
def __init__(self, cluster, id):
self.cluster = cluster
self.id = id
def sentinel_masters(self):
self.cluster.connection_error_if_down(self)
return {self.cluster.service_name: self.cluster.master}
def sentinel_slaves(self, master_name):
self.cluster.connection_error_if_down(self)
if master_name != self.cluster.service_name:
return []
return self.cluster.slaves
class SentinelTestCluster(object):
def __init__(self, service_name='mymaster', ip='127.0.0.1', port=6379):
self.clients = {}
self.master = {
'ip': ip,
'port': port,
'is_master': True,
'is_sdown': False,
'is_odown': False,
'num-other-sentinels': 0,
}
self.service_name = service_name
self.slaves = []
self.nodes_down = set()
def connection_error_if_down(self, node):
if node.id in self.nodes_down:
raise exceptions.ConnectionError
def client(self, host, port, **kwargs):
return SentinelTestClient(self, (host, port))
@pytest.fixture()
def cluster(request):
def teardown():
redis.sentinel.StrictRedis = saved_StrictRedis
cluster = SentinelTestCluster()
saved_StrictRedis = redis.sentinel.StrictRedis
redis.sentinel.StrictRedis = cluster.client
request.addfinalizer(teardown)
return cluster
@pytest.fixture()
def sentinel(request, cluster):
return Sentinel([('foo', 26379), ('bar', 26379)])
def test_discover_master(sentinel):
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
def test_discover_master_error(sentinel):
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('xxx')
def test_discover_master_sentinel_down(cluster, sentinel):
# Put first sentinel 'foo' down
cluster.nodes_down.add(('foo', 26379))
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
# 'bar' is now first sentinel
assert sentinel.sentinels[0].id == ('bar', 26379)
def test_master_min_other_sentinels(cluster):
sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1)
# min_other_sentinels
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
cluster.master['num-other-sentinels'] = 2
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
def test_master_odown(cluster, sentinel):
cluster.master['is_odown'] = True
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
def test_master_sdown(cluster, sentinel):
cluster.master['is_sdown'] = True
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
def test_discover_slaves(cluster, sentinel):
assert sentinel.discover_slaves('mymaster') == []
cluster.slaves = [
{'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False},
{'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False},
]
assert sentinel.discover_slaves('mymaster') == [
('slave0', 1234), ('slave1', 1234)]
# slave0 -> ODOWN
cluster.slaves[0]['is_odown'] = True
assert sentinel.discover_slaves('mymaster') == [
('slave1', 1234)]
# slave1 -> SDOWN
cluster.slaves[1]['is_sdown'] = True
assert sentinel.discover_slaves('mymaster') == []
cluster.slaves[0]['is_odown'] = False
cluster.slaves[1]['is_sdown'] = False
# node0 -> DOWN
cluster.nodes_down.add(('foo', 26379))
assert sentinel.discover_slaves('mymaster') == [
('slave0', 1234), ('slave1', 1234)]
def test_master_for(cluster, sentinel):
master = sentinel.master_for('mymaster', db=9)
assert master.ping()
assert master.connection_pool.master_address == ('127.0.0.1', 6379)
# Use internal connection check
master = sentinel.master_for('mymaster', db=9, check_connection=True)
assert master.ping()
def test_slave_for(cluster, sentinel):
cluster.slaves = [
{'ip': '127.0.0.1', 'port': 6379,
'is_odown': False, 'is_sdown': False},
]
slave = sentinel.slave_for('mymaster', db=9)
assert slave.ping()
def test_slave_for_slave_not_found_error(cluster, sentinel):
cluster.master['is_odown'] = True
slave = sentinel.slave_for('mymaster', db=9)
with pytest.raises(SlaveNotFoundError):
slave.ping()
def test_slave_round_robin(cluster, sentinel):
cluster.slaves = [
{'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False},
{'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False},
]
pool = SentinelConnectionPool('mymaster', sentinel)
rotator = pool.rotate_slaves()
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
# Fallback to master
assert next(rotator) == ('127.0.0.1', 6379)
with pytest.raises(SlaveNotFoundError):
next(rotator)

362
client/openresty/ledis.lua Normal file
View File

@ -0,0 +1,362 @@
--- refer from openresty redis lib
local sub = string.sub
local byte = string.byte
local tcp = ngx.socket.tcp
local concat = table.concat
local null = ngx.null
local pairs = pairs
local unpack = unpack
local setmetatable = setmetatable
local tonumber = tonumber
local error = error
local ok, new_tab = pcall(require, "table.new")
if not ok then
new_tab = function (narr, nrec) return {} end
end
local _M = new_tab(0, 155)
_M._VERSION = '0.01'
local commands = {
--[[kv]]
"decr",
"decrby",
"del",
"exists",
"get",
"getset",
"incr",
"incrby",
"mget",
"mset",
"set",
"setnx",
--[[hash]]
"hdel",
"hexists",
"hget",
"hgetall",
"hincrby",
"hkeys",
"hlen",
"hmget",
--[["hmset",]]
"hset",
"hvals",
"hclear",
--[[list]]
"lindex",
"llen",
"lpop",
"lrange",
"lpush",
"rpop",
"rpush",
"lclear",
--[[zset]]
"zadd",
"zcard",
"zcount",
"zincrby",
"zrange",
"zrangebyscore",
"zrank",
"zrem",
"zremrangebyrank",
"zremrangebyscore",
"zrevrange",
"zrevrank",
"zrevrangebyscore",
"zscore",
"zclear",
--[[server]]
"ping",
"echo",
"select"
}
local mt = { __index = _M }
function _M.new(self)
local sock, err = tcp()
if not sock then
return nil, err
end
return setmetatable({ sock = sock }, mt)
end
function _M.set_timeout(self, timeout)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:settimeout(timeout)
end
function _M.connect(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:connect(...)
end
function _M.set_keepalive(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:setkeepalive(...)
end
function _M.get_reused_times(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:getreusedtimes()
end
local function close(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:close()
end
_M.close = close
local function _read_reply(self, sock)
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
local prefix = byte(line)
if prefix == 36 then -- char '$'
-- print("bulk reply")
local size = tonumber(sub(line, 2))
if size < 0 then
return null
end
local data, err = sock:receive(size)
if not data then
if err == "timeout" then
sock:close()
end
return nil, err
end
local dummy, err = sock:receive(2) -- ignore CRLF
if not dummy then
return nil, err
end
return data
elseif prefix == 43 then -- char '+'
-- print("status reply")
return sub(line, 2)
elseif prefix == 42 then -- char '*'
local n = tonumber(sub(line, 2))
-- print("multi-bulk reply: ", n)
if n < 0 then
return null
end
local vals = new_tab(n, 0);
local nvals = 0
for i = 1, n do
local res, err = _read_reply(self, sock)
if res then
nvals = nvals + 1
vals[nvals] = res
elseif res == nil then
return nil, err
else
-- be a valid redis error value
nvals = nvals + 1
vals[nvals] = {false, err}
end
end
return vals
elseif prefix == 58 then -- char ':'
-- print("integer reply")
return tonumber(sub(line, 2))
elseif prefix == 45 then -- char '-'
-- print("error reply: ", n)
return false, sub(line, 2)
else
return nil, "unkown prefix: \"" .. prefix .. "\""
end
end
local function _gen_req(args)
local nargs = #args
local req = new_tab(nargs + 1, 0)
req[1] = "*" .. nargs .. "\r\n"
local nbits = 1
for i = 1, nargs do
local arg = args[i]
nbits = nbits + 1
if not arg then
req[nbits] = "$-1\r\n"
else
if type(arg) ~= "string" then
arg = tostring(arg)
end
req[nbits] = "$" .. #arg .. "\r\n" .. arg .. "\r\n"
end
end
-- it is faster to do string concatenation on the Lua land
return concat(req)
end
local function _do_cmd(self, ...)
local args = {...}
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local req = _gen_req(args)
local reqs = self._reqs
if reqs then
reqs[#reqs + 1] = req
return
end
-- print("request: ", table.concat(req))
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
return _read_reply(self, sock)
end
function _M.read_reply(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local res, err = _read_reply(self, sock)
return res, err
end
for i = 1, #commands do
local cmd = commands[i]
_M[cmd] =
function (self, ...)
return _do_cmd(self, cmd, ...)
end
end
function _M.hmset(self, hashname, ...)
local args = {...}
if #args == 1 then
local t = args[1]
local n = 0
for k, v in pairs(t) do
n = n + 2
end
local array = new_tab(n, 0)
local i = 0
for k, v in pairs(t) do
array[i + 1] = k
array[i + 2] = v
i = i + 2
end
-- print("key", hashname)
return _do_cmd(self, "hmset", hashname, unpack(array))
end
-- backwards compatibility
return _do_cmd(self, "hmset", hashname, ...)
end
function _M.array_to_hash(self, t)
local n = #t
-- print("n = ", n)
local h = new_tab(0, n / 2)
for i = 1, n, 2 do
h[t[i]] = t[i + 1]
end
return h
end
function _M.add_commands(...)
local cmds = {...}
for i = 1, #cmds do
local cmd = cmds[i]
_M[cmd] =
function (self, ...)
return _do_cmd(self, cmd, ...)
end
end
end
return _M

View File

@ -3,7 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/garyburd/redigo/redis" "github.com/siddontang/ledisdb/client/go/redis"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"

View File

@ -6,7 +6,8 @@
"compression": false, "compression": false,
"block_size": 32768, "block_size": 32768,
"write_buffer_size": 67108864, "write_buffer_size": 67108864,
"cache_size": 524288000 "cache_size": 524288000,
"max_open_files":1024
} }
}, },

View File

@ -5,7 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/log"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"

View File

@ -4,7 +4,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"io" "io"
"os" "os"
) )
@ -73,7 +73,9 @@ func (l *Ledis) Dump(w io.Writer) error {
return err return err
} }
it := sp.Iterator(nil, nil, leveldb.RangeClose, 0, -1) it := sp.NewIterator()
it.SeekToFirst()
var key []byte var key []byte
var value []byte var value []byte
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {

View File

@ -2,7 +2,7 @@ package ledis
import ( import (
"bytes" "bytes"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"os" "os"
"testing" "testing"
) )
@ -59,7 +59,7 @@ func TestDump(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1) it := master.ldb.RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
key := it.Key() key := it.Key()
value := it.Value() value := it.Value()

View File

@ -3,8 +3,8 @@ package ledis
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/log"
"path" "path"
"sync" "sync"
"time" "time"
@ -46,6 +46,7 @@ type Ledis struct {
binlog *BinLog binlog *BinLog
quit chan struct{} quit chan struct{}
jobs *sync.WaitGroup
} }
func Open(configJson json.RawMessage) (*Ledis, error) { func Open(configJson json.RawMessage) (*Ledis, error) {
@ -75,6 +76,7 @@ func OpenWithConfig(cfg *Config) (*Ledis, error) {
l := new(Ledis) l := new(Ledis)
l.quit = make(chan struct{}) l.quit = make(chan struct{})
l.jobs = new(sync.WaitGroup)
l.ldb = ldb l.ldb = ldb
@ -118,6 +120,7 @@ func newDB(l *Ledis, index uint8) *DB {
func (l *Ledis) Close() { func (l *Ledis) Close() {
close(l.quit) close(l.quit)
l.jobs.Wait()
l.ldb.Close() l.ldb.Close()
@ -156,19 +159,23 @@ func (l *Ledis) activeExpireCycle() {
executors[i] = db.newEliminator() executors[i] = db.newEliminator()
} }
l.jobs.Add(1)
go func() { go func() {
tick := time.NewTicker(1 * time.Second) tick := time.NewTicker(1 * time.Second)
for { end := false
for !end {
select { select {
case <-tick.C: case <-tick.C:
for _, eli := range executors { for _, eli := range executors {
eli.active() eli.active()
} }
case <-l.quit: case <-l.quit:
end = true
break break
} }
} }
tick.Stop() tick.Stop()
l.jobs.Done()
}() }()
} }

View File

@ -5,7 +5,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/log"
"io" "io"
"os" "os"
) )

View File

@ -3,14 +3,14 @@ package ledis
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"os" "os"
"path" "path"
"testing" "testing"
) )
func checkLedisEqual(master *Ledis, slave *Ledis) error { func checkLedisEqual(master *Ledis, slave *Ledis) error {
it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1) it := master.ldb.RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
key := it.Key() key := it.Key()
value := it.Value() value := it.Value()

View File

@ -3,7 +3,7 @@ package ledis
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"time" "time"
) )
@ -132,7 +132,7 @@ func (db *DB) hDelete(t *tx, key []byte) int64 {
stop := db.hEncodeStopKey(key) stop := db.hEncodeStopKey(key)
var num int64 = 0 var num int64 = 0
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
t.Delete(it.Key()) t.Delete(it.Key())
num++ num++
@ -232,10 +232,11 @@ func (db *DB) HMset(key []byte, args ...FVPair) error {
return err return err
} }
func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) { func (db *DB) HMget(key []byte, args ...[]byte) ([]interface{}, error) {
var ek []byte var ek []byte
var v []byte
var err error it := db.db.NewIterator()
defer it.Close()
r := make([]interface{}, len(args)) r := make([]interface{}, len(args))
for i := 0; i < len(args); i++ { for i := 0; i < len(args); i++ {
@ -245,17 +246,13 @@ func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) {
ek = db.hEncodeHashKey(key, args[i]) ek = db.hEncodeHashKey(key, args[i])
if v, err = db.db.Get(ek); err != nil { r[i] = it.Find(ek)
return nil, err
}
r[i] = v
} }
return r, nil return r, nil
} }
func (db *DB) HDel(key []byte, args [][]byte) (int64, error) { func (db *DB) HDel(key []byte, args ...[]byte) (int64, error) {
t := db.hashTx t := db.hashTx
var ek []byte var ek []byte
@ -265,6 +262,9 @@ func (db *DB) HDel(key []byte, args [][]byte) (int64, error) {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
it := db.db.NewIterator()
defer it.Close()
var num int64 = 0 var num int64 = 0
for i := 0; i < len(args); i++ { for i := 0; i < len(args); i++ {
if err := checkHashKFSize(key, args[i]); err != nil { if err := checkHashKFSize(key, args[i]); err != nil {
@ -273,9 +273,8 @@ func (db *DB) HDel(key []byte, args [][]byte) (int64, error) {
ek = db.hEncodeHashKey(key, args[i]) ek = db.hEncodeHashKey(key, args[i])
if v, err = db.db.Get(ek); err != nil { v = it.Find(ek)
return 0, err if v == nil {
} else if v == nil {
continue continue
} else { } else {
num++ num++
@ -355,7 +354,7 @@ func (db *DB) HGetAll(key []byte) ([]interface{}, error) {
v := make([]interface{}, 0, 16) v := make([]interface{}, 0, 16)
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
_, k, err := db.hDecodeHashKey(it.Key()) _, k, err := db.hDecodeHashKey(it.Key())
if err != nil { if err != nil {
@ -380,7 +379,7 @@ func (db *DB) HKeys(key []byte) ([]interface{}, error) {
v := make([]interface{}, 0, 16) v := make([]interface{}, 0, 16)
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
_, k, err := db.hDecodeHashKey(it.Key()) _, k, err := db.hDecodeHashKey(it.Key())
if err != nil { if err != nil {
@ -404,7 +403,7 @@ func (db *DB) HValues(key []byte) ([]interface{}, error) {
v := make([]interface{}, 0, 16) v := make([]interface{}, 0, 16)
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
v = append(v, it.Value()) v = append(v, it.Value())
} }
@ -443,7 +442,7 @@ func (db *DB) hFlush() (drop int64, err error) {
maxKey[0] = db.index maxKey[0] = db.index
maxKey[1] = hSizeType + 1 maxKey[1] = hSizeType + 1
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
t.Delete(it.Key()) t.Delete(it.Key())
drop++ drop++
@ -485,7 +484,7 @@ func (db *DB) HScan(key []byte, field []byte, count int, inclusive bool) ([]FVPa
rangeType = leveldb.RangeOpen rangeType = leveldb.RangeOpen
} }
it := db.db.Iterator(minKey, maxKey, rangeType, 0, count) it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
if _, f, err := db.hDecodeHashKey(it.Key()); err != nil { if _, f, err := db.hDecodeHashKey(it.Key()); err != nil {
continue continue

View File

@ -32,11 +32,28 @@ func TestDBHash(t *testing.T) {
key := []byte("testdb_hash_a") key := []byte("testdb_hash_a")
if n, err := db.HSet(key, []byte("a"), []byte("hello world")); err != nil { if n, err := db.HSet(key, []byte("a"), []byte("hello world 1")); err != nil {
t.Fatal(err) t.Fatal(err)
} else if n != 1 { } else if n != 1 {
t.Fatal(n) t.Fatal(n)
} }
if n, err := db.HSet(key, []byte("b"), []byte("hello world 2")); err != nil {
t.Fatal(err)
} else if n != 1 {
t.Fatal(n)
}
ay, _ := db.HMget(key, []byte("a"), []byte("b"))
if v1, _ := ay[0].([]byte); string(v1) != "hello world 1" {
t.Fatal(string(v1))
}
if v2, _ := ay[1].([]byte); string(v2) != "hello world 2" {
t.Fatal(string(v2))
}
} }
func TestDBHScan(t *testing.T) { func TestDBHScan(t *testing.T) {

View File

@ -2,7 +2,7 @@ package ledis
import ( import (
"errors" "errors"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"time" "time"
) )
@ -204,18 +204,15 @@ func (db *DB) IncryBy(key []byte, increment int64) (int64, error) {
func (db *DB) MGet(keys ...[]byte) ([]interface{}, error) { func (db *DB) MGet(keys ...[]byte) ([]interface{}, error) {
values := make([]interface{}, len(keys)) values := make([]interface{}, len(keys))
var err error it := db.db.NewIterator()
var value []byte defer it.Close()
for i := range keys { for i := range keys {
if err := checkKeySize(keys[i]); err != nil { if err := checkKeySize(keys[i]); err != nil {
return nil, err return nil, err
} }
if value, err = db.db.Get(db.encodeKVKey(keys[i])); err != nil { values[i] = it.Find(db.encodeKVKey(keys[i]))
return nil, err
}
values[i] = value
} }
return values, nil return values, nil
@ -319,7 +316,7 @@ func (db *DB) flush() (drop int64, err error) {
minKey := db.encodeKVMinKey() minKey := db.encodeKVMinKey()
maxKey := db.encodeKVMaxKey() maxKey := db.encodeKVMaxKey()
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
t.Delete(it.Key()) t.Delete(it.Key())
drop++ drop++
@ -362,7 +359,7 @@ func (db *DB) Scan(key []byte, count int, inclusive bool) ([]KVPair, error) {
rangeType = leveldb.RangeOpen rangeType = leveldb.RangeOpen
} }
it := db.db.Iterator(minKey, maxKey, rangeType, 0, count) it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
if key, err := db.decodeKVKey(it.Key()); err != nil { if key, err := db.decodeKVKey(it.Key()); err != nil {
continue continue

View File

@ -19,11 +19,28 @@ func TestKVCodec(t *testing.T) {
func TestDBKV(t *testing.T) { func TestDBKV(t *testing.T) {
db := getTestDB() db := getTestDB()
key := []byte("testdb_kv_a") key1 := []byte("testdb_kv_a")
if err := db.Set(key, []byte("hello world")); err != nil { if err := db.Set(key1, []byte("hello world 1")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
key2 := []byte("testdb_kv_b")
if err := db.Set(key2, []byte("hello world 2")); err != nil {
t.Fatal(err)
}
ay, _ := db.MGet(key1, key2)
if v1, _ := ay[0].([]byte); string(v1) != "hello world 1" {
t.Fatal(string(v1))
}
if v2, _ := ay[1].([]byte); string(v2) != "hello world 2" {
t.Fatal(string(v2))
}
} }
func TestDBScan(t *testing.T) { func TestDBScan(t *testing.T) {

View File

@ -3,7 +3,7 @@ package ledis
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"time" "time"
) )
@ -200,7 +200,7 @@ func (db *DB) lDelete(t *tx, key []byte) int64 {
startKey := db.lEncodeListKey(key, headSeq) startKey := db.lEncodeListKey(key, headSeq)
stopKey := db.lEncodeListKey(key, tailSeq) stopKey := db.lEncodeListKey(key, tailSeq)
it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) it := db.db.RangeLimitIterator(startKey, stopKey, leveldb.RangeClose, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
t.Delete(it.Key()) t.Delete(it.Key())
num++ num++
@ -361,7 +361,7 @@ func (db *DB) LRange(key []byte, start int32, stop int32) ([]interface{}, error)
startKey := db.lEncodeListKey(key, startSeq) startKey := db.lEncodeListKey(key, startSeq)
stopKey := db.lEncodeListKey(key, stopSeq) stopKey := db.lEncodeListKey(key, stopSeq)
it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) it := db.db.RangeLimitIterator(startKey, stopKey, leveldb.RangeClose, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
v = append(v, it.Value()) v = append(v, it.Value())
} }
@ -408,7 +408,7 @@ func (db *DB) lFlush() (drop int64, err error) {
maxKey[0] = db.index maxKey[0] = db.index
maxKey[1] = lMetaType + 1 maxKey[1] = lMetaType + 1
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
t.Delete(it.Key()) t.Delete(it.Key())
drop++ drop++

View File

@ -3,7 +3,7 @@ package ledis
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"time" "time"
) )
@ -119,7 +119,7 @@ func (db *DB) expFlush(t *tx, expType byte) (err error) {
maxKey[0] = db.index maxKey[0] = db.index
maxKey[1] = expMetaType + 1 maxKey[1] = expMetaType + 1
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
t.Delete(it.Key()) t.Delete(it.Key())
drop++ drop++
@ -173,7 +173,7 @@ func (eli *elimination) active() {
continue continue
} }
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
for it.Valid() { for it.Valid() {
for i := 1; i < 512 && it.Valid(); i++ { for i := 1; i < 512 && it.Valid(); i++ {
expKeys = append(expKeys, it.Key(), it.Value()) expKeys = append(expKeys, it.Key(), it.Value())

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"time" "time"
) )
@ -434,7 +434,7 @@ func (db *DB) ZCount(key []byte, min int64, max int64) (int64, error) {
rangeType := leveldb.RangeROpen rangeType := leveldb.RangeROpen
it := db.db.Iterator(minKey, maxKey, rangeType, 0, -1) it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, -1)
var n int64 = 0 var n int64 = 0
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
n++ n++
@ -459,16 +459,16 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) {
if s, err := Int64(v, err); err != nil { if s, err := Int64(v, err); err != nil {
return 0, err return 0, err
} else { } else {
var it *leveldb.Iterator var it *leveldb.RangeLimitIterator
sk := db.zEncodeScoreKey(key, member, s) sk := db.zEncodeScoreKey(key, member, s)
if !reverse { if !reverse {
minKey := db.zEncodeStartScoreKey(key, MinScore) minKey := db.zEncodeStartScoreKey(key, MinScore)
it = db.db.Iterator(minKey, sk, leveldb.RangeClose, 0, -1) it = db.db.RangeLimitIterator(minKey, sk, leveldb.RangeClose, 0, -1)
} else { } else {
maxKey := db.zEncodeStopScoreKey(key, MaxScore) maxKey := db.zEncodeStopScoreKey(key, MaxScore)
it = db.db.RevIterator(sk, maxKey, leveldb.RangeClose, 0, -1) it = db.db.RevRangeLimitIterator(sk, maxKey, leveldb.RangeClose, 0, -1)
} }
var lastKey []byte = nil var lastKey []byte = nil
@ -492,14 +492,14 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) {
return -1, nil return -1, nil
} }
func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, reverse bool) *leveldb.Iterator { func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, reverse bool) *leveldb.RangeLimitIterator {
minKey := db.zEncodeStartScoreKey(key, min) minKey := db.zEncodeStartScoreKey(key, min)
maxKey := db.zEncodeStopScoreKey(key, max) maxKey := db.zEncodeStopScoreKey(key, max)
if !reverse { if !reverse {
return db.db.Iterator(minKey, maxKey, leveldb.RangeClose, offset, limit) return db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeClose, offset, limit)
} else { } else {
return db.db.RevIterator(minKey, maxKey, leveldb.RangeClose, offset, limit) return db.db.RevRangeLimitIterator(minKey, maxKey, leveldb.RangeClose, offset, limit)
} }
} }
@ -567,7 +567,7 @@ func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset i
} }
v := make([]interface{}, 0, nv) v := make([]interface{}, 0, nv)
var it *leveldb.Iterator var it *leveldb.RangeLimitIterator
//if reverse and offset is 0, limit < 0, we may use forward iterator then reverse //if reverse and offset is 0, limit < 0, we may use forward iterator then reverse
//because leveldb iterator prev is slower than next //because leveldb iterator prev is slower than next
@ -745,7 +745,7 @@ func (db *DB) zFlush() (drop int64, err error) {
maxKey[0] = db.index maxKey[0] = db.index
maxKey[1] = zScoreType + 1 maxKey[1] = zScoreType + 1
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
t.Delete(it.Key()) t.Delete(it.Key())
drop++ drop++
@ -788,7 +788,7 @@ func (db *DB) ZScan(key []byte, member []byte, count int, inclusive bool) ([]Sco
rangeType = leveldb.RangeOpen rangeType = leveldb.RangeOpen
} }
it := db.db.Iterator(minKey, maxKey, rangeType, 0, count) it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
if _, m, err := db.zDecodeSetKey(it.Key()); err != nil { if _, m, err := db.zDecodeSetKey(it.Key()); err != nil {
continue continue

View File

@ -1,7 +1,7 @@
package ledis package ledis
import ( import (
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"sync" "sync"
) )

59
leveldb/batch.go Normal file
View File

@ -0,0 +1,59 @@
package leveldb
// #cgo LDFLAGS: -lleveldb
// #include "leveldb/c.h"
import "C"
import (
"unsafe"
)
type WriteBatch struct {
db *DB
wbatch *C.leveldb_writebatch_t
}
func (w *WriteBatch) Close() {
C.leveldb_writebatch_destroy(w.wbatch)
}
func (w *WriteBatch) Put(key, value []byte) {
var k, v *C.char
if len(key) != 0 {
k = (*C.char)(unsafe.Pointer(&key[0]))
}
if len(value) != 0 {
v = (*C.char)(unsafe.Pointer(&value[0]))
}
lenk := len(key)
lenv := len(value)
C.leveldb_writebatch_put(w.wbatch, k, C.size_t(lenk), v, C.size_t(lenv))
}
func (w *WriteBatch) Delete(key []byte) {
C.leveldb_writebatch_delete(w.wbatch,
(*C.char)(unsafe.Pointer(&key[0])), C.size_t(len(key)))
}
func (w *WriteBatch) Commit() error {
return w.commit(w.db.writeOpts)
}
func (w *WriteBatch) SyncCommit() error {
return w.commit(w.db.syncWriteOpts)
}
func (w *WriteBatch) Rollback() {
C.leveldb_writebatch_clear(w.wbatch)
}
func (w *WriteBatch) commit(wb *WriteOptions) error {
var errStr *C.char
C.leveldb_write(w.db.db, wb.Opt, w.wbatch, &errStr)
if errStr != nil {
return saveError(errStr)
}
return nil
}

18
leveldb/cache.go Normal file
View File

@ -0,0 +1,18 @@
package leveldb
// #cgo LDFLAGS: -lleveldb
// #include <stdint.h>
// #include "leveldb/c.h"
import "C"
type Cache struct {
Cache *C.leveldb_cache_t
}
func NewLRUCache(capacity int) *Cache {
return &Cache{C.leveldb_cache_create_lru(C.size_t(capacity))}
}
func (c *Cache) Close() {
C.leveldb_cache_destroy(c.Cache)
}

328
leveldb/db.go Normal file
View File

@ -0,0 +1,328 @@
package leveldb
/*
#cgo LDFLAGS: -lleveldb
#include <leveldb/c.h>
*/
import "C"
import (
"encoding/json"
"os"
"unsafe"
)
const defaultFilterBits int = 10
type Config struct {
Path string `json:"path"`
Compression bool `json:"compression"`
BlockSize int `json:"block_size"`
WriteBufferSize int `json:"write_buffer_size"`
CacheSize int `json:"cache_size"`
MaxOpenFiles int `json:"max_open_files"`
}
type DB struct {
cfg *Config
db *C.leveldb_t
opts *Options
//for default read and write options
readOpts *ReadOptions
writeOpts *WriteOptions
iteratorOpts *ReadOptions
syncWriteOpts *WriteOptions
cache *Cache
filter *FilterPolicy
}
func Open(configJson json.RawMessage) (*DB, error) {
cfg := new(Config)
err := json.Unmarshal(configJson, cfg)
if err != nil {
return nil, err
}
return OpenWithConfig(cfg)
}
func OpenWithConfig(cfg *Config) (*DB, error) {
if err := os.MkdirAll(cfg.Path, os.ModePerm); err != nil {
return nil, err
}
db := new(DB)
db.cfg = cfg
if err := db.open(); err != nil {
return nil, err
}
return db, nil
}
func (db *DB) open() error {
db.opts = db.initOptions(db.cfg)
db.readOpts = NewReadOptions()
db.writeOpts = NewWriteOptions()
db.iteratorOpts = NewReadOptions()
db.iteratorOpts.SetFillCache(false)
db.syncWriteOpts = NewWriteOptions()
db.syncWriteOpts.SetSync(true)
var errStr *C.char
ldbname := C.CString(db.cfg.Path)
defer C.leveldb_free(unsafe.Pointer(ldbname))
db.db = C.leveldb_open(db.opts.Opt, ldbname, &errStr)
if errStr != nil {
return saveError(errStr)
}
return nil
}
func (db *DB) initOptions(cfg *Config) *Options {
opts := NewOptions()
opts.SetCreateIfMissing(true)
if cfg.CacheSize <= 0 {
cfg.CacheSize = 4 * 1024 * 1024
}
db.cache = NewLRUCache(cfg.CacheSize)
opts.SetCache(db.cache)
//we must use bloomfilter
db.filter = NewBloomFilter(defaultFilterBits)
opts.SetFilterPolicy(db.filter)
if !cfg.Compression {
opts.SetCompression(NoCompression)
}
if cfg.BlockSize <= 0 {
cfg.BlockSize = 4 * 1024
}
opts.SetBlockSize(cfg.BlockSize)
if cfg.WriteBufferSize <= 0 {
cfg.WriteBufferSize = 4 * 1024 * 1024
}
opts.SetWriteBufferSize(cfg.WriteBufferSize)
if cfg.MaxOpenFiles < 1024 {
cfg.MaxOpenFiles = 1024
}
opts.SetMaxOpenFiles(cfg.MaxOpenFiles)
return opts
}
func (db *DB) Close() {
C.leveldb_close(db.db)
db.db = nil
db.opts.Close()
if db.cache != nil {
db.cache.Close()
}
if db.filter != nil {
db.filter.Close()
}
db.readOpts.Close()
db.writeOpts.Close()
db.iteratorOpts.Close()
db.syncWriteOpts.Close()
}
func (db *DB) Destroy() error {
path := db.cfg.Path
db.Close()
opts := NewOptions()
defer opts.Close()
var errStr *C.char
ldbname := C.CString(path)
defer C.leveldb_free(unsafe.Pointer(ldbname))
C.leveldb_destroy_db(opts.Opt, ldbname, &errStr)
if errStr != nil {
return saveError(errStr)
}
return nil
}
func (db *DB) Clear() error {
bc := db.NewWriteBatch()
defer bc.Close()
var err error
it := db.NewIterator()
it.SeekToFirst()
num := 0
for ; it.Valid(); it.Next() {
bc.Delete(it.Key())
num++
if num == 1000 {
num = 0
if err = bc.Commit(); err != nil {
return err
}
}
}
err = bc.Commit()
return err
}
func (db *DB) Put(key, value []byte) error {
return db.put(db.writeOpts, key, value)
}
func (db *DB) SyncPut(key, value []byte) error {
return db.put(db.syncWriteOpts, key, value)
}
func (db *DB) Get(key []byte) ([]byte, error) {
return db.get(db.readOpts, key)
}
func (db *DB) Delete(key []byte) error {
return db.delete(db.writeOpts, key)
}
func (db *DB) SyncDelete(key []byte) error {
return db.delete(db.syncWriteOpts, key)
}
func (db *DB) NewWriteBatch() *WriteBatch {
wb := &WriteBatch{
db: db,
wbatch: C.leveldb_writebatch_create(),
}
return wb
}
func (db *DB) NewSnapshot() *Snapshot {
s := &Snapshot{
db: db,
snap: C.leveldb_create_snapshot(db.db),
readOpts: NewReadOptions(),
iteratorOpts: NewReadOptions(),
}
s.readOpts.SetSnapshot(s)
s.iteratorOpts.SetSnapshot(s)
s.iteratorOpts.SetFillCache(false)
return s
}
func (db *DB) NewIterator() *Iterator {
it := new(Iterator)
it.it = C.leveldb_create_iterator(db.db, db.iteratorOpts.Opt)
return it
}
func (db *DB) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorForward)
}
func (db *DB) RevRangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorBackward)
}
//limit < 0, unlimit
//offset must >= 0, if < 0, will get nothing
func (db *DB) RangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorForward)
}
//limit < 0, unlimit
//offset must >= 0, if < 0, will get nothing
func (db *DB) RevRangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorBackward)
}
func (db *DB) put(wo *WriteOptions, key, value []byte) error {
var errStr *C.char
var k, v *C.char
if len(key) != 0 {
k = (*C.char)(unsafe.Pointer(&key[0]))
}
if len(value) != 0 {
v = (*C.char)(unsafe.Pointer(&value[0]))
}
lenk := len(key)
lenv := len(value)
C.leveldb_put(
db.db, wo.Opt, k, C.size_t(lenk), v, C.size_t(lenv), &errStr)
if errStr != nil {
return saveError(errStr)
}
return nil
}
func (db *DB) get(ro *ReadOptions, key []byte) ([]byte, error) {
var errStr *C.char
var vallen C.size_t
var k *C.char
if len(key) != 0 {
k = (*C.char)(unsafe.Pointer(&key[0]))
}
value := C.leveldb_get(
db.db, ro.Opt, k, C.size_t(len(key)), &vallen, &errStr)
if errStr != nil {
return nil, saveError(errStr)
}
if value == nil {
return nil, nil
}
defer C.leveldb_free(unsafe.Pointer(value))
return C.GoBytes(unsafe.Pointer(value), C.int(vallen)), nil
}
func (db *DB) delete(wo *WriteOptions, key []byte) error {
var errStr *C.char
var k *C.char
if len(key) != 0 {
k = (*C.char)(unsafe.Pointer(&key[0]))
}
C.leveldb_delete(
db.db, wo.Opt, k, C.size_t(len(key)), &errStr)
if errStr != nil {
return saveError(errStr)
}
return nil
}

19
leveldb/filterpolicy.go Normal file
View File

@ -0,0 +1,19 @@
package leveldb
// #cgo LDFLAGS: -lleveldb
// #include <stdlib.h>
// #include "leveldb/c.h"
import "C"
type FilterPolicy struct {
Policy *C.leveldb_filterpolicy_t
}
func NewBloomFilter(bitsPerKey int) *FilterPolicy {
policy := C.leveldb_filterpolicy_create_bloom(C.int(bitsPerKey))
return &FilterPolicy{policy}
}
func (fp *FilterPolicy) Close() {
C.leveldb_filterpolicy_destroy(fp.Policy)
}

235
leveldb/iterator.go Normal file
View File

@ -0,0 +1,235 @@
package leveldb
// #cgo LDFLAGS: -lleveldb
// #include <stdlib.h>
// #include "leveldb/c.h"
import "C"
import (
"bytes"
"unsafe"
)
const (
IteratorForward uint8 = 0
IteratorBackward uint8 = 1
)
const (
RangeClose uint8 = 0x00
RangeLOpen uint8 = 0x01
RangeROpen uint8 = 0x10
RangeOpen uint8 = 0x11
)
//min must less or equal than max
//range type:
//close: [min, max]
//open: (min, max)
//lopen: (min, max]
//ropen: [min, max)
type Range struct {
Min []byte
Max []byte
Type uint8
}
type Iterator struct {
it *C.leveldb_iterator_t
}
func (it *Iterator) Key() []byte {
var klen C.size_t
kdata := C.leveldb_iter_key(it.it, &klen)
if kdata == nil {
return nil
}
return C.GoBytes(unsafe.Pointer(kdata), C.int(klen))
}
func (it *Iterator) Value() []byte {
var vlen C.size_t
vdata := C.leveldb_iter_value(it.it, &vlen)
if vdata == nil {
return nil
}
return C.GoBytes(unsafe.Pointer(vdata), C.int(vlen))
}
func (it *Iterator) Close() {
C.leveldb_iter_destroy(it.it)
it.it = nil
}
func (it *Iterator) Valid() bool {
return ucharToBool(C.leveldb_iter_valid(it.it))
}
func (it *Iterator) Next() {
C.leveldb_iter_next(it.it)
}
func (it *Iterator) Prev() {
C.leveldb_iter_prev(it.it)
}
func (it *Iterator) SeekToFirst() {
C.leveldb_iter_seek_to_first(it.it)
}
func (it *Iterator) SeekToLast() {
C.leveldb_iter_seek_to_last(it.it)
}
func (it *Iterator) Seek(key []byte) {
C.leveldb_iter_seek(it.it, (*C.char)(unsafe.Pointer(&key[0])), C.size_t(len(key)))
}
func (it *Iterator) Find(key []byte) []byte {
it.Seek(key)
if it.Valid() {
var klen C.size_t
kdata := C.leveldb_iter_key(it.it, &klen)
if kdata == nil {
return nil
} else if bytes.Equal(slice(unsafe.Pointer(kdata), int(C.int(klen))), key) {
return it.Value()
}
}
return nil
}
type RangeLimitIterator struct {
it *Iterator
r *Range
offset int
limit int
step int
//0 for IteratorForward, 1 for IteratorBackward
direction uint8
}
func (it *RangeLimitIterator) Key() []byte {
return it.it.Key()
}
func (it *RangeLimitIterator) Value() []byte {
return it.it.Value()
}
func (it *RangeLimitIterator) Valid() bool {
if it.offset < 0 {
return false
} else if !it.it.Valid() {
return false
} else if it.limit >= 0 && it.step >= it.limit {
return false
}
if it.direction == IteratorForward {
if it.r.Max != nil {
r := bytes.Compare(it.it.Key(), it.r.Max)
if it.r.Type&RangeROpen > 0 {
return !(r >= 0)
} else {
return !(r > 0)
}
}
} else {
if it.r.Min != nil {
r := bytes.Compare(it.it.Key(), it.r.Min)
if it.r.Type&RangeLOpen > 0 {
return !(r <= 0)
} else {
return !(r < 0)
}
}
}
return true
}
func (it *RangeLimitIterator) Next() {
it.step++
if it.direction == IteratorForward {
it.it.Next()
} else {
it.it.Prev()
}
}
func (it *RangeLimitIterator) Close() {
it.it.Close()
}
func newRangeLimitIterator(i *Iterator, r *Range, offset int, limit int, direction uint8) *RangeLimitIterator {
it := new(RangeLimitIterator)
it.it = i
it.r = r
it.offset = offset
it.limit = limit
it.direction = direction
it.step = 0
if offset < 0 {
return it
}
if direction == IteratorForward {
if r.Min == nil {
it.it.SeekToFirst()
} else {
it.it.Seek(r.Min)
if r.Type&RangeLOpen > 0 {
if it.it.Valid() && bytes.Equal(it.it.Key(), r.Min) {
it.it.Next()
}
}
}
} else {
if r.Max == nil {
it.it.SeekToLast()
} else {
it.it.Seek(r.Max)
if !it.it.Valid() {
it.it.SeekToLast()
} else {
if !bytes.Equal(it.it.Key(), r.Max) {
it.it.Prev()
}
}
if r.Type&RangeROpen > 0 {
if it.it.Valid() && bytes.Equal(it.it.Key(), r.Max) {
it.it.Prev()
}
}
}
}
for i := 0; i < offset; i++ {
if it.it.Valid() {
if it.direction == IteratorForward {
it.it.Next()
} else {
it.it.Prev()
}
}
}
return it
}

259
leveldb/leveldb_test.go Normal file
View File

@ -0,0 +1,259 @@
package leveldb
import (
"bytes"
"fmt"
"os"
"sync"
"testing"
)
var testConfigJson = []byte(`
{
"path" : "./testdb",
"compression":true,
"block_size" : 32768,
"write_buffer_size" : 2097152,
"cache_size" : 20971520
}
`)
var testOnce sync.Once
var testDB *DB
func getTestDB() *DB {
f := func() {
var err error
testDB, err = Open(testConfigJson)
if err != nil {
println(err.Error())
panic(err)
}
}
testOnce.Do(f)
return testDB
}
func TestSimple(t *testing.T) {
db := getTestDB()
key := []byte("key")
value := []byte("hello world")
if err := db.Put(key, value); err != nil {
t.Fatal(err)
}
if v, err := db.Get(key); err != nil {
t.Fatal(err)
} else if !bytes.Equal(v, value) {
t.Fatal("not equal")
}
if err := db.Delete(key); err != nil {
t.Fatal(err)
}
if v, err := db.Get(key); err != nil {
t.Fatal(err)
} else if v != nil {
t.Fatal("must nil")
}
}
func TestBatch(t *testing.T) {
db := getTestDB()
key1 := []byte("key1")
key2 := []byte("key2")
value := []byte("hello world")
db.Put(key1, value)
db.Put(key2, value)
wb := db.NewWriteBatch()
defer wb.Close()
wb.Delete(key2)
wb.Put(key1, []byte("hello world2"))
if err := wb.Commit(); err != nil {
t.Fatal(err)
}
if v, err := db.Get(key2); err != nil {
t.Fatal(err)
} else if v != nil {
t.Fatal("must nil")
}
if v, err := db.Get(key1); err != nil {
t.Fatal(err)
} else if string(v) != "hello world2" {
t.Fatal(string(v))
}
wb.Delete(key1)
wb.Rollback()
if v, err := db.Get(key1); err != nil {
t.Fatal(err)
} else if string(v) != "hello world2" {
t.Fatal(string(v))
}
db.Delete(key1)
}
func checkIterator(it *RangeLimitIterator, cv ...int) error {
v := make([]string, 0, len(cv))
for ; it.Valid(); it.Next() {
k := it.Key()
v = append(v, string(k))
}
it.Close()
if len(v) != len(cv) {
return fmt.Errorf("len error %d != %d", len(v), len(cv))
}
for k, i := range cv {
if fmt.Sprintf("key_%d", i) != v[k] {
return fmt.Errorf("%s, %d", v[k], i)
}
}
return nil
}
func TestIterator(t *testing.T) {
db := getTestDB()
db.Clear()
for i := 0; i < 10; i++ {
key := []byte(fmt.Sprintf("key_%d", i))
value := []byte("")
db.Put(key, value)
}
var it *RangeLimitIterator
k := func(i int) []byte {
return []byte(fmt.Sprintf("key_%d", i))
}
it = db.RangeLimitIterator(k(1), k(5), RangeClose, 0, -1)
if err := checkIterator(it, 1, 2, 3, 4, 5); err != nil {
t.Fatal(err)
}
it = db.RangeLimitIterator(k(1), k(5), RangeClose, 1, 3)
if err := checkIterator(it, 2, 3, 4); err != nil {
t.Fatal(err)
}
it = db.RangeLimitIterator(k(1), k(5), RangeLOpen, 0, -1)
if err := checkIterator(it, 2, 3, 4, 5); err != nil {
t.Fatal(err)
}
it = db.RangeLimitIterator(k(1), k(5), RangeROpen, 0, -1)
if err := checkIterator(it, 1, 2, 3, 4); err != nil {
t.Fatal(err)
}
it = db.RangeLimitIterator(k(1), k(5), RangeOpen, 0, -1)
if err := checkIterator(it, 2, 3, 4); err != nil {
t.Fatal(err)
}
it = db.RevRangeLimitIterator(k(1), k(5), RangeClose, 0, -1)
if err := checkIterator(it, 5, 4, 3, 2, 1); err != nil {
t.Fatal(err)
}
it = db.RevRangeLimitIterator(k(1), k(5), RangeClose, 1, 3)
if err := checkIterator(it, 4, 3, 2); err != nil {
t.Fatal(err)
}
it = db.RevRangeLimitIterator(k(1), k(5), RangeLOpen, 0, -1)
if err := checkIterator(it, 5, 4, 3, 2); err != nil {
t.Fatal(err)
}
it = db.RevRangeLimitIterator(k(1), k(5), RangeROpen, 0, -1)
if err := checkIterator(it, 4, 3, 2, 1); err != nil {
t.Fatal(err)
}
it = db.RevRangeLimitIterator(k(1), k(5), RangeOpen, 0, -1)
if err := checkIterator(it, 4, 3, 2); err != nil {
t.Fatal(err)
}
}
func TestSnapshot(t *testing.T) {
db := getTestDB()
key := []byte("key")
value := []byte("hello world")
db.Put(key, value)
s := db.NewSnapshot()
defer s.Close()
db.Put(key, []byte("hello world2"))
if v, err := s.Get(key); err != nil {
t.Fatal(err)
} else if string(v) != string(value) {
t.Fatal(string(v))
}
}
func TestDestroy(t *testing.T) {
db := getTestDB()
db.Put([]byte("a"), []byte("1"))
if err := db.Clear(); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(db.cfg.Path); err != nil {
t.Fatal("must exist ", err.Error())
}
if v, err := db.Get([]byte("a")); err != nil {
t.Fatal(err)
} else if string(v) == "1" {
t.Fatal(string(v))
}
db.Destroy()
if _, err := os.Stat(db.cfg.Path); !os.IsNotExist(err) {
t.Fatal("must not exist")
}
}
func TestCloseMore(t *testing.T) {
cfg := new(Config)
cfg.Path = "/tmp/testdb1234"
cfg.CacheSize = 4 * 1024 * 1024
os.RemoveAll(cfg.Path)
for i := 0; i < 100; i++ {
db, err := OpenWithConfig(cfg)
if err != nil {
t.Fatal(err)
}
db.Put([]byte("key"), []byte("value"))
db.Close()
}
}

7
leveldb/levigo-license Normal file
View File

@ -0,0 +1,7 @@
Copyright (c) 2012 Jeffrey M Hodges
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

128
leveldb/options.go Normal file
View File

@ -0,0 +1,128 @@
package leveldb
// #cgo LDFLAGS: -lleveldb
// #include "leveldb/c.h"
import "C"
type CompressionOpt int
const (
NoCompression = CompressionOpt(0)
SnappyCompression = CompressionOpt(1)
)
type Options struct {
Opt *C.leveldb_options_t
}
type ReadOptions struct {
Opt *C.leveldb_readoptions_t
}
type WriteOptions struct {
Opt *C.leveldb_writeoptions_t
}
func NewOptions() *Options {
opt := C.leveldb_options_create()
return &Options{opt}
}
func NewReadOptions() *ReadOptions {
opt := C.leveldb_readoptions_create()
return &ReadOptions{opt}
}
func NewWriteOptions() *WriteOptions {
opt := C.leveldb_writeoptions_create()
return &WriteOptions{opt}
}
func (o *Options) Close() {
C.leveldb_options_destroy(o.Opt)
}
func (o *Options) SetComparator(cmp *C.leveldb_comparator_t) {
C.leveldb_options_set_comparator(o.Opt, cmp)
}
func (o *Options) SetErrorIfExists(error_if_exists bool) {
eie := boolToUchar(error_if_exists)
C.leveldb_options_set_error_if_exists(o.Opt, eie)
}
func (o *Options) SetCache(cache *Cache) {
C.leveldb_options_set_cache(o.Opt, cache.Cache)
}
// func (o *Options) SetEnv(env *Env) {
// C.leveldb_options_set_env(o.Opt, env.Env)
// }
func (o *Options) SetInfoLog(log *C.leveldb_logger_t) {
C.leveldb_options_set_info_log(o.Opt, log)
}
func (o *Options) SetWriteBufferSize(s int) {
C.leveldb_options_set_write_buffer_size(o.Opt, C.size_t(s))
}
func (o *Options) SetParanoidChecks(pc bool) {
C.leveldb_options_set_paranoid_checks(o.Opt, boolToUchar(pc))
}
func (o *Options) SetMaxOpenFiles(n int) {
C.leveldb_options_set_max_open_files(o.Opt, C.int(n))
}
func (o *Options) SetBlockSize(s int) {
C.leveldb_options_set_block_size(o.Opt, C.size_t(s))
}
func (o *Options) SetBlockRestartInterval(n int) {
C.leveldb_options_set_block_restart_interval(o.Opt, C.int(n))
}
func (o *Options) SetCompression(t CompressionOpt) {
C.leveldb_options_set_compression(o.Opt, C.int(t))
}
func (o *Options) SetCreateIfMissing(b bool) {
C.leveldb_options_set_create_if_missing(o.Opt, boolToUchar(b))
}
func (o *Options) SetFilterPolicy(fp *FilterPolicy) {
var policy *C.leveldb_filterpolicy_t
if fp != nil {
policy = fp.Policy
}
C.leveldb_options_set_filter_policy(o.Opt, policy)
}
func (ro *ReadOptions) Close() {
C.leveldb_readoptions_destroy(ro.Opt)
}
func (ro *ReadOptions) SetVerifyChecksums(b bool) {
C.leveldb_readoptions_set_verify_checksums(ro.Opt, boolToUchar(b))
}
func (ro *ReadOptions) SetFillCache(b bool) {
C.leveldb_readoptions_set_fill_cache(ro.Opt, boolToUchar(b))
}
func (ro *ReadOptions) SetSnapshot(snap *Snapshot) {
var s *C.leveldb_snapshot_t
if snap != nil {
s = snap.snap
}
C.leveldb_readoptions_set_snapshot(ro.Opt, s)
}
func (wo *WriteOptions) Close() {
C.leveldb_writeoptions_destroy(wo.Opt)
}
func (wo *WriteOptions) SetSync(b bool) {
C.leveldb_writeoptions_set_sync(wo.Opt, boolToUchar(b))
}

54
leveldb/snapshot.go Normal file
View File

@ -0,0 +1,54 @@
package leveldb
// #cgo LDFLAGS: -lleveldb
// #include <stdint.h>
// #include "leveldb/c.h"
import "C"
type Snapshot struct {
db *DB
snap *C.leveldb_snapshot_t
readOpts *ReadOptions
iteratorOpts *ReadOptions
}
func (s *Snapshot) Close() {
C.leveldb_release_snapshot(s.db.db, s.snap)
s.iteratorOpts.Close()
s.readOpts.Close()
}
func (s *Snapshot) Get(key []byte) ([]byte, error) {
return s.db.get(s.readOpts, key)
}
func (s *Snapshot) NewIterator() *Iterator {
it := new(Iterator)
it.it = C.leveldb_create_iterator(s.db.db, s.iteratorOpts.Opt)
return it
}
func (s *Snapshot) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorForward)
}
func (s *Snapshot) RevRangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorBackward)
}
//limit < 0, unlimit
//offset must >= 0, if < 0, will get nothing
func (s *Snapshot) RangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorForward)
}
//limit < 0, unlimit
//offset must >= 0, if < 0, will get nothing
func (s *Snapshot) RevRangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorBackward)
}

43
leveldb/util.go Normal file
View File

@ -0,0 +1,43 @@
package leveldb
// #include "leveldb/c.h"
import "C"
import (
"fmt"
"reflect"
"unsafe"
)
func boolToUchar(b bool) C.uchar {
uc := C.uchar(0)
if b {
uc = C.uchar(1)
}
return uc
}
func ucharToBool(uc C.uchar) bool {
if uc == C.uchar(0) {
return false
}
return true
}
func saveError(errStr *C.char) error {
if errStr != nil {
gs := C.GoString(errStr)
C.leveldb_free(unsafe.Pointer(errStr))
return fmt.Errorf(gs)
}
return nil
}
func slice(p unsafe.Pointer, n int) []byte {
var b []byte
pbyte := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pbyte.Data = uintptr(p)
pbyte.Len = n
pbyte.Cap = n
return b
}

193
log/filehandler.go Normal file
View File

@ -0,0 +1,193 @@
package log
import (
"fmt"
"os"
"path"
"time"
)
type FileHandler struct {
fd *os.File
}
func NewFileHandler(fileName string, flag int) (*FileHandler, error) {
dir := path.Dir(fileName)
os.Mkdir(dir, 0777)
f, err := os.OpenFile(fileName, flag, 0)
if err != nil {
return nil, err
}
h := new(FileHandler)
h.fd = f
return h, nil
}
func (h *FileHandler) Write(b []byte) (n int, err error) {
return h.fd.Write(b)
}
func (h *FileHandler) Close() error {
return h.fd.Close()
}
type RotatingFileHandler struct {
fd *os.File
fileName string
maxBytes int
backupCount int
}
func NewRotatingFileHandler(fileName string, maxBytes int, backupCount int) (*RotatingFileHandler, error) {
dir := path.Dir(fileName)
os.Mkdir(dir, 0777)
h := new(RotatingFileHandler)
if maxBytes <= 0 {
return nil, fmt.Errorf("invalid max bytes")
}
h.fileName = fileName
h.maxBytes = maxBytes
h.backupCount = backupCount
var err error
h.fd, err = os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, err
}
return h, nil
}
func (h *RotatingFileHandler) Write(p []byte) (n int, err error) {
h.doRollover()
return h.fd.Write(p)
}
func (h *RotatingFileHandler) Close() error {
if h.fd != nil {
return h.fd.Close()
}
return nil
}
func (h *RotatingFileHandler) doRollover() {
f, err := h.fd.Stat()
if err != nil {
return
}
if h.maxBytes <= 0 {
return
} else if f.Size() < int64(h.maxBytes) {
return
}
if h.backupCount > 0 {
h.fd.Close()
for i := h.backupCount - 1; i > 0; i-- {
sfn := fmt.Sprintf("%s.%d", h.fileName, i)
dfn := fmt.Sprintf("%s.%d", h.fileName, i+1)
os.Rename(sfn, dfn)
}
dfn := fmt.Sprintf("%s.1", h.fileName)
os.Rename(h.fileName, dfn)
h.fd, _ = os.OpenFile(h.fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
}
}
//refer: http://docs.python.org/2/library/logging.handlers.html
//same like python TimedRotatingFileHandler
type TimeRotatingFileHandler struct {
fd *os.File
baseName string
interval int64
suffix string
rolloverAt int64
}
const (
WhenSecond = iota
WhenMinute
WhenHour
WhenDay
)
func NewTimeRotatingFileHandler(baseName string, when int8, interval int) (*TimeRotatingFileHandler, error) {
dir := path.Dir(baseName)
os.Mkdir(dir, 0777)
h := new(TimeRotatingFileHandler)
h.baseName = baseName
switch when {
case WhenSecond:
h.interval = 1
h.suffix = "2006-01-02_15-04-05"
case WhenMinute:
h.interval = 60
h.suffix = "2006-01-02_15-04"
case WhenHour:
h.interval = 3600
h.suffix = "2006-01-02_15"
case WhenDay:
h.interval = 3600 * 24
h.suffix = "2006-01-02"
default:
return nil, fmt.Errorf("invalid when_rotate: %d", when)
}
h.interval = h.interval * int64(interval)
var err error
h.fd, err = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, err
}
fInfo, _ := h.fd.Stat()
h.rolloverAt = fInfo.ModTime().Unix() + h.interval
return h, nil
}
func (h *TimeRotatingFileHandler) doRollover() {
//refer http://hg.python.org/cpython/file/2.7/Lib/logging/handlers.py
now := time.Now()
if h.rolloverAt <= now.Unix() {
fName := h.baseName + now.Format(h.suffix)
h.fd.Close()
e := os.Rename(h.baseName, fName)
if e != nil {
panic(e)
}
h.fd, _ = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
h.rolloverAt = time.Now().Unix() + h.interval
}
}
func (h *TimeRotatingFileHandler) Write(b []byte) (n int, err error) {
h.doRollover()
return h.fd.Write(b)
}
func (h *TimeRotatingFileHandler) Close() error {
return h.fd.Close()
}

45
log/handler.go Normal file
View File

@ -0,0 +1,45 @@
package log
import (
"io"
)
type Handler interface {
Write(p []byte) (n int, err error)
Close() error
}
type StreamHandler struct {
w io.Writer
}
func NewStreamHandler(w io.Writer) (*StreamHandler, error) {
h := new(StreamHandler)
h.w = w
return h, nil
}
func (h *StreamHandler) Write(b []byte) (n int, err error) {
return h.w.Write(b)
}
func (h *StreamHandler) Close() error {
return nil
}
type NullHandler struct {
}
func NewNullHandler() (*NullHandler, error) {
return new(NullHandler), nil
}
func (h *NullHandler) Write(b []byte) (n int, err error) {
return len(b), nil
}
func (h *NullHandler) Close() {
}

226
log/log.go Normal file
View File

@ -0,0 +1,226 @@
package log
import (
"fmt"
"os"
"runtime"
"strconv"
"sync"
"time"
)
const (
LevelTrace = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
LevelFatal
)
const (
Ltime = 1 << iota //time format "2006/01/02 15:04:05"
Lfile //file.go:123
Llevel //[Trace|Debug|Info...]
)
var LevelName [6]string = [6]string{"Trace", "Debug", "Info", "Warn", "Error", "Fatal"}
const TimeFormat = "2006/01/02 15:04:05"
const maxBufPoolSize = 16
type Logger struct {
sync.Mutex
level int
flag int
handler Handler
quit chan struct{}
msg chan []byte
bufs [][]byte
}
func New(handler Handler, flag int) *Logger {
var l = new(Logger)
l.level = LevelInfo
l.handler = handler
l.flag = flag
l.quit = make(chan struct{})
l.msg = make(chan []byte, 1024)
l.bufs = make([][]byte, 0, 16)
go l.run()
return l
}
func NewDefault(handler Handler) *Logger {
return New(handler, Ltime|Lfile|Llevel)
}
func newStdHandler() *StreamHandler {
h, _ := NewStreamHandler(os.Stdout)
return h
}
var std = NewDefault(newStdHandler())
func (l *Logger) run() {
for {
select {
case msg := <-l.msg:
l.handler.Write(msg)
l.putBuf(msg)
case <-l.quit:
l.handler.Close()
}
}
}
func (l *Logger) popBuf() []byte {
l.Lock()
var buf []byte
if len(l.bufs) == 0 {
buf = make([]byte, 0, 1024)
} else {
buf = l.bufs[len(l.bufs)-1]
l.bufs = l.bufs[0 : len(l.bufs)-1]
}
l.Unlock()
return buf
}
func (l *Logger) putBuf(buf []byte) {
l.Lock()
if len(l.bufs) < maxBufPoolSize {
buf = buf[0:0]
l.bufs = append(l.bufs, buf)
}
l.Unlock()
}
func (l *Logger) Close() {
if l.quit == nil {
return
}
close(l.quit)
l.quit = nil
}
func (l *Logger) SetLevel(level int) {
l.level = level
}
func (l *Logger) Output(callDepth int, level int, format string, v ...interface{}) {
if l.level > level {
return
}
buf := l.popBuf()
if l.flag&Ltime > 0 {
now := time.Now().Format(TimeFormat)
buf = append(buf, '[')
buf = append(buf, now...)
buf = append(buf, "] "...)
}
if l.flag&Lfile > 0 {
_, file, line, ok := runtime.Caller(callDepth)
if !ok {
file = "???"
line = 0
} else {
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
file = file[i+1:]
break
}
}
}
buf = append(buf, file...)
buf = append(buf, ':')
strconv.AppendInt(buf, int64(line), 10)
}
if l.flag&Llevel > 0 {
buf = append(buf, '[')
buf = append(buf, LevelName[level]...)
buf = append(buf, "] "...)
}
s := fmt.Sprintf(format, v...)
buf = append(buf, s...)
if s[len(s)-1] != '\n' {
buf = append(buf, '\n')
}
l.msg <- buf
}
func (l *Logger) Trace(format string, v ...interface{}) {
l.Output(2, LevelTrace, format, v...)
}
func (l *Logger) Debug(format string, v ...interface{}) {
l.Output(2, LevelDebug, format, v...)
}
func (l *Logger) Info(format string, v ...interface{}) {
l.Output(2, LevelInfo, format, v...)
}
func (l *Logger) Warn(format string, v ...interface{}) {
l.Output(2, LevelWarn, format, v...)
}
func (l *Logger) Error(format string, v ...interface{}) {
l.Output(2, LevelError, format, v...)
}
func (l *Logger) Fatal(format string, v ...interface{}) {
l.Output(2, LevelFatal, format, v...)
}
func SetLevel(level int) {
std.SetLevel(level)
}
func Trace(format string, v ...interface{}) {
std.Output(2, LevelTrace, format, v...)
}
func Debug(format string, v ...interface{}) {
std.Output(2, LevelDebug, format, v...)
}
func Info(format string, v ...interface{}) {
std.Output(2, LevelInfo, format, v...)
}
func Warn(format string, v ...interface{}) {
std.Output(2, LevelWarn, format, v...)
}
func Error(format string, v ...interface{}) {
std.Output(2, LevelError, format, v...)
}
func Fatal(format string, v ...interface{}) {
std.Output(2, LevelFatal, format, v...)
}

52
log/log_test.go Normal file
View File

@ -0,0 +1,52 @@
package log
import (
"os"
"testing"
)
func TestStdStreamLog(t *testing.T) {
h, _ := NewStreamHandler(os.Stdout)
s := NewDefault(h)
s.Info("hello world")
s.Close()
Info("hello world")
}
func TestRotatingFileLog(t *testing.T) {
path := "./test_log"
os.RemoveAll(path)
os.Mkdir(path, 0777)
fileName := path + "/test"
h, err := NewRotatingFileHandler(fileName, 10, 2)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 10)
h.Write(buf)
h.Write(buf)
if _, err := os.Stat(fileName + ".1"); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(fileName + ".2"); err == nil {
t.Fatal(err)
}
h.Write(buf)
if _, err := os.Stat(fileName + ".2"); err != nil {
t.Fatal(err)
}
h.Close()
os.RemoveAll(path)
}

62
log/sockethandler.go Normal file
View File

@ -0,0 +1,62 @@
package log
import (
"encoding/binary"
"net"
"time"
)
type SocketHandler struct {
c net.Conn
protocol string
addr string
}
func NewSocketHandler(protocol string, addr string) (*SocketHandler, error) {
s := new(SocketHandler)
s.protocol = protocol
s.addr = addr
return s, nil
}
func (h *SocketHandler) Write(p []byte) (n int, err error) {
if err = h.connect(); err != nil {
return
}
buf := make([]byte, len(p)+4)
binary.BigEndian.PutUint32(buf, uint32(len(p)))
copy(buf[4:], p)
n, err = h.c.Write(buf)
if err != nil {
h.c.Close()
h.c = nil
}
return
}
func (h *SocketHandler) Close() error {
if h.c != nil {
h.c.Close()
}
return nil
}
func (h *SocketHandler) connect() error {
if h.c != nil {
return nil
}
var err error
h.c, err = net.DialTimeout(h.protocol, h.addr, 20*time.Second)
if err != nil {
return err
}
return nil
}

View File

@ -1,7 +1,7 @@
package server package server
import ( import (
"github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/log"
) )
const ( const (

View File

@ -1,7 +1,7 @@
package server package server
import ( import (
"github.com/garyburd/redigo/redis" "github.com/siddontang/ledisdb/client/go/redis"
"os" "os"
"sync" "sync"
"testing" "testing"

View File

@ -4,8 +4,8 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"errors" "errors"
"github.com/siddontang/go-log/log"
"github.com/siddontang/ledisdb/ledis" "github.com/siddontang/ledisdb/ledis"
"github.com/siddontang/ledisdb/log"
"io" "io"
"net" "net"
"runtime" "runtime"

View File

@ -59,7 +59,7 @@ func hdelCommand(c *client) error {
return ErrCmdParams return ErrCmdParams
} }
if n, err := c.db.HDel(args[0], args[1:]); err != nil { if n, err := c.db.HDel(args[0], args[1:]...); err != nil {
return err return err
} else { } else {
c.writeInteger(n) c.writeInteger(n)
@ -138,7 +138,7 @@ func hmgetCommand(c *client) error {
return ErrCmdParams return ErrCmdParams
} }
if v, err := c.db.HMget(args[0], args[1:]); err != nil { if v, err := c.db.HMget(args[0], args[1:]...); err != nil {
return err return err
} else { } else {
c.writeArray(v) c.writeArray(v)
@ -207,6 +207,61 @@ func hclearCommand(c *client) error {
return nil return nil
} }
func hexpireCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
duration, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.HExpire(args[0], duration); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func hexpireAtCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
when, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.HExpireAt(args[0], when); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func httlCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
if v, err := c.db.HTTL(args[0]); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func init() { func init() {
register("hdel", hdelCommand) register("hdel", hdelCommand)
register("hexists", hexistsCommand) register("hexists", hexistsCommand)
@ -223,4 +278,7 @@ func init() {
//ledisdb special command //ledisdb special command
register("hclear", hclearCommand) register("hclear", hclearCommand)
register("hexpire", hexpireCommand)
register("hexpireat", hexpireAtCommand)
register("httl", httlCommand)
} }

View File

@ -2,7 +2,7 @@ package server
import ( import (
"fmt" "fmt"
"github.com/garyburd/redigo/redis" "github.com/siddontang/ledisdb/client/go/redis"
"strconv" "strconv"
"testing" "testing"
) )

View File

@ -203,6 +203,65 @@ func mgetCommand(c *client) error {
return nil return nil
} }
func expireCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
duration, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.Expire(args[0], duration); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func expireAtCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
when, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.ExpireAt(args[0], when); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func ttlCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
if v, err := c.db.TTL(args[0]); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
// func (db *DB) Expire(key []byte, duration int6
// func (db *DB) ExpireAt(key []byte, when int64)
// func (db *DB) TTL(key []byte) (int64, error)
func init() { func init() {
register("decr", decrCommand) register("decr", decrCommand)
register("decrby", decrbyCommand) register("decrby", decrbyCommand)
@ -216,4 +275,7 @@ func init() {
register("mset", msetCommand) register("mset", msetCommand)
register("set", setCommand) register("set", setCommand)
register("setnx", setnxCommand) register("setnx", setnxCommand)
register("expire", expireCommand)
register("expireat", expireAtCommand)
register("ttl", ttlCommand)
} }

View File

@ -1,7 +1,7 @@
package server package server
import ( import (
"github.com/garyburd/redigo/redis" "github.com/siddontang/ledisdb/client/go/redis"
"testing" "testing"
) )

View File

@ -143,6 +143,61 @@ func lclearCommand(c *client) error {
return nil return nil
} }
func lexpireCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
duration, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.LExpire(args[0], duration); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func lexpireAtCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
when, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.LExpireAt(args[0], when); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func lttlCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
if v, err := c.db.LTTL(args[0]); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func init() { func init() {
register("lindex", lindexCommand) register("lindex", lindexCommand)
register("llen", llenCommand) register("llen", llenCommand)
@ -155,5 +210,7 @@ func init() {
//ledisdb special command //ledisdb special command
register("lclear", lclearCommand) register("lclear", lclearCommand)
register("lexpire", lexpireCommand)
register("lexpireat", lexpireAtCommand)
register("lttl", lttlCommand)
} }

View File

@ -2,7 +2,7 @@ package server
import ( import (
"fmt" "fmt"
"github.com/garyburd/redigo/redis" "github.com/siddontang/ledisdb/client/go/redis"
"strconv" "strconv"
"testing" "testing"
) )

View File

@ -3,14 +3,14 @@ package server
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/ledisdb/leveldb"
"os" "os"
"testing" "testing"
"time" "time"
) )
func checkDataEqual(master *App, slave *App) error { func checkDataEqual(master *App, slave *App) error {
it := master.ldb.DataDB().Iterator(nil, nil, leveldb.RangeClose, 0, -1) it := master.ldb.DataDB().RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1)
for ; it.Valid(); it.Next() { for ; it.Valid(); it.Next() {
key := it.Key() key := it.Key()
value := it.Value() value := it.Value()

63
server/cmd_ttl_test.go Normal file
View File

@ -0,0 +1,63 @@
package server
import (
"github.com/siddontang/ledisdb/client/go/redis"
"testing"
"time"
)
func now() int64 {
return time.Now().Unix()
}
func TestKVExpire(t *testing.T) {
c := getTestConn()
defer c.Close()
k := "a_ttl"
c.Do("set", k, "123")
// expire + ttl
exp := int64(10)
if n, err := redis.Int(c.Do("expire", k, exp)); err != nil {
t.Fatal(err)
} else if n != 1 {
t.Fatal(n)
}
if ttl, err := redis.Int64(c.Do("ttl", k)); err != nil {
t.Fatal(err)
} else if ttl != exp {
t.Fatal(ttl)
}
// expireat + ttl
tm := now() + 3
if n, err := redis.Int(c.Do("expireat", k, tm)); err != nil {
t.Fatal(err)
} else if n != 1 {
t.Fatal(n)
}
if ttl, err := redis.Int64(c.Do("ttl", k)); err != nil {
t.Fatal(err)
} else if ttl != 3 {
t.Fatal(ttl)
}
kErr := "not_exist_ttl"
// err - expire, expireat
if n, err := redis.Int(c.Do("expire", kErr, tm)); err != nil || n != 0 {
t.Fatal(false)
}
if n, err := redis.Int(c.Do("expireat", kErr, tm)); err != nil || n != 0 {
t.Fatal(false)
}
if n, err := redis.Int(c.Do("ttl", kErr)); err != nil || n != -1 {
t.Fatal(false)
}
}

View File

@ -421,6 +421,61 @@ func zclearCommand(c *client) error {
return nil return nil
} }
func zexpireCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
duration, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.ZExpire(args[0], duration); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func zexpireAtCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
when, err := ledis.StrInt64(args[1], nil)
if err != nil {
return err
}
if v, err := c.db.ZExpireAt(args[0], when); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func zttlCommand(c *client) error {
args := c.args
if len(args) == 0 {
return ErrCmdParams
}
if v, err := c.db.ZTTL(args[0]); err != nil {
return err
} else {
c.writeInteger(v)
}
return nil
}
func init() { func init() {
register("zadd", zaddCommand) register("zadd", zaddCommand)
register("zcard", zcardCommand) register("zcard", zcardCommand)
@ -438,6 +493,9 @@ func init() {
register("zscore", zscoreCommand) register("zscore", zscoreCommand)
//ledisdb special command //ledisdb special command
register("zclear", zclearCommand)
register("zclear", zclearCommand)
register("zexpire", zexpireCommand)
register("zexpireat", zexpireAtCommand)
register("zttl", zttlCommand)
} }

View File

@ -2,7 +2,7 @@ package server
import ( import (
"fmt" "fmt"
"github.com/garyburd/redigo/redis" "github.com/siddontang/ledisdb/client/go/redis"
"strconv" "strconv"
"testing" "testing"
) )

View File

@ -7,8 +7,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/siddontang/go-log/log"
"github.com/siddontang/ledisdb/ledis" "github.com/siddontang/ledisdb/ledis"
"github.com/siddontang/ledisdb/log"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"