mirror of https://github.com/ledisdb/ledisdb.git
Merge branch 'develop'
This commit is contained in:
commit
8e8f3704e6
|
@ -2,6 +2,4 @@
|
|||
|
||||
. ./dev.sh
|
||||
|
||||
go get -u github.com/siddontang/go-leveldb/leveldb
|
||||
go get -u github.com/siddontang/go-log/log
|
||||
go get -u github.com/garyburd/redigo/redis
|
||||
#nothing to do now
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2014 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
watchState = 1 << iota
|
||||
multiState
|
||||
subscribeState
|
||||
monitorState
|
||||
)
|
||||
|
||||
type commandInfo struct {
|
||||
set, clear int
|
||||
}
|
||||
|
||||
var commandInfos = map[string]commandInfo{
|
||||
"WATCH": commandInfo{set: watchState},
|
||||
"UNWATCH": commandInfo{clear: watchState},
|
||||
"MULTI": commandInfo{set: multiState},
|
||||
"EXEC": commandInfo{clear: watchState | multiState},
|
||||
"DISCARD": commandInfo{clear: watchState | multiState},
|
||||
"PSUBSCRIBE": commandInfo{set: subscribeState},
|
||||
"SUBSCRIBE": commandInfo{set: subscribeState},
|
||||
"MONITOR": commandInfo{set: monitorState},
|
||||
}
|
||||
|
||||
func lookupCommandInfo(commandName string) commandInfo {
|
||||
return commandInfos[strings.ToUpper(commandName)]
|
||||
}
|
|
@ -0,0 +1,418 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// conn is the low-level implementation of Conn
|
||||
type conn struct {
|
||||
|
||||
// Shared
|
||||
mu sync.Mutex
|
||||
pending int
|
||||
err error
|
||||
conn net.Conn
|
||||
|
||||
// Read
|
||||
readTimeout time.Duration
|
||||
br *bufio.Reader
|
||||
|
||||
// Write
|
||||
writeTimeout time.Duration
|
||||
bw *bufio.Writer
|
||||
|
||||
// Scratch space for formatting argument length.
|
||||
// '*' or '$', length, "\r\n"
|
||||
lenScratch [32]byte
|
||||
|
||||
// Scratch space for formatting integers and floats.
|
||||
numScratch [40]byte
|
||||
}
|
||||
|
||||
// Dial connects to the Redis server at the given network and address.
|
||||
func Dial(network, address string) (Conn, error) {
|
||||
c, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(c, 0, 0), nil
|
||||
}
|
||||
|
||||
// DialTimeout acts like Dial but takes timeouts for establishing the
|
||||
// connection to the server, writing a command and reading a reply.
|
||||
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
|
||||
var c net.Conn
|
||||
var err error
|
||||
if connectTimeout > 0 {
|
||||
c, err = net.DialTimeout(network, address, connectTimeout)
|
||||
} else {
|
||||
c, err = net.Dial(network, address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(c, readTimeout, writeTimeout), nil
|
||||
}
|
||||
|
||||
// NewConn returns a new Redigo connection for the given net connection.
|
||||
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
|
||||
return &conn{
|
||||
conn: netConn,
|
||||
bw: bufio.NewWriter(netConn),
|
||||
br: bufio.NewReader(netConn),
|
||||
readTimeout: readTimeout,
|
||||
writeTimeout: writeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
c.mu.Lock()
|
||||
err := c.err
|
||||
if c.err == nil {
|
||||
c.err = errors.New("redigo: closed")
|
||||
err = c.conn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) fatal(err error) error {
|
||||
c.mu.Lock()
|
||||
if c.err == nil {
|
||||
c.err = err
|
||||
// Close connection to force errors on subsequent calls and to unblock
|
||||
// other reader or writer.
|
||||
c.conn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Err() error {
|
||||
c.mu.Lock()
|
||||
err := c.err
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeLen(prefix byte, n int) error {
|
||||
c.lenScratch[len(c.lenScratch)-1] = '\n'
|
||||
c.lenScratch[len(c.lenScratch)-2] = '\r'
|
||||
i := len(c.lenScratch) - 3
|
||||
for {
|
||||
c.lenScratch[i] = byte('0' + n%10)
|
||||
i -= 1
|
||||
n = n / 10
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.lenScratch[i] = prefix
|
||||
_, err := c.bw.Write(c.lenScratch[i:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeString(s string) error {
|
||||
c.writeLen('$', len(s))
|
||||
c.bw.WriteString(s)
|
||||
_, err := c.bw.WriteString("\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeBytes(p []byte) error {
|
||||
c.writeLen('$', len(p))
|
||||
c.bw.Write(p)
|
||||
_, err := c.bw.WriteString("\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeInt64(n int64) error {
|
||||
return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
|
||||
}
|
||||
|
||||
func (c *conn) writeFloat64(n float64) error {
|
||||
return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
|
||||
}
|
||||
|
||||
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
|
||||
c.writeLen('*', 1+len(args))
|
||||
err = c.writeString(cmd)
|
||||
for _, arg := range args {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
err = c.writeString(arg)
|
||||
case []byte:
|
||||
err = c.writeBytes(arg)
|
||||
case int:
|
||||
err = c.writeInt64(int64(arg))
|
||||
case int64:
|
||||
err = c.writeInt64(arg)
|
||||
case float64:
|
||||
err = c.writeFloat64(arg)
|
||||
case bool:
|
||||
if arg {
|
||||
err = c.writeString("1")
|
||||
} else {
|
||||
err = c.writeString("0")
|
||||
}
|
||||
case nil:
|
||||
err = c.writeString("")
|
||||
default:
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprint(&buf, arg)
|
||||
err = c.writeBytes(buf.Bytes())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) readLine() ([]byte, error) {
|
||||
p, err := c.br.ReadSlice('\n')
|
||||
if err == bufio.ErrBufferFull {
|
||||
return nil, errors.New("redigo: long response line")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i := len(p) - 2
|
||||
if i < 0 || p[i] != '\r' {
|
||||
return nil, errors.New("redigo: bad response line terminator")
|
||||
}
|
||||
return p[:i], nil
|
||||
}
|
||||
|
||||
// parseLen parses bulk string and array lengths.
|
||||
func parseLen(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return -1, errors.New("redigo: malformed length")
|
||||
}
|
||||
|
||||
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
|
||||
// handle $-1 and $-1 null replies.
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
var n int
|
||||
for _, b := range p {
|
||||
n *= 10
|
||||
if b < '0' || b > '9' {
|
||||
return -1, errors.New("redigo: illegal bytes in length")
|
||||
}
|
||||
n += int(b - '0')
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// parseInt parses an integer reply.
|
||||
func parseInt(p []byte) (interface{}, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, errors.New("redigo: malformed integer")
|
||||
}
|
||||
|
||||
var negate bool
|
||||
if p[0] == '-' {
|
||||
negate = true
|
||||
p = p[1:]
|
||||
if len(p) == 0 {
|
||||
return 0, errors.New("redigo: malformed integer")
|
||||
}
|
||||
}
|
||||
|
||||
var n int64
|
||||
for _, b := range p {
|
||||
n *= 10
|
||||
if b < '0' || b > '9' {
|
||||
return 0, errors.New("redigo: illegal bytes in length")
|
||||
}
|
||||
n += int64(b - '0')
|
||||
}
|
||||
|
||||
if negate {
|
||||
n = -n
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var (
|
||||
okReply interface{} = "OK"
|
||||
pongReply interface{} = "PONG"
|
||||
)
|
||||
|
||||
func (c *conn) readReply() (interface{}, error) {
|
||||
line, err := c.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(line) == 0 {
|
||||
return nil, errors.New("redigo: short response line")
|
||||
}
|
||||
switch line[0] {
|
||||
case '+':
|
||||
switch {
|
||||
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
|
||||
// Avoid allocation for frequent "+OK" response.
|
||||
return okReply, nil
|
||||
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
|
||||
// Avoid allocation in PING command benchmarks :)
|
||||
return pongReply, nil
|
||||
default:
|
||||
return string(line[1:]), nil
|
||||
}
|
||||
case '-':
|
||||
return Error(string(line[1:])), nil
|
||||
case ':':
|
||||
return parseInt(line[1:])
|
||||
case '$':
|
||||
n, err := parseLen(line[1:])
|
||||
if n < 0 || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := make([]byte, n)
|
||||
_, err = io.ReadFull(c.br, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if line, err := c.readLine(); err != nil {
|
||||
return nil, err
|
||||
} else if len(line) != 0 {
|
||||
return nil, errors.New("redigo: bad bulk string format")
|
||||
}
|
||||
return p, nil
|
||||
case '*':
|
||||
n, err := parseLen(line[1:])
|
||||
if n < 0 || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := make([]interface{}, n)
|
||||
for i := range r {
|
||||
r[i], err = c.readReply()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
return nil, errors.New("redigo: unexpected response line")
|
||||
}
|
||||
|
||||
func (c *conn) Send(cmd string, args ...interface{}) error {
|
||||
c.mu.Lock()
|
||||
c.pending += 1
|
||||
c.mu.Unlock()
|
||||
if c.writeTimeout != 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
if err := c.writeCommand(cmd, args); err != nil {
|
||||
return c.fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Flush() error {
|
||||
if c.writeTimeout != 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
if err := c.bw.Flush(); err != nil {
|
||||
return c.fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Receive() (reply interface{}, err error) {
|
||||
c.mu.Lock()
|
||||
// There can be more receives than sends when using pub/sub. To allow
|
||||
// normal use of the connection after unsubscribe from all channels, do not
|
||||
// decrement pending to a negative value.
|
||||
if c.pending > 0 {
|
||||
c.pending -= 1
|
||||
}
|
||||
c.mu.Unlock()
|
||||
if c.readTimeout != 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
if reply, err = c.readReply(); err != nil {
|
||||
return nil, c.fatal(err)
|
||||
}
|
||||
if err, ok := reply.(Error); ok {
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
|
||||
c.mu.Lock()
|
||||
pending := c.pending
|
||||
c.pending = 0
|
||||
c.mu.Unlock()
|
||||
|
||||
if cmd == "" && pending == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if c.writeTimeout != 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
if cmd != "" {
|
||||
c.writeCommand(cmd, args)
|
||||
}
|
||||
|
||||
if err := c.bw.Flush(); err != nil {
|
||||
return nil, c.fatal(err)
|
||||
}
|
||||
|
||||
if c.readTimeout != 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
if cmd == "" {
|
||||
reply := make([]interface{}, pending)
|
||||
for i := range reply {
|
||||
r, e := c.readReply()
|
||||
if e != nil {
|
||||
return nil, c.fatal(e)
|
||||
}
|
||||
reply[i] = r
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
var reply interface{}
|
||||
for i := 0; i <= pending; i++ {
|
||||
var e error
|
||||
if reply, e = c.readReply(); e != nil {
|
||||
return nil, c.fatal(e)
|
||||
}
|
||||
if e, ok := reply.(Error); ok && err == nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
return reply, err
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Package redis is a client for the Redis database.
|
||||
//
|
||||
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
|
||||
// documentation about this package.
|
||||
//
|
||||
// Connections
|
||||
//
|
||||
// The Conn interface is the primary interface for working with Redis.
|
||||
// Applications create connections by calling the Dial, DialWithTimeout or
|
||||
// NewConn functions. In the future, functions will be added for creating
|
||||
// sharded and other types of connections.
|
||||
//
|
||||
// The application must call the connection Close method when the application
|
||||
// is done with the connection.
|
||||
//
|
||||
// Executing Commands
|
||||
//
|
||||
// The Conn interface has a generic method for executing Redis commands:
|
||||
//
|
||||
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
|
||||
//
|
||||
// The Redis command reference (http://redis.io/commands) lists the available
|
||||
// commands. An example of using the Redis APPEND command is:
|
||||
//
|
||||
// n, err := conn.Do("APPEND", "key", "value")
|
||||
//
|
||||
// The Do method converts command arguments to binary strings for transmission
|
||||
// to the server as follows:
|
||||
//
|
||||
// Go Type Conversion
|
||||
// []byte Sent as is
|
||||
// string Sent as is
|
||||
// int, int64 strconv.FormatInt(v)
|
||||
// float64 strconv.FormatFloat(v, 'g', -1, 64)
|
||||
// bool true -> "1", false -> "0"
|
||||
// nil ""
|
||||
// all other types fmt.Print(v)
|
||||
//
|
||||
// Redis command reply types are represented using the following Go types:
|
||||
//
|
||||
// Redis type Go type
|
||||
// error redis.Error
|
||||
// integer int64
|
||||
// simple string string
|
||||
// bulk string []byte or nil if value not present.
|
||||
// array []interface{} or nil if value not present.
|
||||
//
|
||||
// Use type assertions or the reply helper functions to convert from
|
||||
// interface{} to the specific Go type for the command result.
|
||||
//
|
||||
// Pipelining
|
||||
//
|
||||
// Connections support pipelining using the Send, Flush and Receive methods.
|
||||
//
|
||||
// Send(commandName string, args ...interface{}) error
|
||||
// Flush() error
|
||||
// Receive() (reply interface{}, err error)
|
||||
//
|
||||
// Send writes the command to the connection's output buffer. Flush flushes the
|
||||
// connection's output buffer to the server. Receive reads a single reply from
|
||||
// the server. The following example shows a simple pipeline.
|
||||
//
|
||||
// c.Send("SET", "foo", "bar")
|
||||
// c.Send("GET", "foo")
|
||||
// c.Flush()
|
||||
// c.Receive() // reply from SET
|
||||
// v, err = c.Receive() // reply from GET
|
||||
//
|
||||
// The Do method combines the functionality of the Send, Flush and Receive
|
||||
// methods. The Do method starts by writing the command and flushing the output
|
||||
// buffer. Next, the Do method receives all pending replies including the reply
|
||||
// for the command just sent by Do. If any of the received replies is an error,
|
||||
// then Do returns the error. If there are no errors, then Do returns the last
|
||||
// reply. If the command argument to the Do method is "", then the Do method
|
||||
// will flush the output buffer and receive pending replies without sending a
|
||||
// command.
|
||||
//
|
||||
// Use the Send and Do methods to implement pipelined transactions.
|
||||
//
|
||||
// c.Send("MULTI")
|
||||
// c.Send("INCR", "foo")
|
||||
// c.Send("INCR", "bar")
|
||||
// r, err := c.Do("EXEC")
|
||||
// fmt.Println(r) // prints [1, 1]
|
||||
//
|
||||
// Concurrency
|
||||
//
|
||||
// Connections support a single concurrent caller to the write methods (Send,
|
||||
// Flush) and a single concurrent caller to the read method (Receive). Because
|
||||
// Do method combines the functionality of Send, Flush and Receive, the Do
|
||||
// method cannot be called concurrently with the other methods.
|
||||
//
|
||||
// For full concurrent access to Redis, use the thread-safe Pool to get and
|
||||
// release connections from within a goroutine.
|
||||
//
|
||||
// Publish and Subscribe
|
||||
//
|
||||
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
|
||||
//
|
||||
// c.Send("SUBSCRIBE", "example")
|
||||
// c.Flush()
|
||||
// for {
|
||||
// reply, err := c.Receive()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// // process pushed message
|
||||
// }
|
||||
//
|
||||
// The PubSubConn type wraps a Conn with convenience methods for implementing
|
||||
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
|
||||
// send and flush a subscription management command. The receive method
|
||||
// converts a pushed message to convenient types for use in a type switch.
|
||||
//
|
||||
// psc := PubSubConn{c}
|
||||
// psc.Subscribe("example")
|
||||
// for {
|
||||
// switch v := psc.Receive().(type) {
|
||||
// case redis.Message:
|
||||
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
|
||||
// case redis.Subscription:
|
||||
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
|
||||
// case error:
|
||||
// return v
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Reply Helpers
|
||||
//
|
||||
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
|
||||
// to a value of a specific type. To allow convenient wrapping of calls to the
|
||||
// connection Do and Receive methods, the functions take a second argument of
|
||||
// type error. If the error is non-nil, then the helper function returns the
|
||||
// error. If the error is nil, the function converts the reply to the specified
|
||||
// type:
|
||||
//
|
||||
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
|
||||
// if err != nil {
|
||||
// // handle error return from c.Do or type conversion error.
|
||||
// }
|
||||
//
|
||||
// The Scan function converts elements of a array reply to Go types:
|
||||
//
|
||||
// var value1 int
|
||||
// var value2 string
|
||||
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
package redis
|
|
@ -0,0 +1,117 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// NewLoggingConn returns a logging wrapper around a connection.
|
||||
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
|
||||
if prefix != "" {
|
||||
prefix = prefix + "."
|
||||
}
|
||||
return &loggingConn{conn, logger, prefix}
|
||||
}
|
||||
|
||||
type loggingConn struct {
|
||||
Conn
|
||||
logger *log.Logger
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (c *loggingConn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err)
|
||||
c.logger.Output(2, buf.String())
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) {
|
||||
const chop = 32
|
||||
switch v := v.(type) {
|
||||
case []byte:
|
||||
if len(v) > chop {
|
||||
fmt.Fprintf(buf, "%q...", v[:chop])
|
||||
} else {
|
||||
fmt.Fprintf(buf, "%q", v)
|
||||
}
|
||||
case string:
|
||||
if len(v) > chop {
|
||||
fmt.Fprintf(buf, "%q...", v[:chop])
|
||||
} else {
|
||||
fmt.Fprintf(buf, "%q", v)
|
||||
}
|
||||
case []interface{}:
|
||||
if len(v) == 0 {
|
||||
buf.WriteString("[]")
|
||||
} else {
|
||||
sep := "["
|
||||
fin := "]"
|
||||
if len(v) > chop {
|
||||
v = v[:chop]
|
||||
fin = "...]"
|
||||
}
|
||||
for _, vv := range v {
|
||||
buf.WriteString(sep)
|
||||
c.printValue(buf, vv)
|
||||
sep = ", "
|
||||
}
|
||||
buf.WriteString(fin)
|
||||
}
|
||||
default:
|
||||
fmt.Fprint(buf, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "%s%s(", c.prefix, method)
|
||||
if method != "Receive" {
|
||||
buf.WriteString(commandName)
|
||||
for _, arg := range args {
|
||||
buf.WriteString(", ")
|
||||
c.printValue(&buf, arg)
|
||||
}
|
||||
}
|
||||
buf.WriteString(") -> (")
|
||||
if method != "Send" {
|
||||
c.printValue(&buf, reply)
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
fmt.Fprintf(&buf, "%v)", err)
|
||||
c.logger.Output(3, buf.String())
|
||||
}
|
||||
|
||||
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) {
|
||||
reply, err := c.Conn.Do(commandName, args...)
|
||||
c.print("Do", commandName, args, reply, err)
|
||||
return reply, err
|
||||
}
|
||||
|
||||
func (c *loggingConn) Send(commandName string, args ...interface{}) error {
|
||||
err := c.Conn.Send(commandName, args...)
|
||||
c.print("Send", commandName, args, nil, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *loggingConn) Receive() (interface{}, error) {
|
||||
reply, err := c.Conn.Receive()
|
||||
c.print("Receive", "", nil, reply, err)
|
||||
return reply, err
|
||||
}
|
|
@ -0,0 +1,358 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var nowFunc = time.Now // for testing
|
||||
|
||||
// ErrPoolExhausted is returned from a pool connection method (Do, Send,
|
||||
// Receive, Flush, Err) when the maximum number of database connections in the
|
||||
// pool has been reached.
|
||||
var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
|
||||
|
||||
var errPoolClosed = errors.New("redigo: connection pool closed")
|
||||
|
||||
// Pool maintains a pool of connections. The application calls the Get method
|
||||
// to get a connection from the pool and the connection's Close method to
|
||||
// return the connection's resources to the pool.
|
||||
//
|
||||
// The following example shows how to use a pool in a web application. The
|
||||
// application creates a pool at application startup and makes it available to
|
||||
// request handlers using a global variable.
|
||||
//
|
||||
// func newPool(server, password string) *redis.Pool {
|
||||
// return &redis.Pool{
|
||||
// MaxIdle: 3,
|
||||
// IdleTimeout: 240 * time.Second,
|
||||
// Dial: func () (redis.Conn, error) {
|
||||
// c, err := redis.Dial("tcp", server)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if _, err := c.Do("AUTH", password); err != nil {
|
||||
// c.Close()
|
||||
// return nil, err
|
||||
// }
|
||||
// return c, err
|
||||
// },
|
||||
// TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||
// _, err := c.Do("PING")
|
||||
// return err
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// var (
|
||||
// pool *redis.Pool
|
||||
// redisServer = flag.String("redisServer", ":6379", "")
|
||||
// redisPassword = flag.String("redisPassword", "", "")
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// flag.Parse()
|
||||
// pool = newPool(*redisServer, *redisPassword)
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// A request handler gets a connection from the pool and closes the connection
|
||||
// when the handler is done:
|
||||
//
|
||||
// func serveHome(w http.ResponseWriter, r *http.Request) {
|
||||
// conn := pool.Get()
|
||||
// defer conn.Close()
|
||||
// ....
|
||||
// }
|
||||
//
|
||||
type Pool struct {
|
||||
|
||||
// Dial is an application supplied function for creating new connections.
|
||||
Dial func() (Conn, error)
|
||||
|
||||
// TestOnBorrow is an optional application supplied function for checking
|
||||
// the health of an idle connection before the connection is used again by
|
||||
// the application. Argument t is the time that the connection was returned
|
||||
// to the pool. If the function returns an error, then the connection is
|
||||
// closed.
|
||||
TestOnBorrow func(c Conn, t time.Time) error
|
||||
|
||||
// Maximum number of idle connections in the pool.
|
||||
MaxIdle int
|
||||
|
||||
// Maximum number of connections allocated by the pool at a given time.
|
||||
// When zero, there is no limit on the number of connections in the pool.
|
||||
MaxActive int
|
||||
|
||||
// Close connections after remaining idle for this duration. If the value
|
||||
// is zero, then idle connections are not closed. Applications should set
|
||||
// the timeout to a value less than the server's timeout.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// mu protects fields defined below.
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
active int
|
||||
|
||||
// Stack of idleConn with most recently used at the front.
|
||||
idle list.List
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
c Conn
|
||||
t time.Time
|
||||
}
|
||||
|
||||
// NewPool is a convenience function for initializing a pool.
|
||||
func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
|
||||
return &Pool{Dial: newFn, MaxIdle: maxIdle}
|
||||
}
|
||||
|
||||
// Get gets a connection. The application must close the returned connection.
|
||||
// The connection acquires an underlying connection on the first call to the
|
||||
// connection Do, Send, Receive, Flush or Err methods. An application can force
|
||||
// the connection to acquire an underlying connection without executing a Redis
|
||||
// command by calling the Err method.
|
||||
func (p *Pool) Get() Conn {
|
||||
return &pooledConnection{p: p}
|
||||
}
|
||||
|
||||
// ActiveCount returns the number of active connections in the pool.
|
||||
func (p *Pool) ActiveCount() int {
|
||||
p.mu.Lock()
|
||||
active := p.active
|
||||
p.mu.Unlock()
|
||||
return active
|
||||
}
|
||||
|
||||
// Close releases the resources used by the pool.
|
||||
func (p *Pool) Close() error {
|
||||
p.mu.Lock()
|
||||
idle := p.idle
|
||||
p.idle.Init()
|
||||
p.closed = true
|
||||
p.active -= idle.Len()
|
||||
p.mu.Unlock()
|
||||
for e := idle.Front(); e != nil; e = e.Next() {
|
||||
e.Value.(idleConn).c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get prunes stale connections and returns a connection from the idle list or
|
||||
// creates a new connection.
|
||||
func (p *Pool) get() (Conn, error) {
|
||||
p.mu.Lock()
|
||||
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return nil, errors.New("redigo: get on closed pool")
|
||||
}
|
||||
|
||||
// Prune stale connections.
|
||||
|
||||
if timeout := p.IdleTimeout; timeout > 0 {
|
||||
for i, n := 0, p.idle.Len(); i < n; i++ {
|
||||
e := p.idle.Back()
|
||||
if e == nil {
|
||||
break
|
||||
}
|
||||
ic := e.Value.(idleConn)
|
||||
if ic.t.Add(timeout).After(nowFunc()) {
|
||||
break
|
||||
}
|
||||
p.idle.Remove(e)
|
||||
p.active -= 1
|
||||
p.mu.Unlock()
|
||||
ic.c.Close()
|
||||
p.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
// Get idle connection.
|
||||
|
||||
for i, n := 0, p.idle.Len(); i < n; i++ {
|
||||
e := p.idle.Front()
|
||||
if e == nil {
|
||||
break
|
||||
}
|
||||
ic := e.Value.(idleConn)
|
||||
p.idle.Remove(e)
|
||||
test := p.TestOnBorrow
|
||||
p.mu.Unlock()
|
||||
if test == nil || test(ic.c, ic.t) == nil {
|
||||
return ic.c, nil
|
||||
}
|
||||
ic.c.Close()
|
||||
p.mu.Lock()
|
||||
p.active -= 1
|
||||
}
|
||||
|
||||
if p.MaxActive > 0 && p.active >= p.MaxActive {
|
||||
p.mu.Unlock()
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
// No idle connection, create new.
|
||||
|
||||
dial := p.Dial
|
||||
p.active += 1
|
||||
p.mu.Unlock()
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
p.mu.Lock()
|
||||
p.active -= 1
|
||||
p.mu.Unlock()
|
||||
c = nil
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
func (p *Pool) put(c Conn, forceClose bool) error {
|
||||
if c.Err() == nil && !forceClose {
|
||||
p.mu.Lock()
|
||||
if !p.closed {
|
||||
p.idle.PushFront(idleConn{t: nowFunc(), c: c})
|
||||
if p.idle.Len() > p.MaxIdle {
|
||||
c = p.idle.Remove(p.idle.Back()).(idleConn).c
|
||||
} else {
|
||||
c = nil
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
if c != nil {
|
||||
p.mu.Lock()
|
||||
p.active -= 1
|
||||
p.mu.Unlock()
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type pooledConnection struct {
|
||||
c Conn
|
||||
err error
|
||||
p *Pool
|
||||
state int
|
||||
}
|
||||
|
||||
func (c *pooledConnection) get() error {
|
||||
if c.err == nil && c.c == nil {
|
||||
c.c, c.err = c.p.get()
|
||||
}
|
||||
return c.err
|
||||
}
|
||||
|
||||
var (
|
||||
sentinel []byte
|
||||
sentinelOnce sync.Once
|
||||
)
|
||||
|
||||
func initSentinel() {
|
||||
p := make([]byte, 64)
|
||||
if _, err := rand.Read(p); err == nil {
|
||||
sentinel = p
|
||||
} else {
|
||||
h := sha1.New()
|
||||
io.WriteString(h, "Oops, rand failed. Use time instead.")
|
||||
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
|
||||
sentinel = h.Sum(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pooledConnection) Close() (err error) {
|
||||
if c.c != nil {
|
||||
if c.state&multiState != 0 {
|
||||
c.c.Send("DISCARD")
|
||||
c.state &^= (multiState | watchState)
|
||||
} else if c.state&watchState != 0 {
|
||||
c.c.Send("UNWATCH")
|
||||
c.state &^= watchState
|
||||
}
|
||||
if c.state&subscribeState != 0 {
|
||||
c.c.Send("UNSUBSCRIBE")
|
||||
c.c.Send("PUNSUBSCRIBE")
|
||||
// To detect the end of the message stream, ask the server to echo
|
||||
// a sentinel value and read until we see that value.
|
||||
sentinelOnce.Do(initSentinel)
|
||||
c.c.Send("ECHO", sentinel)
|
||||
c.c.Flush()
|
||||
for {
|
||||
p, err := c.c.Receive()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
|
||||
c.state &^= subscribeState
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.c.Do("")
|
||||
c.p.put(c.c, c.state != 0)
|
||||
c.c = nil
|
||||
c.err = errPoolClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pooledConnection) Err() error {
|
||||
if err := c.get(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.c.Err()
|
||||
}
|
||||
|
||||
func (c *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
|
||||
if err := c.get(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ci := lookupCommandInfo(commandName)
|
||||
c.state = (c.state | ci.set) &^ ci.clear
|
||||
return c.c.Do(commandName, args...)
|
||||
}
|
||||
|
||||
func (c *pooledConnection) Send(commandName string, args ...interface{}) error {
|
||||
if err := c.get(); err != nil {
|
||||
return err
|
||||
}
|
||||
ci := lookupCommandInfo(commandName)
|
||||
c.state = (c.state | ci.set) &^ ci.clear
|
||||
return c.c.Send(commandName, args...)
|
||||
}
|
||||
|
||||
func (c *pooledConnection) Flush() error {
|
||||
if err := c.get(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.c.Flush()
|
||||
}
|
||||
|
||||
func (c *pooledConnection) Receive() (reply interface{}, err error) {
|
||||
if err := c.get(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.c.Receive()
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Subscription represents a subscribe or unsubscribe notification.
|
||||
type Subscription struct {
|
||||
|
||||
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
|
||||
Kind string
|
||||
|
||||
// The channel that was changed.
|
||||
Channel string
|
||||
|
||||
// The current number of subscriptions for connection.
|
||||
Count int
|
||||
}
|
||||
|
||||
// Message represents a message notification.
|
||||
type Message struct {
|
||||
|
||||
// The originating channel.
|
||||
Channel string
|
||||
|
||||
// The message data.
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// PMessage represents a pmessage notification.
|
||||
type PMessage struct {
|
||||
|
||||
// The matched pattern.
|
||||
Pattern string
|
||||
|
||||
// The originating channel.
|
||||
Channel string
|
||||
|
||||
// The message data.
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// PubSubConn wraps a Conn with convenience methods for subscribers.
|
||||
type PubSubConn struct {
|
||||
Conn Conn
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c PubSubConn) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// Subscribe subscribes the connection to the specified channels.
|
||||
func (c PubSubConn) Subscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("SUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// PSubscribe subscribes the connection to the given patterns.
|
||||
func (c PubSubConn) PSubscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("PSUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes the connection from the given channels, or from all
|
||||
// of them if none is given.
|
||||
func (c PubSubConn) Unsubscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("UNSUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
|
||||
// of them if none is given.
|
||||
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("PUNSUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// Receive returns a pushed message as a Subscription, Message, PMessage or
|
||||
// error. The return value is intended to be used directly in a type switch as
|
||||
// illustrated in the PubSubConn example.
|
||||
func (c PubSubConn) Receive() interface{} {
|
||||
reply, err := Values(c.Conn.Receive())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var kind string
|
||||
reply, err = Scan(reply, &kind)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "message":
|
||||
var m Message
|
||||
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return m
|
||||
case "pmessage":
|
||||
var pm PMessage
|
||||
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return pm
|
||||
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe":
|
||||
s := Subscription{Kind: kind}
|
||||
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil {
|
||||
return err
|
||||
}
|
||||
return s
|
||||
}
|
||||
return errors.New("redigo: unknown pubsub notification")
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
// Error represents an error returned in a command reply.
|
||||
type Error string
|
||||
|
||||
func (err Error) Error() string { return string(err) }
|
||||
|
||||
// Conn represents a connection to a Redis server.
|
||||
type Conn interface {
|
||||
// Close closes the connection.
|
||||
Close() error
|
||||
|
||||
// Err returns a non-nil value if the connection is broken. The returned
|
||||
// value is either the first non-nil value returned from the underlying
|
||||
// network connection or a protocol parsing error. Applications should
|
||||
// close broken connections.
|
||||
Err() error
|
||||
|
||||
// Do sends a command to the server and returns the received reply.
|
||||
Do(commandName string, args ...interface{}) (reply interface{}, err error)
|
||||
|
||||
// Send writes the command to the client's output buffer.
|
||||
Send(commandName string, args ...interface{}) error
|
||||
|
||||
// Flush flushes the output buffer to the Redis server.
|
||||
Flush() error
|
||||
|
||||
// Receive receives a single reply from the Redis server
|
||||
Receive() (reply interface{}, err error)
|
||||
}
|
|
@ -0,0 +1,271 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ErrNil indicates that a reply value is nil.
|
||||
var ErrNil = errors.New("redigo: nil returned")
|
||||
|
||||
// Int is a helper that converts a command reply to an integer. If err is not
|
||||
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
|
||||
// reply to an int as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer int(reply), nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Int(reply interface{}, err error) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
x := int(reply)
|
||||
if int64(x) != reply {
|
||||
return 0, strconv.ErrRange
|
||||
}
|
||||
return x, nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseInt(string(reply), 10, 0)
|
||||
return int(n), err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Int, got type %T", reply)
|
||||
}
|
||||
|
||||
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||
// reply to an int64 as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer reply, nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Int64(reply interface{}, err error) (int64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
return reply, nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseInt(string(reply), 10, 64)
|
||||
return n, err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Int64, got type %T", reply)
|
||||
}
|
||||
|
||||
var errNegativeInt = errors.New("redigo: unexpected value for Uint64")
|
||||
|
||||
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||
// reply to an int64 as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer reply, nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Uint64(reply interface{}, err error) (uint64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
if reply < 0 {
|
||||
return 0, errNegativeInt
|
||||
}
|
||||
return uint64(reply), nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseUint(string(reply), 10, 64)
|
||||
return n, err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Uint64, got type %T", reply)
|
||||
}
|
||||
|
||||
// Float64 is a helper that converts a command reply to 64 bit float. If err is
|
||||
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
|
||||
// the reply to an int as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Float64(reply interface{}, err error) (float64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []byte:
|
||||
n, err := strconv.ParseFloat(string(reply), 64)
|
||||
return n, err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Float64, got type %T", reply)
|
||||
}
|
||||
|
||||
// String is a helper that converts a command reply to a string. If err is not
|
||||
// equal to nil, then String returns "", err. Otherwise String converts the
|
||||
// reply to a string as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string string(reply), nil
|
||||
// simple string reply, nil
|
||||
// nil "", ErrNil
|
||||
// other "", error
|
||||
func String(reply interface{}, err error) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []byte:
|
||||
return string(reply), nil
|
||||
case string:
|
||||
return reply, nil
|
||||
case nil:
|
||||
return "", ErrNil
|
||||
case Error:
|
||||
return "", reply
|
||||
}
|
||||
return "", fmt.Errorf("redigo: unexpected type for String, got type %T", reply)
|
||||
}
|
||||
|
||||
// Bytes is a helper that converts a command reply to a slice of bytes. If err
|
||||
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
|
||||
// the reply to a slice of bytes as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string reply, nil
|
||||
// simple string []byte(reply), nil
|
||||
// nil nil, ErrNil
|
||||
// other nil, error
|
||||
func Bytes(reply interface{}, err error) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []byte:
|
||||
return reply, nil
|
||||
case string:
|
||||
return []byte(reply), nil
|
||||
case nil:
|
||||
return nil, ErrNil
|
||||
case Error:
|
||||
return nil, reply
|
||||
}
|
||||
return nil, fmt.Errorf("redigo: unexpected type for Bytes, got type %T", reply)
|
||||
}
|
||||
|
||||
// Bool is a helper that converts a command reply to a boolean. If err is not
|
||||
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
|
||||
// reply to boolean as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer value != 0, nil
|
||||
// bulk string strconv.ParseBool(reply)
|
||||
// nil false, ErrNil
|
||||
// other false, error
|
||||
func Bool(reply interface{}, err error) (bool, error) {
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
return reply != 0, nil
|
||||
case []byte:
|
||||
return strconv.ParseBool(string(reply))
|
||||
case nil:
|
||||
return false, ErrNil
|
||||
case Error:
|
||||
return false, reply
|
||||
}
|
||||
return false, fmt.Errorf("redigo: unexpected type for Bool, got type %T", reply)
|
||||
}
|
||||
|
||||
// MultiBulk is deprecated. Use Values.
|
||||
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) }
|
||||
|
||||
// Values is a helper that converts an array command reply to a []interface{}.
|
||||
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
|
||||
// converts the reply as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// array reply, nil
|
||||
// nil nil, ErrNil
|
||||
// other nil, error
|
||||
func Values(reply interface{}, err error) ([]interface{}, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []interface{}:
|
||||
return reply, nil
|
||||
case nil:
|
||||
return nil, ErrNil
|
||||
case Error:
|
||||
return nil, reply
|
||||
}
|
||||
return nil, fmt.Errorf("redigo: unexpected type for Values, got type %T", reply)
|
||||
}
|
||||
|
||||
// Strings is a helper that converts an array command reply to a []string. If
|
||||
// err is not equal to nil, then Strings returns nil, err. If one of the array
|
||||
// items is not a bulk string or nil, then Strings returns an error.
|
||||
func Strings(reply interface{}, err error) ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []interface{}:
|
||||
result := make([]string, len(reply))
|
||||
for i := range reply {
|
||||
if reply[i] == nil {
|
||||
continue
|
||||
}
|
||||
p, ok := reply[i].([]byte)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i])
|
||||
}
|
||||
result[i] = string(p)
|
||||
}
|
||||
return result, nil
|
||||
case nil:
|
||||
return nil, ErrNil
|
||||
case Error:
|
||||
return nil, reply
|
||||
}
|
||||
return nil, fmt.Errorf("redigo: unexpected type for Strings, got type %T", reply)
|
||||
}
|
|
@ -0,0 +1,513 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func ensureLen(d reflect.Value, n int) {
|
||||
if n > d.Cap() {
|
||||
d.Set(reflect.MakeSlice(d.Type(), n, n))
|
||||
} else {
|
||||
d.SetLen(n)
|
||||
}
|
||||
}
|
||||
|
||||
func cannotConvert(d reflect.Value, s interface{}) error {
|
||||
return fmt.Errorf("redigo: Scan cannot convert from %s to %s",
|
||||
reflect.TypeOf(s), d.Type())
|
||||
}
|
||||
|
||||
func convertAssignBytes(d reflect.Value, s []byte) (err error) {
|
||||
switch d.Type().Kind() {
|
||||
case reflect.Float32, reflect.Float64:
|
||||
var x float64
|
||||
x, err = strconv.ParseFloat(string(s), d.Type().Bits())
|
||||
d.SetFloat(x)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
var x int64
|
||||
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
|
||||
d.SetInt(x)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
var x uint64
|
||||
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
|
||||
d.SetUint(x)
|
||||
case reflect.Bool:
|
||||
var x bool
|
||||
x, err = strconv.ParseBool(string(s))
|
||||
d.SetBool(x)
|
||||
case reflect.String:
|
||||
d.SetString(string(s))
|
||||
case reflect.Slice:
|
||||
if d.Type().Elem().Kind() != reflect.Uint8 {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
d.SetBytes(s)
|
||||
}
|
||||
default:
|
||||
err = cannotConvert(d, s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertAssignInt(d reflect.Value, s int64) (err error) {
|
||||
switch d.Type().Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
d.SetInt(s)
|
||||
if d.Int() != s {
|
||||
err = strconv.ErrRange
|
||||
d.SetInt(0)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if s < 0 {
|
||||
err = strconv.ErrRange
|
||||
} else {
|
||||
x := uint64(s)
|
||||
d.SetUint(x)
|
||||
if d.Uint() != x {
|
||||
err = strconv.ErrRange
|
||||
d.SetUint(0)
|
||||
}
|
||||
}
|
||||
case reflect.Bool:
|
||||
d.SetBool(s != 0)
|
||||
default:
|
||||
err = cannotConvert(d, s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertAssignValue(d reflect.Value, s interface{}) (err error) {
|
||||
switch s := s.(type) {
|
||||
case []byte:
|
||||
err = convertAssignBytes(d, s)
|
||||
case int64:
|
||||
err = convertAssignInt(d, s)
|
||||
default:
|
||||
err = cannotConvert(d, s)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func convertAssignValues(d reflect.Value, s []interface{}) error {
|
||||
if d.Type().Kind() != reflect.Slice {
|
||||
return cannotConvert(d, s)
|
||||
}
|
||||
ensureLen(d, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
if err := convertAssignValue(d.Index(i), s[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertAssign(d interface{}, s interface{}) (err error) {
|
||||
// Handle the most common destination types using type switches and
|
||||
// fall back to reflection for all other types.
|
||||
switch s := s.(type) {
|
||||
case nil:
|
||||
// ingore
|
||||
case []byte:
|
||||
switch d := d.(type) {
|
||||
case *string:
|
||||
*d = string(s)
|
||||
case *int:
|
||||
*d, err = strconv.Atoi(string(s))
|
||||
case *bool:
|
||||
*d, err = strconv.ParseBool(string(s))
|
||||
case *[]byte:
|
||||
*d = s
|
||||
case *interface{}:
|
||||
*d = s
|
||||
case nil:
|
||||
// skip value
|
||||
default:
|
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
err = convertAssignBytes(d.Elem(), s)
|
||||
}
|
||||
}
|
||||
case int64:
|
||||
switch d := d.(type) {
|
||||
case *int:
|
||||
x := int(s)
|
||||
if int64(x) != s {
|
||||
err = strconv.ErrRange
|
||||
x = 0
|
||||
}
|
||||
*d = x
|
||||
case *bool:
|
||||
*d = s != 0
|
||||
case *interface{}:
|
||||
*d = s
|
||||
case nil:
|
||||
// skip value
|
||||
default:
|
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
err = convertAssignInt(d.Elem(), s)
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
switch d := d.(type) {
|
||||
case *[]interface{}:
|
||||
*d = s
|
||||
case *interface{}:
|
||||
*d = s
|
||||
case nil:
|
||||
// skip value
|
||||
default:
|
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
err = convertAssignValues(d.Elem(), s)
|
||||
}
|
||||
}
|
||||
case Error:
|
||||
err = s
|
||||
default:
|
||||
err = cannotConvert(reflect.ValueOf(d), s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Scan copies from src to the values pointed at by dest.
|
||||
//
|
||||
// The values pointed at by dest must be an integer, float, boolean, string,
|
||||
// []byte, interface{} or slices of these types. Scan uses the standard strconv
|
||||
// package to convert bulk strings to numeric and boolean types.
|
||||
//
|
||||
// If a dest value is nil, then the corresponding src value is skipped.
|
||||
//
|
||||
// If a src element is nil, then the corresponding dest value is not modified.
|
||||
//
|
||||
// To enable easy use of Scan in a loop, Scan returns the slice of src
|
||||
// following the copied values.
|
||||
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) {
|
||||
if len(src) < len(dest) {
|
||||
return nil, errors.New("redigo: Scan array short")
|
||||
}
|
||||
var err error
|
||||
for i, d := range dest {
|
||||
err = convertAssign(d, src[i])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return src[len(dest):], err
|
||||
}
|
||||
|
||||
type fieldSpec struct {
|
||||
name string
|
||||
index []int
|
||||
//omitEmpty bool
|
||||
}
|
||||
|
||||
type structSpec struct {
|
||||
m map[string]*fieldSpec
|
||||
l []*fieldSpec
|
||||
}
|
||||
|
||||
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
|
||||
return ss.m[string(name)]
|
||||
}
|
||||
|
||||
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
switch {
|
||||
case f.PkgPath != "":
|
||||
// Ignore unexported fields.
|
||||
case f.Anonymous:
|
||||
// TODO: Handle pointers. Requires change to decoder and
|
||||
// protection against infinite recursion.
|
||||
if f.Type.Kind() == reflect.Struct {
|
||||
compileStructSpec(f.Type, depth, append(index, i), ss)
|
||||
}
|
||||
default:
|
||||
fs := &fieldSpec{name: f.Name}
|
||||
tag := f.Tag.Get("redis")
|
||||
p := strings.Split(tag, ",")
|
||||
if len(p) > 0 {
|
||||
if p[0] == "-" {
|
||||
continue
|
||||
}
|
||||
if len(p[0]) > 0 {
|
||||
fs.name = p[0]
|
||||
}
|
||||
for _, s := range p[1:] {
|
||||
switch s {
|
||||
//case "omitempty":
|
||||
// fs.omitempty = true
|
||||
default:
|
||||
panic(errors.New("redigo: unknown field flag " + s + " for type " + t.Name()))
|
||||
}
|
||||
}
|
||||
}
|
||||
d, found := depth[fs.name]
|
||||
if !found {
|
||||
d = 1 << 30
|
||||
}
|
||||
switch {
|
||||
case len(index) == d:
|
||||
// At same depth, remove from result.
|
||||
delete(ss.m, fs.name)
|
||||
j := 0
|
||||
for i := 0; i < len(ss.l); i++ {
|
||||
if fs.name != ss.l[i].name {
|
||||
ss.l[j] = ss.l[i]
|
||||
j += 1
|
||||
}
|
||||
}
|
||||
ss.l = ss.l[:j]
|
||||
case len(index) < d:
|
||||
fs.index = make([]int, len(index)+1)
|
||||
copy(fs.index, index)
|
||||
fs.index[len(index)] = i
|
||||
depth[fs.name] = len(index)
|
||||
ss.m[fs.name] = fs
|
||||
ss.l = append(ss.l, fs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
structSpecMutex sync.RWMutex
|
||||
structSpecCache = make(map[reflect.Type]*structSpec)
|
||||
defaultFieldSpec = &fieldSpec{}
|
||||
)
|
||||
|
||||
func structSpecForType(t reflect.Type) *structSpec {
|
||||
|
||||
structSpecMutex.RLock()
|
||||
ss, found := structSpecCache[t]
|
||||
structSpecMutex.RUnlock()
|
||||
if found {
|
||||
return ss
|
||||
}
|
||||
|
||||
structSpecMutex.Lock()
|
||||
defer structSpecMutex.Unlock()
|
||||
ss, found = structSpecCache[t]
|
||||
if found {
|
||||
return ss
|
||||
}
|
||||
|
||||
ss = &structSpec{m: make(map[string]*fieldSpec)}
|
||||
compileStructSpec(t, make(map[string]int), nil, ss)
|
||||
structSpecCache[t] = ss
|
||||
return ss
|
||||
}
|
||||
|
||||
var errScanStructValue = errors.New("redigo: ScanStruct value must be non-nil pointer to a struct")
|
||||
|
||||
// ScanStruct scans alternating names and values from src to a struct. The
|
||||
// HGETALL and CONFIG GET commands return replies in this format.
|
||||
//
|
||||
// ScanStruct uses exported field names to match values in the response. Use
|
||||
// 'redis' field tag to override the name:
|
||||
//
|
||||
// Field int `redis:"myName"`
|
||||
//
|
||||
// Fields with the tag redis:"-" are ignored.
|
||||
//
|
||||
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
|
||||
// standard strconv package to convert bulk string values to numeric and
|
||||
// boolean types.
|
||||
//
|
||||
// If a src element is nil, then the corresponding field is not modified.
|
||||
func ScanStruct(src []interface{}, dest interface{}) error {
|
||||
d := reflect.ValueOf(dest)
|
||||
if d.Kind() != reflect.Ptr || d.IsNil() {
|
||||
return errScanStructValue
|
||||
}
|
||||
d = d.Elem()
|
||||
if d.Kind() != reflect.Struct {
|
||||
return errScanStructValue
|
||||
}
|
||||
ss := structSpecForType(d.Type())
|
||||
|
||||
if len(src)%2 != 0 {
|
||||
return errors.New("redigo: ScanStruct expects even number of values in values")
|
||||
}
|
||||
|
||||
for i := 0; i < len(src); i += 2 {
|
||||
s := src[i+1]
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
name, ok := src[i].([]byte)
|
||||
if !ok {
|
||||
return errors.New("redigo: ScanStruct key not a bulk string value")
|
||||
}
|
||||
fs := ss.fieldSpec(name)
|
||||
if fs == nil {
|
||||
continue
|
||||
}
|
||||
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errScanSliceValue = errors.New("redigo: ScanSlice dest must be non-nil pointer to a struct")
|
||||
)
|
||||
|
||||
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
|
||||
// slice must be integer, float, boolean, string, struct or pointer to struct
|
||||
// values.
|
||||
//
|
||||
// Struct fields must be integer, float, boolean or string values. All struct
|
||||
// fields are used unless a subset is specified using fieldNames.
|
||||
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
|
||||
d := reflect.ValueOf(dest)
|
||||
if d.Kind() != reflect.Ptr || d.IsNil() {
|
||||
return errScanSliceValue
|
||||
}
|
||||
d = d.Elem()
|
||||
if d.Kind() != reflect.Slice {
|
||||
return errScanSliceValue
|
||||
}
|
||||
|
||||
isPtr := false
|
||||
t := d.Type().Elem()
|
||||
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
|
||||
isPtr = true
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
ensureLen(d, len(src))
|
||||
for i, s := range src {
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
if err := convertAssignValue(d.Index(i), s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ss := structSpecForType(t)
|
||||
fss := ss.l
|
||||
if len(fieldNames) > 0 {
|
||||
fss = make([]*fieldSpec, len(fieldNames))
|
||||
for i, name := range fieldNames {
|
||||
fss[i] = ss.m[name]
|
||||
if fss[i] == nil {
|
||||
return errors.New("redigo: ScanSlice bad field name " + name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(fss) == 0 {
|
||||
return errors.New("redigo: ScanSlice no struct fields")
|
||||
}
|
||||
|
||||
n := len(src) / len(fss)
|
||||
if n*len(fss) != len(src) {
|
||||
return errors.New("redigo: ScanSlice length not a multiple of struct field count")
|
||||
}
|
||||
|
||||
ensureLen(d, n)
|
||||
for i := 0; i < n; i++ {
|
||||
d := d.Index(i)
|
||||
if isPtr {
|
||||
if d.IsNil() {
|
||||
d.Set(reflect.New(t))
|
||||
}
|
||||
d = d.Elem()
|
||||
}
|
||||
for j, fs := range fss {
|
||||
s := src[i*len(fss)+j]
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Args is a helper for constructing command arguments from structured values.
|
||||
type Args []interface{}
|
||||
|
||||
// Add returns the result of appending value to args.
|
||||
func (args Args) Add(value ...interface{}) Args {
|
||||
return append(args, value...)
|
||||
}
|
||||
|
||||
// AddFlat returns the result of appending the flattened value of v to args.
|
||||
//
|
||||
// Maps are flattened by appending the alternating keys and map values to args.
|
||||
//
|
||||
// Slices are flattened by appending the slice elements to args.
|
||||
//
|
||||
// Structs are flattened by appending the alternating names and values of
|
||||
// exported fields to args. If v is a nil struct pointer, then nothing is
|
||||
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
|
||||
// for more information on the use of the 'redis' field tag.
|
||||
//
|
||||
// Other types are appended to args as is.
|
||||
func (args Args) AddFlat(v interface{}) Args {
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
case reflect.Struct:
|
||||
args = flattenStruct(args, rv)
|
||||
case reflect.Slice:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
args = append(args, rv.Index(i).Interface())
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, k := range rv.MapKeys() {
|
||||
args = append(args, k.Interface(), rv.MapIndex(k).Interface())
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if rv.Type().Elem().Kind() == reflect.Struct {
|
||||
if !rv.IsNil() {
|
||||
args = flattenStruct(args, rv.Elem())
|
||||
}
|
||||
} else {
|
||||
args = append(args, v)
|
||||
}
|
||||
default:
|
||||
args = append(args, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func flattenStruct(args Args, v reflect.Value) Args {
|
||||
ss := structSpecForType(v.Type())
|
||||
for _, fs := range ss.l {
|
||||
fv := v.FieldByIndex(fs.index)
|
||||
args = append(args, fs.name, fv.Interface())
|
||||
}
|
||||
return args
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Script encapsulates the source, hash and key count for a Lua script. See
|
||||
// http://redis.io/commands/eval for information on scripts in Redis.
|
||||
type Script struct {
|
||||
keyCount int
|
||||
src string
|
||||
hash string
|
||||
}
|
||||
|
||||
// NewScript returns a new script object. If keyCount is greater than or equal
|
||||
// to zero, then the count is automatically inserted in the EVAL command
|
||||
// argument list. If keyCount is less than zero, then the application supplies
|
||||
// the count as the first value in the keysAndArgs argument to the Do, Send and
|
||||
// SendHash methods.
|
||||
func NewScript(keyCount int, src string) *Script {
|
||||
h := sha1.New()
|
||||
io.WriteString(h, src)
|
||||
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
|
||||
}
|
||||
|
||||
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
|
||||
var args []interface{}
|
||||
if s.keyCount < 0 {
|
||||
args = make([]interface{}, 1+len(keysAndArgs))
|
||||
args[0] = spec
|
||||
copy(args[1:], keysAndArgs)
|
||||
} else {
|
||||
args = make([]interface{}, 2+len(keysAndArgs))
|
||||
args[0] = spec
|
||||
args[1] = s.keyCount
|
||||
copy(args[2:], keysAndArgs)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// Do evalutes the script. Under the covers, Do optimistically evaluates the
|
||||
// script using the EVALSHA command. If the command fails because the script is
|
||||
// not loaded, then Do evaluates the script using the EVAL command (thus
|
||||
// causing the script to load).
|
||||
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) {
|
||||
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...)
|
||||
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
|
||||
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...)
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// SendHash evaluates the script without waiting for the reply. The script is
|
||||
// evaluated with the EVALSHA command. The application must ensure that the
|
||||
// script is loaded by a previous call to Send, Do or Load methods.
|
||||
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error {
|
||||
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...)
|
||||
}
|
||||
|
||||
// Send evaluates the script without waiting for the reply.
|
||||
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error {
|
||||
return c.Send("EVAL", s.args(s.src, keysAndArgs)...)
|
||||
}
|
||||
|
||||
// Load loads the script without evaluating it.
|
||||
func (s *Script) Load(c Conn) error {
|
||||
_, err := c.Do("SCRIPT", "LOAD", s.src)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
from redis.client import Redis, StrictRedis
|
||||
from redis.connection import (
|
||||
BlockingConnectionPool,
|
||||
ConnectionPool,
|
||||
Connection,
|
||||
UnixDomainSocketConnection
|
||||
)
|
||||
from redis.utils import from_url
|
||||
from redis.exceptions import (
|
||||
AuthenticationError,
|
||||
ConnectionError,
|
||||
BusyLoadingError,
|
||||
DataError,
|
||||
InvalidResponse,
|
||||
PubSubError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
WatchError,
|
||||
)
|
||||
|
||||
|
||||
__version__ = '2.7.6'
|
||||
VERSION = tuple(map(int, __version__.split('.')))
|
||||
|
||||
__all__ = [
|
||||
'Redis', 'StrictRedis', 'ConnectionPool', 'BlockingConnectionPool',
|
||||
'Connection', 'UnixDomainSocketConnection',
|
||||
'RedisError', 'ConnectionError', 'ResponseError', 'AuthenticationError',
|
||||
'InvalidResponse', 'DataError', 'PubSubError', 'WatchError', 'from_url',
|
||||
'BusyLoadingError'
|
||||
]
|
|
@ -0,0 +1,79 @@
|
|||
"""Internal module for Python 2 backwards compatibility."""
|
||||
import sys
|
||||
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
from urlparse import urlparse
|
||||
from itertools import imap, izip
|
||||
from string import letters as ascii_letters
|
||||
from Queue import Queue
|
||||
try:
|
||||
from cStringIO import StringIO as BytesIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO as BytesIO
|
||||
|
||||
iteritems = lambda x: x.iteritems()
|
||||
iterkeys = lambda x: x.iterkeys()
|
||||
itervalues = lambda x: x.itervalues()
|
||||
nativestr = lambda x: \
|
||||
x if isinstance(x, str) else x.encode('utf-8', 'replace')
|
||||
u = lambda x: x.decode()
|
||||
b = lambda x: x
|
||||
next = lambda x: x.next()
|
||||
byte_to_chr = lambda x: x
|
||||
unichr = unichr
|
||||
xrange = xrange
|
||||
basestring = basestring
|
||||
unicode = unicode
|
||||
bytes = str
|
||||
long = long
|
||||
else:
|
||||
from urllib.parse import urlparse
|
||||
from io import BytesIO
|
||||
from string import ascii_letters
|
||||
from queue import Queue
|
||||
|
||||
iteritems = lambda x: iter(x.items())
|
||||
iterkeys = lambda x: iter(x.keys())
|
||||
itervalues = lambda x: iter(x.values())
|
||||
byte_to_chr = lambda x: chr(x)
|
||||
nativestr = lambda x: \
|
||||
x if isinstance(x, str) else x.decode('utf-8', 'replace')
|
||||
u = lambda x: x
|
||||
b = lambda x: x.encode('iso-8859-1') if not isinstance(x, bytes) else x
|
||||
next = next
|
||||
unichr = chr
|
||||
imap = map
|
||||
izip = zip
|
||||
xrange = range
|
||||
basestring = str
|
||||
unicode = str
|
||||
bytes = bytes
|
||||
long = int
|
||||
|
||||
try: # Python 3
|
||||
from queue import LifoQueue, Empty, Full
|
||||
except ImportError:
|
||||
from Queue import Empty, Full
|
||||
try: # Python 2.6 - 2.7
|
||||
from Queue import LifoQueue
|
||||
except ImportError: # Python 2.5
|
||||
from Queue import Queue
|
||||
# From the Python 2.7 lib. Python 2.5 already extracted the core
|
||||
# methods to aid implementating different queue organisations.
|
||||
|
||||
class LifoQueue(Queue):
|
||||
"Override queue methods to implement a last-in first-out queue."
|
||||
|
||||
def _init(self, maxsize):
|
||||
self.maxsize = maxsize
|
||||
self.queue = []
|
||||
|
||||
def _qsize(self, len=len):
|
||||
return len(self.queue)
|
||||
|
||||
def _put(self, item):
|
||||
self.queue.append(item)
|
||||
|
||||
def _get(self):
|
||||
return self.queue.pop()
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,580 @@
|
|||
from itertools import chain
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
|
||||
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
|
||||
BytesIO, nativestr, basestring,
|
||||
LifoQueue, Empty, Full)
|
||||
from redis.exceptions import (
|
||||
RedisError,
|
||||
ConnectionError,
|
||||
BusyLoadingError,
|
||||
ResponseError,
|
||||
InvalidResponse,
|
||||
AuthenticationError,
|
||||
NoScriptError,
|
||||
ExecAbortError,
|
||||
)
|
||||
from redis.utils import HIREDIS_AVAILABLE
|
||||
if HIREDIS_AVAILABLE:
|
||||
import hiredis
|
||||
|
||||
|
||||
SYM_STAR = b('*')
|
||||
SYM_DOLLAR = b('$')
|
||||
SYM_CRLF = b('\r\n')
|
||||
SYM_LF = b('\n')
|
||||
|
||||
|
||||
class PythonParser(object):
|
||||
"Plain Python parsing class"
|
||||
MAX_READ_LENGTH = 1000000
|
||||
encoding = None
|
||||
|
||||
EXCEPTION_CLASSES = {
|
||||
'ERR': ResponseError,
|
||||
'EXECABORT': ExecAbortError,
|
||||
'LOADING': BusyLoadingError,
|
||||
'NOSCRIPT': NoScriptError,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._fp = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.on_disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def on_connect(self, connection):
|
||||
"Called when the socket connects"
|
||||
self._fp = connection._sock.makefile('rb')
|
||||
if connection.decode_responses:
|
||||
self.encoding = connection.encoding
|
||||
|
||||
def on_disconnect(self):
|
||||
"Called when the socket disconnects"
|
||||
if self._fp is not None:
|
||||
self._fp.close()
|
||||
self._fp = None
|
||||
|
||||
def read(self, length=None):
|
||||
"""
|
||||
Read a line from the socket if no length is specified,
|
||||
otherwise read ``length`` bytes. Always strip away the newlines.
|
||||
"""
|
||||
try:
|
||||
if length is not None:
|
||||
bytes_left = length + 2 # read the line ending
|
||||
if length > self.MAX_READ_LENGTH:
|
||||
# apparently reading more than 1MB or so from a windows
|
||||
# socket can cause MemoryErrors. See:
|
||||
# https://github.com/andymccurdy/redis-py/issues/205
|
||||
# read smaller chunks at a time to work around this
|
||||
try:
|
||||
buf = BytesIO()
|
||||
while bytes_left > 0:
|
||||
read_len = min(bytes_left, self.MAX_READ_LENGTH)
|
||||
buf.write(self._fp.read(read_len))
|
||||
bytes_left -= read_len
|
||||
buf.seek(0)
|
||||
return buf.read(length)
|
||||
finally:
|
||||
buf.close()
|
||||
return self._fp.read(bytes_left)[:-2]
|
||||
|
||||
# no length, read a full line
|
||||
return self._fp.readline()[:-2]
|
||||
except (socket.error, socket.timeout):
|
||||
e = sys.exc_info()[1]
|
||||
raise ConnectionError("Error while reading from socket: %s" %
|
||||
(e.args,))
|
||||
|
||||
def parse_error(self, response):
|
||||
"Parse an error response"
|
||||
error_code = response.split(' ')[0]
|
||||
if error_code in self.EXCEPTION_CLASSES:
|
||||
response = response[len(error_code) + 1:]
|
||||
return self.EXCEPTION_CLASSES[error_code](response)
|
||||
return ResponseError(response)
|
||||
|
||||
def read_response(self):
|
||||
response = self.read()
|
||||
if not response:
|
||||
raise ConnectionError("Socket closed on remote end")
|
||||
|
||||
byte, response = byte_to_chr(response[0]), response[1:]
|
||||
|
||||
if byte not in ('-', '+', ':', '$', '*'):
|
||||
raise InvalidResponse("Protocol Error")
|
||||
|
||||
# server returned an error
|
||||
if byte == '-':
|
||||
response = nativestr(response)
|
||||
error = self.parse_error(response)
|
||||
# if the error is a ConnectionError, raise immediately so the user
|
||||
# is notified
|
||||
if isinstance(error, ConnectionError):
|
||||
raise error
|
||||
# otherwise, we're dealing with a ResponseError that might belong
|
||||
# inside a pipeline response. the connection's read_response()
|
||||
# and/or the pipeline's execute() will raise this error if
|
||||
# necessary, so just return the exception instance here.
|
||||
return error
|
||||
# single value
|
||||
elif byte == '+':
|
||||
pass
|
||||
# int value
|
||||
elif byte == ':':
|
||||
response = long(response)
|
||||
# bulk response
|
||||
elif byte == '$':
|
||||
length = int(response)
|
||||
if length == -1:
|
||||
return None
|
||||
response = self.read(length)
|
||||
# multi-bulk response
|
||||
elif byte == '*':
|
||||
length = int(response)
|
||||
if length == -1:
|
||||
return None
|
||||
response = [self.read_response() for i in xrange(length)]
|
||||
if isinstance(response, bytes) and self.encoding:
|
||||
response = response.decode(self.encoding)
|
||||
return response
|
||||
|
||||
|
||||
class HiredisParser(object):
|
||||
"Parser class for connections using Hiredis"
|
||||
def __init__(self):
|
||||
if not HIREDIS_AVAILABLE:
|
||||
raise RedisError("Hiredis is not installed")
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.on_disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def on_connect(self, connection):
|
||||
self._sock = connection._sock
|
||||
kwargs = {
|
||||
'protocolError': InvalidResponse,
|
||||
'replyError': ResponseError,
|
||||
}
|
||||
if connection.decode_responses:
|
||||
kwargs['encoding'] = connection.encoding
|
||||
self._reader = hiredis.Reader(**kwargs)
|
||||
|
||||
def on_disconnect(self):
|
||||
self._sock = None
|
||||
self._reader = None
|
||||
|
||||
def read_response(self):
|
||||
if not self._reader:
|
||||
raise ConnectionError("Socket closed on remote end")
|
||||
response = self._reader.gets()
|
||||
while response is False:
|
||||
try:
|
||||
buffer = self._sock.recv(4096)
|
||||
except (socket.error, socket.timeout):
|
||||
e = sys.exc_info()[1]
|
||||
raise ConnectionError("Error while reading from socket: %s" %
|
||||
(e.args,))
|
||||
if not buffer:
|
||||
raise ConnectionError("Socket closed on remote end")
|
||||
self._reader.feed(buffer)
|
||||
# proactively, but not conclusively, check if more data is in the
|
||||
# buffer. if the data received doesn't end with \n, there's more.
|
||||
if not buffer.endswith(SYM_LF):
|
||||
continue
|
||||
response = self._reader.gets()
|
||||
return response
|
||||
|
||||
if HIREDIS_AVAILABLE:
|
||||
DefaultParser = HiredisParser
|
||||
else:
|
||||
DefaultParser = PythonParser
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"Manages TCP communication to and from a Redis server"
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None,
|
||||
socket_timeout=None, encoding='utf-8',
|
||||
encoding_errors='strict', decode_responses=False,
|
||||
parser_class=DefaultParser):
|
||||
self.pid = os.getpid()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.db = db
|
||||
self.password = password
|
||||
self.socket_timeout = socket_timeout
|
||||
self.encoding = encoding
|
||||
self.encoding_errors = encoding_errors
|
||||
self.decode_responses = decode_responses
|
||||
self._sock = None
|
||||
self._parser = parser_class()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
"Connects to the Redis server if not already connected"
|
||||
if self._sock:
|
||||
return
|
||||
try:
|
||||
sock = self._connect()
|
||||
except socket.error:
|
||||
e = sys.exc_info()[1]
|
||||
raise ConnectionError(self._error_message(e))
|
||||
|
||||
self._sock = sock
|
||||
self.on_connect()
|
||||
|
||||
def _connect(self):
|
||||
"Create a TCP socket connection"
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.socket_timeout)
|
||||
sock.connect((self.host, self.port))
|
||||
return sock
|
||||
|
||||
def _error_message(self, exception):
|
||||
# args for socket.error can either be (errno, "message")
|
||||
# or just "message"
|
||||
if len(exception.args) == 1:
|
||||
return "Error connecting to %s:%s. %s." % \
|
||||
(self.host, self.port, exception.args[0])
|
||||
else:
|
||||
return "Error %s connecting %s:%s. %s." % \
|
||||
(exception.args[0], self.host, self.port, exception.args[1])
|
||||
|
||||
def on_connect(self):
|
||||
"Initialize the connection, authenticate and select a database"
|
||||
self._parser.on_connect(self)
|
||||
|
||||
# if a password is specified, authenticate
|
||||
if self.password:
|
||||
self.send_command('AUTH', self.password)
|
||||
if nativestr(self.read_response()) != 'OK':
|
||||
raise AuthenticationError('Invalid Password')
|
||||
|
||||
# if a database is specified, switch to it
|
||||
if self.db:
|
||||
self.send_command('SELECT', self.db)
|
||||
if nativestr(self.read_response()) != 'OK':
|
||||
raise ConnectionError('Invalid Database')
|
||||
|
||||
def disconnect(self):
|
||||
"Disconnects from the Redis server"
|
||||
self._parser.on_disconnect()
|
||||
if self._sock is None:
|
||||
return
|
||||
try:
|
||||
self._sock.close()
|
||||
except socket.error:
|
||||
pass
|
||||
self._sock = None
|
||||
|
||||
def send_packed_command(self, command):
|
||||
"Send an already packed command to the Redis server"
|
||||
if not self._sock:
|
||||
self.connect()
|
||||
try:
|
||||
self._sock.sendall(command)
|
||||
except socket.error:
|
||||
e = sys.exc_info()[1]
|
||||
self.disconnect()
|
||||
if len(e.args) == 1:
|
||||
_errno, errmsg = 'UNKNOWN', e.args[0]
|
||||
else:
|
||||
_errno, errmsg = e.args
|
||||
raise ConnectionError("Error %s while writing to socket. %s." %
|
||||
(_errno, errmsg))
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
raise
|
||||
|
||||
def send_command(self, *args):
|
||||
"Pack and send a command to the Redis server"
|
||||
self.send_packed_command(self.pack_command(*args))
|
||||
|
||||
def read_response(self):
|
||||
"Read the response from a previously sent command"
|
||||
try:
|
||||
response = self._parser.read_response()
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
raise
|
||||
if isinstance(response, ResponseError):
|
||||
raise response
|
||||
return response
|
||||
|
||||
def encode(self, value):
|
||||
"Return a bytestring representation of the value"
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
value = repr(value)
|
||||
if not isinstance(value, basestring):
|
||||
value = str(value)
|
||||
if isinstance(value, unicode):
|
||||
value = value.encode(self.encoding, self.encoding_errors)
|
||||
return value
|
||||
|
||||
def pack_command(self, *args):
|
||||
"Pack a series of arguments into a value Redis command"
|
||||
output = SYM_STAR + b(str(len(args))) + SYM_CRLF
|
||||
for enc_value in imap(self.encode, args):
|
||||
output += SYM_DOLLAR
|
||||
output += b(str(len(enc_value)))
|
||||
output += SYM_CRLF
|
||||
output += enc_value
|
||||
output += SYM_CRLF
|
||||
return output
|
||||
|
||||
|
||||
class UnixDomainSocketConnection(Connection):
|
||||
def __init__(self, path='', db=0, password=None,
|
||||
socket_timeout=None, encoding='utf-8',
|
||||
encoding_errors='strict', decode_responses=False,
|
||||
parser_class=DefaultParser):
|
||||
self.pid = os.getpid()
|
||||
self.path = path
|
||||
self.db = db
|
||||
self.password = password
|
||||
self.socket_timeout = socket_timeout
|
||||
self.encoding = encoding
|
||||
self.encoding_errors = encoding_errors
|
||||
self.decode_responses = decode_responses
|
||||
self._sock = None
|
||||
self._parser = parser_class()
|
||||
|
||||
def _connect(self):
|
||||
"Create a Unix domain socket connection"
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.socket_timeout)
|
||||
sock.connect(self.path)
|
||||
return sock
|
||||
|
||||
def _error_message(self, exception):
|
||||
# args for socket.error can either be (errno, "message")
|
||||
# or just "message"
|
||||
if len(exception.args) == 1:
|
||||
return "Error connecting to unix socket: %s. %s." % \
|
||||
(self.path, exception.args[0])
|
||||
else:
|
||||
return "Error %s connecting to unix socket: %s. %s." % \
|
||||
(exception.args[0], self.path, exception.args[1])
|
||||
|
||||
|
||||
# TODO: add ability to block waiting on a connection to be released
|
||||
class ConnectionPool(object):
|
||||
"Generic connection pool"
|
||||
def __init__(self, connection_class=Connection, max_connections=None,
|
||||
**connection_kwargs):
|
||||
self.pid = os.getpid()
|
||||
self.connection_class = connection_class
|
||||
self.connection_kwargs = connection_kwargs
|
||||
self.max_connections = max_connections or 2 ** 31
|
||||
self._created_connections = 0
|
||||
self._available_connections = []
|
||||
self._in_use_connections = set()
|
||||
|
||||
def _checkpid(self):
|
||||
if self.pid != os.getpid():
|
||||
self.disconnect()
|
||||
self.__init__(self.connection_class, self.max_connections,
|
||||
**self.connection_kwargs)
|
||||
|
||||
def get_connection(self, command_name, *keys, **options):
|
||||
"Get a connection from the pool"
|
||||
self._checkpid()
|
||||
try:
|
||||
connection = self._available_connections.pop()
|
||||
except IndexError:
|
||||
connection = self.make_connection()
|
||||
self._in_use_connections.add(connection)
|
||||
return connection
|
||||
|
||||
def make_connection(self):
|
||||
"Create a new connection"
|
||||
if self._created_connections >= self.max_connections:
|
||||
raise ConnectionError("Too many connections")
|
||||
self._created_connections += 1
|
||||
return self.connection_class(**self.connection_kwargs)
|
||||
|
||||
def release(self, connection):
|
||||
"Releases the connection back to the pool"
|
||||
self._checkpid()
|
||||
if connection.pid == self.pid:
|
||||
self._in_use_connections.remove(connection)
|
||||
self._available_connections.append(connection)
|
||||
|
||||
def disconnect(self):
|
||||
"Disconnects all connections in the pool"
|
||||
all_conns = chain(self._available_connections,
|
||||
self._in_use_connections)
|
||||
for connection in all_conns:
|
||||
connection.disconnect()
|
||||
|
||||
|
||||
class BlockingConnectionPool(object):
|
||||
"""
|
||||
Thread-safe blocking connection pool::
|
||||
|
||||
>>> from redis.client import Redis
|
||||
>>> client = Redis(connection_pool=BlockingConnectionPool())
|
||||
|
||||
It performs the same function as the default
|
||||
``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
|
||||
it maintains a pool of reusable connections that can be shared by
|
||||
multiple redis clients (safely across threads if required).
|
||||
|
||||
The difference is that, in the event that a client tries to get a
|
||||
connection from the pool when all of connections are in use, rather than
|
||||
raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
|
||||
``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
|
||||
makes the client wait ("blocks") for a specified number of seconds until
|
||||
a connection becomes available.
|
||||
|
||||
Use ``max_connections`` to increase / decrease the pool size::
|
||||
|
||||
>>> pool = BlockingConnectionPool(max_connections=10)
|
||||
|
||||
Use ``timeout`` to tell it either how many seconds to wait for a connection
|
||||
to become available, or to block forever:
|
||||
|
||||
# Block forever.
|
||||
>>> pool = BlockingConnectionPool(timeout=None)
|
||||
|
||||
# Raise a ``ConnectionError`` after five seconds if a connection is
|
||||
# not available.
|
||||
>>> pool = BlockingConnectionPool(timeout=5)
|
||||
"""
|
||||
def __init__(self, max_connections=50, timeout=20, connection_class=None,
|
||||
queue_class=None, **connection_kwargs):
|
||||
"Compose and assign values."
|
||||
# Compose.
|
||||
if connection_class is None:
|
||||
connection_class = Connection
|
||||
if queue_class is None:
|
||||
queue_class = LifoQueue
|
||||
|
||||
# Assign.
|
||||
self.connection_class = connection_class
|
||||
self.connection_kwargs = connection_kwargs
|
||||
self.queue_class = queue_class
|
||||
self.max_connections = max_connections
|
||||
self.timeout = timeout
|
||||
|
||||
# Validate the ``max_connections``. With the "fill up the queue"
|
||||
# algorithm we use, it must be a positive integer.
|
||||
is_valid = isinstance(max_connections, int) and max_connections > 0
|
||||
if not is_valid:
|
||||
raise ValueError('``max_connections`` must be a positive integer')
|
||||
|
||||
# Get the current process id, so we can disconnect and reinstantiate if
|
||||
# it changes.
|
||||
self.pid = os.getpid()
|
||||
|
||||
# Create and fill up a thread safe queue with ``None`` values.
|
||||
self.pool = self.queue_class(max_connections)
|
||||
while True:
|
||||
try:
|
||||
self.pool.put_nowait(None)
|
||||
except Full:
|
||||
break
|
||||
|
||||
# Keep a list of actual connection instances so that we can
|
||||
# disconnect them later.
|
||||
self._connections = []
|
||||
|
||||
def _checkpid(self):
|
||||
"""
|
||||
Check the current process id. If it has changed, disconnect and
|
||||
re-instantiate this connection pool instance.
|
||||
"""
|
||||
# Get the current process id.
|
||||
pid = os.getpid()
|
||||
|
||||
# If it hasn't changed since we were instantiated, then we're fine, so
|
||||
# just exit, remaining connected.
|
||||
if self.pid == pid:
|
||||
return
|
||||
|
||||
# If it has changed, then disconnect and re-instantiate.
|
||||
self.disconnect()
|
||||
self.reinstantiate()
|
||||
|
||||
def make_connection(self):
|
||||
"Make a fresh connection."
|
||||
connection = self.connection_class(**self.connection_kwargs)
|
||||
self._connections.append(connection)
|
||||
return connection
|
||||
|
||||
def get_connection(self, command_name, *keys, **options):
|
||||
"""
|
||||
Get a connection, blocking for ``self.timeout`` until a connection
|
||||
is available from the pool.
|
||||
|
||||
If the connection returned is ``None`` then creates a new connection.
|
||||
Because we use a last-in first-out queue, the existing connections
|
||||
(having been returned to the pool after the initial ``None`` values
|
||||
were added) will be returned before ``None`` values. This means we only
|
||||
create new connections when we need to, i.e.: the actual number of
|
||||
connections will only increase in response to demand.
|
||||
"""
|
||||
# Make sure we haven't changed process.
|
||||
self._checkpid()
|
||||
|
||||
# Try and get a connection from the pool. If one isn't available within
|
||||
# self.timeout then raise a ``ConnectionError``.
|
||||
connection = None
|
||||
try:
|
||||
connection = self.pool.get(block=True, timeout=self.timeout)
|
||||
except Empty:
|
||||
# Note that this is not caught by the redis client and will be
|
||||
# raised unless handled by application code. If you want never to
|
||||
raise ConnectionError("No connection available.")
|
||||
|
||||
# If the ``connection`` is actually ``None`` then that's a cue to make
|
||||
# a new connection to add to the pool.
|
||||
if connection is None:
|
||||
connection = self.make_connection()
|
||||
|
||||
return connection
|
||||
|
||||
def release(self, connection):
|
||||
"Releases the connection back to the pool."
|
||||
# Make sure we haven't changed process.
|
||||
self._checkpid()
|
||||
|
||||
# Put the connection back into the pool.
|
||||
try:
|
||||
self.pool.put_nowait(connection)
|
||||
except Full:
|
||||
# This shouldn't normally happen but might perhaps happen after a
|
||||
# reinstantiation. So, we can handle the exception by not putting
|
||||
# the connection back on the pool, because we definitely do not
|
||||
# want to reuse it.
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
"Disconnects all connections in the pool."
|
||||
for connection in self._connections:
|
||||
connection.disconnect()
|
||||
|
||||
def reinstantiate(self):
|
||||
"""
|
||||
Reinstatiate this instance within a new process with a new connection
|
||||
pool set.
|
||||
"""
|
||||
self.__init__(max_connections=self.max_connections,
|
||||
timeout=self.timeout,
|
||||
connection_class=self.connection_class,
|
||||
queue_class=self.queue_class, **self.connection_kwargs)
|
|
@ -0,0 +1,49 @@
|
|||
"Core exceptions raised by the Redis client"
|
||||
|
||||
|
||||
class RedisError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class ServerError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionError(ServerError):
|
||||
pass
|
||||
|
||||
|
||||
class BusyLoadingError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidResponse(ServerError):
|
||||
pass
|
||||
|
||||
|
||||
class ResponseError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class DataError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class PubSubError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class WatchError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class NoScriptError(ResponseError):
|
||||
pass
|
||||
|
||||
|
||||
class ExecAbortError(ResponseError):
|
||||
pass
|
|
@ -0,0 +1,16 @@
|
|||
try:
|
||||
import hiredis
|
||||
HIREDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
HIREDIS_AVAILABLE = False
|
||||
|
||||
|
||||
def from_url(url, db=None, **kwargs):
|
||||
"""
|
||||
Returns an active Redis client generated from the given database URL.
|
||||
|
||||
Will attempt to extract the database id from the path url fragment, if
|
||||
none is provided.
|
||||
"""
|
||||
from redis.client import Redis
|
||||
return Redis.from_url(url, db, **kwargs)
|
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env python
|
||||
import os
|
||||
import sys
|
||||
|
||||
from redis import __version__
|
||||
|
||||
try:
|
||||
from setuptools import setup
|
||||
from setuptools.command.test import test as TestCommand
|
||||
|
||||
class PyTest(TestCommand):
|
||||
def finalize_options(self):
|
||||
TestCommand.finalize_options(self)
|
||||
self.test_args = []
|
||||
self.test_suite = True
|
||||
|
||||
def run_tests(self):
|
||||
# import here, because outside the eggs aren't loaded
|
||||
import pytest
|
||||
errno = pytest.main(self.test_args)
|
||||
sys.exit(errno)
|
||||
|
||||
except ImportError:
|
||||
|
||||
from distutils.core import setup
|
||||
PyTest = lambda x: x
|
||||
|
||||
f = open(os.path.join(os.path.dirname(__file__), 'README.rst'))
|
||||
long_description = f.read()
|
||||
f.close()
|
||||
|
||||
setup(
|
||||
name='redis',
|
||||
version=__version__,
|
||||
description='Python client for Redis key-value store',
|
||||
long_description=long_description,
|
||||
url='http://github.com/andymccurdy/redis-py',
|
||||
author='Andy McCurdy',
|
||||
author_email='sedrik@gmail.com',
|
||||
maintainer='Andy McCurdy',
|
||||
maintainer_email='sedrik@gmail.com',
|
||||
keywords=['Redis', 'key-value store'],
|
||||
license='MIT',
|
||||
packages=['redis'],
|
||||
tests_require=['pytest>=2.5.0'],
|
||||
cmdclass={'test': PyTest},
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Environment :: Console',
|
||||
'Intended Audience :: Developers',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python',
|
||||
'Programming Language :: Python :: 2.6',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.2',
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
]
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
import redis
|
||||
|
||||
from distutils.version import StrictVersion
|
||||
|
||||
|
||||
_REDIS_VERSIONS = {}
|
||||
|
||||
|
||||
def get_version(**kwargs):
|
||||
params = {'host': 'localhost', 'port': 6379, 'db': 9}
|
||||
params.update(kwargs)
|
||||
key = '%s:%s' % (params['host'], params['port'])
|
||||
if key not in _REDIS_VERSIONS:
|
||||
client = redis.Redis(**params)
|
||||
_REDIS_VERSIONS[key] = client.info()['redis_version']
|
||||
client.connection_pool.disconnect()
|
||||
return _REDIS_VERSIONS[key]
|
||||
|
||||
|
||||
def _get_client(cls, request=None, **kwargs):
|
||||
params = {'host': 'localhost', 'port': 6379, 'db': 9}
|
||||
params.update(kwargs)
|
||||
client = cls(**params)
|
||||
client.flushdb()
|
||||
if request:
|
||||
def teardown():
|
||||
client.flushdb()
|
||||
client.connection_pool.disconnect()
|
||||
request.addfinalizer(teardown)
|
||||
return client
|
||||
|
||||
|
||||
def skip_if_server_version_lt(min_version):
|
||||
check = StrictVersion(get_version()) < StrictVersion(min_version)
|
||||
return pytest.mark.skipif(check, reason="")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def r(request, **kwargs):
|
||||
return _get_client(redis.Redis, request, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sr(request, **kwargs):
|
||||
return _get_client(redis.StrictRedis, request, **kwargs)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,402 @@
|
|||
from __future__ import with_statement
|
||||
import os
|
||||
import pytest
|
||||
import redis
|
||||
import time
|
||||
import re
|
||||
|
||||
from threading import Thread
|
||||
from redis.connection import ssl_available
|
||||
from .conftest import skip_if_server_version_lt
|
||||
|
||||
|
||||
class DummyConnection(object):
|
||||
description_format = "DummyConnection<>"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.pid = os.getpid()
|
||||
|
||||
|
||||
class TestConnectionPool(object):
|
||||
def get_pool(self, connection_kwargs=None, max_connections=None,
|
||||
connection_class=DummyConnection):
|
||||
connection_kwargs = connection_kwargs or {}
|
||||
pool = redis.ConnectionPool(
|
||||
connection_class=connection_class,
|
||||
max_connections=max_connections,
|
||||
**connection_kwargs)
|
||||
return pool
|
||||
|
||||
def test_connection_creation(self):
|
||||
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
|
||||
pool = self.get_pool(connection_kwargs=connection_kwargs)
|
||||
connection = pool.get_connection('_')
|
||||
assert isinstance(connection, DummyConnection)
|
||||
assert connection.kwargs == connection_kwargs
|
||||
|
||||
def test_multiple_connections(self):
|
||||
pool = self.get_pool()
|
||||
c1 = pool.get_connection('_')
|
||||
c2 = pool.get_connection('_')
|
||||
assert c1 != c2
|
||||
|
||||
def test_max_connections(self):
|
||||
pool = self.get_pool(max_connections=2)
|
||||
pool.get_connection('_')
|
||||
pool.get_connection('_')
|
||||
with pytest.raises(redis.ConnectionError):
|
||||
pool.get_connection('_')
|
||||
|
||||
def test_reuse_previously_released_connection(self):
|
||||
pool = self.get_pool()
|
||||
c1 = pool.get_connection('_')
|
||||
pool.release(c1)
|
||||
c2 = pool.get_connection('_')
|
||||
assert c1 == c2
|
||||
|
||||
def test_repr_contains_db_info_tcp(self):
|
||||
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
|
||||
pool = self.get_pool(connection_kwargs=connection_kwargs,
|
||||
connection_class=redis.Connection)
|
||||
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=1>>'
|
||||
assert repr(pool) == expected
|
||||
|
||||
def test_repr_contains_db_info_unix(self):
|
||||
connection_kwargs = {'path': '/abc', 'db': 1}
|
||||
pool = self.get_pool(connection_kwargs=connection_kwargs,
|
||||
connection_class=redis.UnixDomainSocketConnection)
|
||||
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
|
||||
assert repr(pool) == expected
|
||||
|
||||
|
||||
class TestBlockingConnectionPool(object):
|
||||
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
|
||||
connection_kwargs = connection_kwargs or {}
|
||||
pool = redis.BlockingConnectionPool(connection_class=DummyConnection,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
**connection_kwargs)
|
||||
return pool
|
||||
|
||||
def test_connection_creation(self):
|
||||
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
|
||||
pool = self.get_pool(connection_kwargs=connection_kwargs)
|
||||
connection = pool.get_connection('_')
|
||||
assert isinstance(connection, DummyConnection)
|
||||
assert connection.kwargs == connection_kwargs
|
||||
|
||||
def test_multiple_connections(self):
|
||||
pool = self.get_pool()
|
||||
c1 = pool.get_connection('_')
|
||||
c2 = pool.get_connection('_')
|
||||
assert c1 != c2
|
||||
|
||||
def test_connection_pool_blocks_until_timeout(self):
|
||||
"When out of connections, block for timeout seconds, then raise"
|
||||
pool = self.get_pool(max_connections=1, timeout=0.1)
|
||||
pool.get_connection('_')
|
||||
|
||||
start = time.time()
|
||||
with pytest.raises(redis.ConnectionError):
|
||||
pool.get_connection('_')
|
||||
# we should have waited at least 0.1 seconds
|
||||
assert time.time() - start >= 0.1
|
||||
|
||||
def connection_pool_blocks_until_another_connection_released(self):
|
||||
"""
|
||||
When out of connections, block until another connection is released
|
||||
to the pool
|
||||
"""
|
||||
pool = self.get_pool(max_connections=1, timeout=2)
|
||||
c1 = pool.get_connection('_')
|
||||
|
||||
def target():
|
||||
time.sleep(0.1)
|
||||
pool.release(c1)
|
||||
|
||||
Thread(target=target).start()
|
||||
start = time.time()
|
||||
pool.get_connection('_')
|
||||
assert time.time() - start >= 0.1
|
||||
|
||||
def test_reuse_previously_released_connection(self):
|
||||
pool = self.get_pool()
|
||||
c1 = pool.get_connection('_')
|
||||
pool.release(c1)
|
||||
c2 = pool.get_connection('_')
|
||||
assert c1 == c2
|
||||
|
||||
def test_repr_contains_db_info_tcp(self):
|
||||
pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
|
||||
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=0>>'
|
||||
assert repr(pool) == expected
|
||||
|
||||
def test_repr_contains_db_info_unix(self):
|
||||
pool = redis.ConnectionPool(
|
||||
connection_class=redis.UnixDomainSocketConnection,
|
||||
path='abc',
|
||||
db=0,
|
||||
)
|
||||
expected = 'ConnectionPool<UnixDomainSocketConnection<path=abc,db=0>>'
|
||||
assert repr(pool) == expected
|
||||
|
||||
|
||||
class TestConnectionPoolURLParsing(object):
|
||||
def test_defaults(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://localhost')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6379,
|
||||
'db': 0,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_hostname(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://myhost')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'myhost',
|
||||
'port': 6379,
|
||||
'db': 0,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_port(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://localhost:6380')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6380,
|
||||
'db': 0,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_password(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6379,
|
||||
'db': 0,
|
||||
'password': 'mypassword',
|
||||
}
|
||||
|
||||
def test_db_as_argument(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://localhost', db='1')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6379,
|
||||
'db': 1,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_db_in_path(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://localhost/2', db='1')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6379,
|
||||
'db': 2,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_db_in_querystring(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3',
|
||||
db='1')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6379,
|
||||
'db': 3,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_extra_querystring_options(self):
|
||||
pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2')
|
||||
assert pool.connection_class == redis.Connection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6379,
|
||||
'db': 0,
|
||||
'password': None,
|
||||
'a': '1',
|
||||
'b': '2'
|
||||
}
|
||||
|
||||
def test_calling_from_subclass_returns_correct_instance(self):
|
||||
pool = redis.BlockingConnectionPool.from_url('redis://localhost')
|
||||
assert isinstance(pool, redis.BlockingConnectionPool)
|
||||
|
||||
def test_client_creates_connection_pool(self):
|
||||
r = redis.StrictRedis.from_url('redis://myhost')
|
||||
assert r.connection_pool.connection_class == redis.Connection
|
||||
assert r.connection_pool.connection_kwargs == {
|
||||
'host': 'myhost',
|
||||
'port': 6379,
|
||||
'db': 0,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
|
||||
class TestConnectionPoolUnixSocketURLParsing(object):
|
||||
def test_defaults(self):
|
||||
pool = redis.ConnectionPool.from_url('unix:///socket')
|
||||
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||
assert pool.connection_kwargs == {
|
||||
'path': '/socket',
|
||||
'db': 0,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_password(self):
|
||||
pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket')
|
||||
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||
assert pool.connection_kwargs == {
|
||||
'path': '/socket',
|
||||
'db': 0,
|
||||
'password': 'mypassword',
|
||||
}
|
||||
|
||||
def test_db_as_argument(self):
|
||||
pool = redis.ConnectionPool.from_url('unix:///socket', db=1)
|
||||
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||
assert pool.connection_kwargs == {
|
||||
'path': '/socket',
|
||||
'db': 1,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_db_in_querystring(self):
|
||||
pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1)
|
||||
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||
assert pool.connection_kwargs == {
|
||||
'path': '/socket',
|
||||
'db': 2,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
def test_extra_querystring_options(self):
|
||||
pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2')
|
||||
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||
assert pool.connection_kwargs == {
|
||||
'path': '/socket',
|
||||
'db': 0,
|
||||
'password': None,
|
||||
'a': '1',
|
||||
'b': '2'
|
||||
}
|
||||
|
||||
|
||||
class TestSSLConnectionURLParsing(object):
|
||||
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
|
||||
def test_defaults(self):
|
||||
pool = redis.ConnectionPool.from_url('rediss://localhost')
|
||||
assert pool.connection_class == redis.SSLConnection
|
||||
assert pool.connection_kwargs == {
|
||||
'host': 'localhost',
|
||||
'port': 6379,
|
||||
'db': 0,
|
||||
'password': None,
|
||||
}
|
||||
|
||||
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
|
||||
def test_cert_reqs_options(self):
|
||||
import ssl
|
||||
pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none')
|
||||
assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE
|
||||
|
||||
pool = redis.ConnectionPool.from_url(
|
||||
'rediss://?ssl_cert_reqs=optional')
|
||||
assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL
|
||||
|
||||
pool = redis.ConnectionPool.from_url(
|
||||
'rediss://?ssl_cert_reqs=required')
|
||||
assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED
|
||||
|
||||
|
||||
class TestConnection(object):
|
||||
def test_on_connect_error(self):
|
||||
"""
|
||||
An error in Connection.on_connect should disconnect from the server
|
||||
see for details: https://github.com/andymccurdy/redis-py/issues/368
|
||||
"""
|
||||
# this assumes the Redis server being tested against doesn't have
|
||||
# 9999 databases ;)
|
||||
bad_connection = redis.Redis(db=9999)
|
||||
# an error should be raised on connect
|
||||
with pytest.raises(redis.RedisError):
|
||||
bad_connection.info()
|
||||
pool = bad_connection.connection_pool
|
||||
assert len(pool._available_connections) == 1
|
||||
assert not pool._available_connections[0]._sock
|
||||
|
||||
@skip_if_server_version_lt('2.8.8')
|
||||
def test_busy_loading_disconnects_socket(self, r):
|
||||
"""
|
||||
If Redis raises a LOADING error, the connection should be
|
||||
disconnected and a BusyLoadingError raised
|
||||
"""
|
||||
with pytest.raises(redis.BusyLoadingError):
|
||||
r.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
|
||||
pool = r.connection_pool
|
||||
assert len(pool._available_connections) == 1
|
||||
assert not pool._available_connections[0]._sock
|
||||
|
||||
@skip_if_server_version_lt('2.8.8')
|
||||
def test_busy_loading_from_pipeline_immediate_command(self, r):
|
||||
"""
|
||||
BusyLoadingErrors should raise from Pipelines that execute a
|
||||
command immediately, like WATCH does.
|
||||
"""
|
||||
pipe = r.pipeline()
|
||||
with pytest.raises(redis.BusyLoadingError):
|
||||
pipe.immediate_execute_command('DEBUG', 'ERROR',
|
||||
'LOADING fake message')
|
||||
pool = r.connection_pool
|
||||
assert not pipe.connection
|
||||
assert len(pool._available_connections) == 1
|
||||
assert not pool._available_connections[0]._sock
|
||||
|
||||
@skip_if_server_version_lt('2.8.8')
|
||||
def test_busy_loading_from_pipeline(self, r):
|
||||
"""
|
||||
BusyLoadingErrors should be raised from a pipeline execution
|
||||
regardless of the raise_on_error flag.
|
||||
"""
|
||||
pipe = r.pipeline()
|
||||
pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
|
||||
with pytest.raises(redis.BusyLoadingError):
|
||||
pipe.execute()
|
||||
pool = r.connection_pool
|
||||
assert not pipe.connection
|
||||
assert len(pool._available_connections) == 1
|
||||
assert not pool._available_connections[0]._sock
|
||||
|
||||
@skip_if_server_version_lt('2.8.8')
|
||||
def test_read_only_error(self, r):
|
||||
"READONLY errors get turned in ReadOnlyError exceptions"
|
||||
with pytest.raises(redis.ReadOnlyError):
|
||||
r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah')
|
||||
|
||||
def test_connect_from_url_tcp(self):
|
||||
connection = redis.Redis.from_url('redis://localhost')
|
||||
pool = connection.connection_pool
|
||||
|
||||
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
|
||||
'ConnectionPool',
|
||||
'Connection',
|
||||
'host=localhost,port=6379,db=0',
|
||||
)
|
||||
|
||||
def test_connect_from_url_unix(self):
|
||||
connection = redis.Redis.from_url('unix:///path/to/socket')
|
||||
pool = connection.connection_pool
|
||||
|
||||
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
|
||||
'ConnectionPool',
|
||||
'UnixDomainSocketConnection',
|
||||
'path=/path/to/socket,db=0',
|
||||
)
|
|
@ -0,0 +1,33 @@
|
|||
from __future__ import with_statement
|
||||
import pytest
|
||||
|
||||
from redis._compat import unichr, u, unicode
|
||||
from .conftest import r as _redis_client
|
||||
|
||||
|
||||
class TestEncoding(object):
|
||||
@pytest.fixture()
|
||||
def r(self, request):
|
||||
return _redis_client(request=request, decode_responses=True)
|
||||
|
||||
def test_simple_encoding(self, r):
|
||||
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
|
||||
r['unicode-string'] = unicode_string
|
||||
cached_val = r['unicode-string']
|
||||
assert isinstance(cached_val, unicode)
|
||||
assert unicode_string == cached_val
|
||||
|
||||
def test_list_encoding(self, r):
|
||||
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
|
||||
result = [unicode_string, unicode_string, unicode_string]
|
||||
r.rpush('a', *result)
|
||||
assert r.lrange('a', 0, -1) == result
|
||||
|
||||
|
||||
class TestCommandsAndTokensArentEncoded(object):
|
||||
@pytest.fixture()
|
||||
def r(self, request):
|
||||
return _redis_client(request=request, charset='utf-16')
|
||||
|
||||
def test_basic_command(self, r):
|
||||
r.set('hello', 'world')
|
|
@ -0,0 +1,167 @@
|
|||
from __future__ import with_statement
|
||||
import pytest
|
||||
import time
|
||||
|
||||
from redis.exceptions import LockError, ResponseError
|
||||
from redis.lock import Lock, LuaLock
|
||||
|
||||
|
||||
class TestLock(object):
|
||||
lock_class = Lock
|
||||
|
||||
def get_lock(self, redis, *args, **kwargs):
|
||||
kwargs['lock_class'] = self.lock_class
|
||||
return redis.lock(*args, **kwargs)
|
||||
|
||||
def test_lock(self, sr):
|
||||
lock = self.get_lock(sr, 'foo')
|
||||
assert lock.acquire(blocking=False)
|
||||
assert sr.get('foo') == lock.local.token
|
||||
assert sr.ttl('foo') == -1
|
||||
lock.release()
|
||||
assert sr.get('foo') is None
|
||||
|
||||
def test_competing_locks(self, sr):
|
||||
lock1 = self.get_lock(sr, 'foo')
|
||||
lock2 = self.get_lock(sr, 'foo')
|
||||
assert lock1.acquire(blocking=False)
|
||||
assert not lock2.acquire(blocking=False)
|
||||
lock1.release()
|
||||
assert lock2.acquire(blocking=False)
|
||||
assert not lock1.acquire(blocking=False)
|
||||
lock2.release()
|
||||
|
||||
def test_timeout(self, sr):
|
||||
lock = self.get_lock(sr, 'foo', timeout=10)
|
||||
assert lock.acquire(blocking=False)
|
||||
assert 8 < sr.ttl('foo') <= 10
|
||||
lock.release()
|
||||
|
||||
def test_float_timeout(self, sr):
|
||||
lock = self.get_lock(sr, 'foo', timeout=9.5)
|
||||
assert lock.acquire(blocking=False)
|
||||
assert 8 < sr.pttl('foo') <= 9500
|
||||
lock.release()
|
||||
|
||||
def test_blocking_timeout(self, sr):
|
||||
lock1 = self.get_lock(sr, 'foo')
|
||||
assert lock1.acquire(blocking=False)
|
||||
lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2)
|
||||
start = time.time()
|
||||
assert not lock2.acquire()
|
||||
assert (time.time() - start) > 0.2
|
||||
lock1.release()
|
||||
|
||||
def test_context_manager(self, sr):
|
||||
# blocking_timeout prevents a deadlock if the lock can't be acquired
|
||||
# for some reason
|
||||
with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock:
|
||||
assert sr.get('foo') == lock.local.token
|
||||
assert sr.get('foo') is None
|
||||
|
||||
def test_high_sleep_raises_error(self, sr):
|
||||
"If sleep is higher than timeout, it should raise an error"
|
||||
with pytest.raises(LockError):
|
||||
self.get_lock(sr, 'foo', timeout=1, sleep=2)
|
||||
|
||||
def test_releasing_unlocked_lock_raises_error(self, sr):
|
||||
lock = self.get_lock(sr, 'foo')
|
||||
with pytest.raises(LockError):
|
||||
lock.release()
|
||||
|
||||
def test_releasing_lock_no_longer_owned_raises_error(self, sr):
|
||||
lock = self.get_lock(sr, 'foo')
|
||||
lock.acquire(blocking=False)
|
||||
# manually change the token
|
||||
sr.set('foo', 'a')
|
||||
with pytest.raises(LockError):
|
||||
lock.release()
|
||||
# even though we errored, the token is still cleared
|
||||
assert lock.local.token is None
|
||||
|
||||
def test_extend_lock(self, sr):
|
||||
lock = self.get_lock(sr, 'foo', timeout=10)
|
||||
assert lock.acquire(blocking=False)
|
||||
assert 8000 < sr.pttl('foo') <= 10000
|
||||
assert lock.extend(10)
|
||||
assert 16000 < sr.pttl('foo') <= 20000
|
||||
lock.release()
|
||||
|
||||
def test_extend_lock_float(self, sr):
|
||||
lock = self.get_lock(sr, 'foo', timeout=10.0)
|
||||
assert lock.acquire(blocking=False)
|
||||
assert 8000 < sr.pttl('foo') <= 10000
|
||||
assert lock.extend(10.0)
|
||||
assert 16000 < sr.pttl('foo') <= 20000
|
||||
lock.release()
|
||||
|
||||
def test_extending_unlocked_lock_raises_error(self, sr):
|
||||
lock = self.get_lock(sr, 'foo', timeout=10)
|
||||
with pytest.raises(LockError):
|
||||
lock.extend(10)
|
||||
|
||||
def test_extending_lock_with_no_timeout_raises_error(self, sr):
|
||||
lock = self.get_lock(sr, 'foo')
|
||||
assert lock.acquire(blocking=False)
|
||||
with pytest.raises(LockError):
|
||||
lock.extend(10)
|
||||
lock.release()
|
||||
|
||||
def test_extending_lock_no_longer_owned_raises_error(self, sr):
|
||||
lock = self.get_lock(sr, 'foo')
|
||||
assert lock.acquire(blocking=False)
|
||||
sr.set('foo', 'a')
|
||||
with pytest.raises(LockError):
|
||||
lock.extend(10)
|
||||
|
||||
|
||||
class TestLuaLock(TestLock):
|
||||
lock_class = LuaLock
|
||||
|
||||
|
||||
class TestLockClassSelection(object):
|
||||
def test_lock_class_argument(self, sr):
|
||||
lock = sr.lock('foo', lock_class=Lock)
|
||||
assert type(lock) == Lock
|
||||
lock = sr.lock('foo', lock_class=LuaLock)
|
||||
assert type(lock) == LuaLock
|
||||
|
||||
def test_cached_lualock_flag(self, sr):
|
||||
try:
|
||||
sr._use_lua_lock = True
|
||||
lock = sr.lock('foo')
|
||||
assert type(lock) == LuaLock
|
||||
finally:
|
||||
sr._use_lua_lock = None
|
||||
|
||||
def test_cached_lock_flag(self, sr):
|
||||
try:
|
||||
sr._use_lua_lock = False
|
||||
lock = sr.lock('foo')
|
||||
assert type(lock) == Lock
|
||||
finally:
|
||||
sr._use_lua_lock = None
|
||||
|
||||
def test_lua_compatible_server(self, sr, monkeypatch):
|
||||
@classmethod
|
||||
def mock_register(cls, redis):
|
||||
return
|
||||
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
|
||||
try:
|
||||
lock = sr.lock('foo')
|
||||
assert type(lock) == LuaLock
|
||||
assert sr._use_lua_lock is True
|
||||
finally:
|
||||
sr._use_lua_lock = None
|
||||
|
||||
def test_lua_unavailable(self, sr, monkeypatch):
|
||||
@classmethod
|
||||
def mock_register(cls, redis):
|
||||
raise ResponseError()
|
||||
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
|
||||
try:
|
||||
lock = sr.lock('foo')
|
||||
assert type(lock) == Lock
|
||||
assert sr._use_lua_lock is False
|
||||
finally:
|
||||
sr._use_lua_lock = None
|
|
@ -0,0 +1,226 @@
|
|||
from __future__ import with_statement
|
||||
import pytest
|
||||
|
||||
import redis
|
||||
from redis._compat import b, u, unichr, unicode
|
||||
|
||||
|
||||
class TestPipeline(object):
|
||||
def test_pipeline(self, r):
|
||||
with r.pipeline() as pipe:
|
||||
pipe.set('a', 'a1').get('a').zadd('z', z1=1).zadd('z', z2=4)
|
||||
pipe.zincrby('z', 'z1').zrange('z', 0, 5, withscores=True)
|
||||
assert pipe.execute() == \
|
||||
[
|
||||
True,
|
||||
b('a1'),
|
||||
True,
|
||||
True,
|
||||
2.0,
|
||||
[(b('z1'), 2.0), (b('z2'), 4)],
|
||||
]
|
||||
|
||||
def test_pipeline_length(self, r):
|
||||
with r.pipeline() as pipe:
|
||||
# Initially empty.
|
||||
assert len(pipe) == 0
|
||||
assert not pipe
|
||||
|
||||
# Fill 'er up!
|
||||
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
|
||||
assert len(pipe) == 3
|
||||
assert pipe
|
||||
|
||||
# Execute calls reset(), so empty once again.
|
||||
pipe.execute()
|
||||
assert len(pipe) == 0
|
||||
assert not pipe
|
||||
|
||||
def test_pipeline_no_transaction(self, r):
|
||||
with r.pipeline(transaction=False) as pipe:
|
||||
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
|
||||
assert pipe.execute() == [True, True, True]
|
||||
assert r['a'] == b('a1')
|
||||
assert r['b'] == b('b1')
|
||||
assert r['c'] == b('c1')
|
||||
|
||||
def test_pipeline_no_transaction_watch(self, r):
|
||||
r['a'] = 0
|
||||
|
||||
with r.pipeline(transaction=False) as pipe:
|
||||
pipe.watch('a')
|
||||
a = pipe.get('a')
|
||||
|
||||
pipe.multi()
|
||||
pipe.set('a', int(a) + 1)
|
||||
assert pipe.execute() == [True]
|
||||
|
||||
def test_pipeline_no_transaction_watch_failure(self, r):
|
||||
r['a'] = 0
|
||||
|
||||
with r.pipeline(transaction=False) as pipe:
|
||||
pipe.watch('a')
|
||||
a = pipe.get('a')
|
||||
|
||||
r['a'] = 'bad'
|
||||
|
||||
pipe.multi()
|
||||
pipe.set('a', int(a) + 1)
|
||||
|
||||
with pytest.raises(redis.WatchError):
|
||||
pipe.execute()
|
||||
|
||||
assert r['a'] == b('bad')
|
||||
|
||||
def test_exec_error_in_response(self, r):
|
||||
"""
|
||||
an invalid pipeline command at exec time adds the exception instance
|
||||
to the list of returned values
|
||||
"""
|
||||
r['c'] = 'a'
|
||||
with r.pipeline() as pipe:
|
||||
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
|
||||
result = pipe.execute(raise_on_error=False)
|
||||
|
||||
assert result[0]
|
||||
assert r['a'] == b('1')
|
||||
assert result[1]
|
||||
assert r['b'] == b('2')
|
||||
|
||||
# we can't lpush to a key that's a string value, so this should
|
||||
# be a ResponseError exception
|
||||
assert isinstance(result[2], redis.ResponseError)
|
||||
assert r['c'] == b('a')
|
||||
|
||||
# since this isn't a transaction, the other commands after the
|
||||
# error are still executed
|
||||
assert result[3]
|
||||
assert r['d'] == b('4')
|
||||
|
||||
# make sure the pipe was restored to a working state
|
||||
assert pipe.set('z', 'zzz').execute() == [True]
|
||||
assert r['z'] == b('zzz')
|
||||
|
||||
def test_exec_error_raised(self, r):
|
||||
r['c'] = 'a'
|
||||
with r.pipeline() as pipe:
|
||||
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
|
||||
with pytest.raises(redis.ResponseError) as ex:
|
||||
pipe.execute()
|
||||
assert unicode(ex.value).startswith('Command # 3 (LPUSH c 3) of '
|
||||
'pipeline caused error: ')
|
||||
|
||||
# make sure the pipe was restored to a working state
|
||||
assert pipe.set('z', 'zzz').execute() == [True]
|
||||
assert r['z'] == b('zzz')
|
||||
|
||||
def test_parse_error_raised(self, r):
|
||||
with r.pipeline() as pipe:
|
||||
# the zrem is invalid because we don't pass any keys to it
|
||||
pipe.set('a', 1).zrem('b').set('b', 2)
|
||||
with pytest.raises(redis.ResponseError) as ex:
|
||||
pipe.execute()
|
||||
|
||||
assert unicode(ex.value).startswith('Command # 2 (ZREM b) of '
|
||||
'pipeline caused error: ')
|
||||
|
||||
# make sure the pipe was restored to a working state
|
||||
assert pipe.set('z', 'zzz').execute() == [True]
|
||||
assert r['z'] == b('zzz')
|
||||
|
||||
def test_watch_succeed(self, r):
|
||||
r['a'] = 1
|
||||
r['b'] = 2
|
||||
|
||||
with r.pipeline() as pipe:
|
||||
pipe.watch('a', 'b')
|
||||
assert pipe.watching
|
||||
a_value = pipe.get('a')
|
||||
b_value = pipe.get('b')
|
||||
assert a_value == b('1')
|
||||
assert b_value == b('2')
|
||||
pipe.multi()
|
||||
|
||||
pipe.set('c', 3)
|
||||
assert pipe.execute() == [True]
|
||||
assert not pipe.watching
|
||||
|
||||
def test_watch_failure(self, r):
|
||||
r['a'] = 1
|
||||
r['b'] = 2
|
||||
|
||||
with r.pipeline() as pipe:
|
||||
pipe.watch('a', 'b')
|
||||
r['b'] = 3
|
||||
pipe.multi()
|
||||
pipe.get('a')
|
||||
with pytest.raises(redis.WatchError):
|
||||
pipe.execute()
|
||||
|
||||
assert not pipe.watching
|
||||
|
||||
def test_unwatch(self, r):
|
||||
r['a'] = 1
|
||||
r['b'] = 2
|
||||
|
||||
with r.pipeline() as pipe:
|
||||
pipe.watch('a', 'b')
|
||||
r['b'] = 3
|
||||
pipe.unwatch()
|
||||
assert not pipe.watching
|
||||
pipe.get('a')
|
||||
assert pipe.execute() == [b('1')]
|
||||
|
||||
def test_transaction_callable(self, r):
|
||||
r['a'] = 1
|
||||
r['b'] = 2
|
||||
has_run = []
|
||||
|
||||
def my_transaction(pipe):
|
||||
a_value = pipe.get('a')
|
||||
assert a_value in (b('1'), b('2'))
|
||||
b_value = pipe.get('b')
|
||||
assert b_value == b('2')
|
||||
|
||||
# silly run-once code... incr's "a" so WatchError should be raised
|
||||
# forcing this all to run again. this should incr "a" once to "2"
|
||||
if not has_run:
|
||||
r.incr('a')
|
||||
has_run.append('it has')
|
||||
|
||||
pipe.multi()
|
||||
pipe.set('c', int(a_value) + int(b_value))
|
||||
|
||||
result = r.transaction(my_transaction, 'a', 'b')
|
||||
assert result == [True]
|
||||
assert r['c'] == b('4')
|
||||
|
||||
def test_exec_error_in_no_transaction_pipeline(self, r):
|
||||
r['a'] = 1
|
||||
with r.pipeline(transaction=False) as pipe:
|
||||
pipe.llen('a')
|
||||
pipe.expire('a', 100)
|
||||
|
||||
with pytest.raises(redis.ResponseError) as ex:
|
||||
pipe.execute()
|
||||
|
||||
assert unicode(ex.value).startswith('Command # 1 (LLEN a) of '
|
||||
'pipeline caused error: ')
|
||||
|
||||
assert r['a'] == b('1')
|
||||
|
||||
def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r):
|
||||
key = unichr(3456) + u('abcd') + unichr(3421)
|
||||
r[key] = 1
|
||||
with r.pipeline(transaction=False) as pipe:
|
||||
pipe.llen(key)
|
||||
pipe.expire(key, 100)
|
||||
|
||||
with pytest.raises(redis.ResponseError) as ex:
|
||||
pipe.execute()
|
||||
|
||||
expected = unicode('Command # 1 (LLEN %s) of pipeline caused '
|
||||
'error: ') % key
|
||||
assert unicode(ex.value).startswith(expected)
|
||||
|
||||
assert r[key] == b('1')
|
|
@ -0,0 +1,392 @@
|
|||
from __future__ import with_statement
|
||||
import pytest
|
||||
import time
|
||||
|
||||
import redis
|
||||
from redis.exceptions import ConnectionError
|
||||
from redis._compat import basestring, u, unichr
|
||||
|
||||
from .conftest import r as _redis_client
|
||||
|
||||
|
||||
def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
|
||||
now = time.time()
|
||||
timeout = now + timeout
|
||||
while now < timeout:
|
||||
message = pubsub.get_message(
|
||||
ignore_subscribe_messages=ignore_subscribe_messages)
|
||||
if message is not None:
|
||||
return message
|
||||
time.sleep(0.01)
|
||||
now = time.time()
|
||||
return None
|
||||
|
||||
|
||||
def make_message(type, channel, data, pattern=None):
|
||||
return {
|
||||
'type': type,
|
||||
'pattern': pattern and pattern.encode('utf-8') or None,
|
||||
'channel': channel.encode('utf-8'),
|
||||
'data': data.encode('utf-8') if isinstance(data, basestring) else data
|
||||
}
|
||||
|
||||
|
||||
def make_subscribe_test_data(pubsub, type):
|
||||
if type == 'channel':
|
||||
return {
|
||||
'p': pubsub,
|
||||
'sub_type': 'subscribe',
|
||||
'unsub_type': 'unsubscribe',
|
||||
'sub_func': pubsub.subscribe,
|
||||
'unsub_func': pubsub.unsubscribe,
|
||||
'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')]
|
||||
}
|
||||
elif type == 'pattern':
|
||||
return {
|
||||
'p': pubsub,
|
||||
'sub_type': 'psubscribe',
|
||||
'unsub_type': 'punsubscribe',
|
||||
'sub_func': pubsub.psubscribe,
|
||||
'unsub_func': pubsub.punsubscribe,
|
||||
'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')]
|
||||
}
|
||||
assert False, 'invalid subscribe type: %s' % type
|
||||
|
||||
|
||||
class TestPubSubSubscribeUnsubscribe(object):
|
||||
|
||||
def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func,
|
||||
unsub_func, keys):
|
||||
for key in keys:
|
||||
assert sub_func(key) is None
|
||||
|
||||
# should be a message for each channel/pattern we just subscribed to
|
||||
for i, key in enumerate(keys):
|
||||
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
|
||||
|
||||
for key in keys:
|
||||
assert unsub_func(key) is None
|
||||
|
||||
# should be a message for each channel/pattern we just unsubscribed
|
||||
# from
|
||||
for i, key in enumerate(keys):
|
||||
i = len(keys) - 1 - i
|
||||
assert wait_for_message(p) == make_message(unsub_type, key, i)
|
||||
|
||||
def test_channel_subscribe_unsubscribe(self, r):
|
||||
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
|
||||
self._test_subscribe_unsubscribe(**kwargs)
|
||||
|
||||
def test_pattern_subscribe_unsubscribe(self, r):
|
||||
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
|
||||
self._test_subscribe_unsubscribe(**kwargs)
|
||||
|
||||
def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type,
|
||||
sub_func, unsub_func, keys):
|
||||
|
||||
for key in keys:
|
||||
assert sub_func(key) is None
|
||||
|
||||
# should be a message for each channel/pattern we just subscribed to
|
||||
for i, key in enumerate(keys):
|
||||
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
|
||||
|
||||
# manually disconnect
|
||||
p.connection.disconnect()
|
||||
|
||||
# calling get_message again reconnects and resubscribes
|
||||
# note, we may not re-subscribe to channels in exactly the same order
|
||||
# so we have to do some extra checks to make sure we got them all
|
||||
messages = []
|
||||
for i in range(len(keys)):
|
||||
messages.append(wait_for_message(p))
|
||||
|
||||
unique_channels = set()
|
||||
assert len(messages) == len(keys)
|
||||
for i, message in enumerate(messages):
|
||||
assert message['type'] == sub_type
|
||||
assert message['data'] == i + 1
|
||||
assert isinstance(message['channel'], bytes)
|
||||
channel = message['channel'].decode('utf-8')
|
||||
unique_channels.add(channel)
|
||||
|
||||
assert len(unique_channels) == len(keys)
|
||||
for channel in unique_channels:
|
||||
assert channel in keys
|
||||
|
||||
def test_resubscribe_to_channels_on_reconnection(self, r):
|
||||
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
|
||||
self._test_resubscribe_on_reconnection(**kwargs)
|
||||
|
||||
def test_resubscribe_to_patterns_on_reconnection(self, r):
|
||||
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
|
||||
self._test_resubscribe_on_reconnection(**kwargs)
|
||||
|
||||
def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func,
|
||||
unsub_func, keys):
|
||||
|
||||
assert p.subscribed is False
|
||||
sub_func(keys[0])
|
||||
# we're now subscribed even though we haven't processed the
|
||||
# reply from the server just yet
|
||||
assert p.subscribed is True
|
||||
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
|
||||
# we're still subscribed
|
||||
assert p.subscribed is True
|
||||
|
||||
# unsubscribe from all channels
|
||||
unsub_func()
|
||||
# we're still technically subscribed until we process the
|
||||
# response messages from the server
|
||||
assert p.subscribed is True
|
||||
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
|
||||
# now we're no longer subscribed as no more messages can be delivered
|
||||
# to any channels we were listening to
|
||||
assert p.subscribed is False
|
||||
|
||||
# subscribing again flips the flag back
|
||||
sub_func(keys[0])
|
||||
assert p.subscribed is True
|
||||
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
|
||||
|
||||
# unsubscribe again
|
||||
unsub_func()
|
||||
assert p.subscribed is True
|
||||
# subscribe to another channel before reading the unsubscribe response
|
||||
sub_func(keys[1])
|
||||
assert p.subscribed is True
|
||||
# read the unsubscribe for key1
|
||||
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
|
||||
# we're still subscribed to key2, so subscribed should still be True
|
||||
assert p.subscribed is True
|
||||
# read the key2 subscribe message
|
||||
assert wait_for_message(p) == make_message(sub_type, keys[1], 1)
|
||||
unsub_func()
|
||||
# haven't read the message yet, so we're still subscribed
|
||||
assert p.subscribed is True
|
||||
assert wait_for_message(p) == make_message(unsub_type, keys[1], 0)
|
||||
# now we're finally unsubscribed
|
||||
assert p.subscribed is False
|
||||
|
||||
def test_subscribe_property_with_channels(self, r):
|
||||
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
|
||||
self._test_subscribed_property(**kwargs)
|
||||
|
||||
def test_subscribe_property_with_patterns(self, r):
|
||||
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
|
||||
self._test_subscribed_property(**kwargs)
|
||||
|
||||
def test_ignore_all_subscribe_messages(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
|
||||
checks = (
|
||||
(p.subscribe, 'foo'),
|
||||
(p.unsubscribe, 'foo'),
|
||||
(p.psubscribe, 'f*'),
|
||||
(p.punsubscribe, 'f*'),
|
||||
)
|
||||
|
||||
assert p.subscribed is False
|
||||
for func, channel in checks:
|
||||
assert func(channel) is None
|
||||
assert p.subscribed is True
|
||||
assert wait_for_message(p) is None
|
||||
assert p.subscribed is False
|
||||
|
||||
def test_ignore_individual_subscribe_messages(self, r):
|
||||
p = r.pubsub()
|
||||
|
||||
checks = (
|
||||
(p.subscribe, 'foo'),
|
||||
(p.unsubscribe, 'foo'),
|
||||
(p.psubscribe, 'f*'),
|
||||
(p.punsubscribe, 'f*'),
|
||||
)
|
||||
|
||||
assert p.subscribed is False
|
||||
for func, channel in checks:
|
||||
assert func(channel) is None
|
||||
assert p.subscribed is True
|
||||
message = wait_for_message(p, ignore_subscribe_messages=True)
|
||||
assert message is None
|
||||
assert p.subscribed is False
|
||||
|
||||
|
||||
class TestPubSubMessages(object):
|
||||
def setup_method(self, method):
|
||||
self.message = None
|
||||
|
||||
def message_handler(self, message):
|
||||
self.message = message
|
||||
|
||||
def test_published_message_to_channel(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.subscribe('foo')
|
||||
assert r.publish('foo', 'test message') == 1
|
||||
|
||||
message = wait_for_message(p)
|
||||
assert isinstance(message, dict)
|
||||
assert message == make_message('message', 'foo', 'test message')
|
||||
|
||||
def test_published_message_to_pattern(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.subscribe('foo')
|
||||
p.psubscribe('f*')
|
||||
# 1 to pattern, 1 to channel
|
||||
assert r.publish('foo', 'test message') == 2
|
||||
|
||||
message1 = wait_for_message(p)
|
||||
message2 = wait_for_message(p)
|
||||
assert isinstance(message1, dict)
|
||||
assert isinstance(message2, dict)
|
||||
|
||||
expected = [
|
||||
make_message('message', 'foo', 'test message'),
|
||||
make_message('pmessage', 'foo', 'test message', pattern='f*')
|
||||
]
|
||||
|
||||
assert message1 in expected
|
||||
assert message2 in expected
|
||||
assert message1 != message2
|
||||
|
||||
def test_channel_message_handler(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.subscribe(foo=self.message_handler)
|
||||
assert r.publish('foo', 'test message') == 1
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == make_message('message', 'foo', 'test message')
|
||||
|
||||
def test_pattern_message_handler(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.psubscribe(**{'f*': self.message_handler})
|
||||
assert r.publish('foo', 'test message') == 1
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == make_message('pmessage', 'foo', 'test message',
|
||||
pattern='f*')
|
||||
|
||||
def test_unicode_channel_message_handler(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
channel = u('uni') + unichr(4456) + u('code')
|
||||
channels = {channel: self.message_handler}
|
||||
p.subscribe(**channels)
|
||||
assert r.publish(channel, 'test message') == 1
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == make_message('message', channel, 'test message')
|
||||
|
||||
def test_unicode_pattern_message_handler(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
pattern = u('uni') + unichr(4456) + u('*')
|
||||
channel = u('uni') + unichr(4456) + u('code')
|
||||
p.psubscribe(**{pattern: self.message_handler})
|
||||
assert r.publish(channel, 'test message') == 1
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == make_message('pmessage', channel,
|
||||
'test message', pattern=pattern)
|
||||
|
||||
|
||||
class TestPubSubAutoDecoding(object):
|
||||
"These tests only validate that we get unicode values back"
|
||||
|
||||
channel = u('uni') + unichr(4456) + u('code')
|
||||
pattern = u('uni') + unichr(4456) + u('*')
|
||||
data = u('abc') + unichr(4458) + u('123')
|
||||
|
||||
def make_message(self, type, channel, data, pattern=None):
|
||||
return {
|
||||
'type': type,
|
||||
'channel': channel,
|
||||
'pattern': pattern,
|
||||
'data': data
|
||||
}
|
||||
|
||||
def setup_method(self, method):
|
||||
self.message = None
|
||||
|
||||
def message_handler(self, message):
|
||||
self.message = message
|
||||
|
||||
@pytest.fixture()
|
||||
def r(self, request):
|
||||
return _redis_client(request=request, decode_responses=True)
|
||||
|
||||
def test_channel_subscribe_unsubscribe(self, r):
|
||||
p = r.pubsub()
|
||||
p.subscribe(self.channel)
|
||||
assert wait_for_message(p) == self.make_message('subscribe',
|
||||
self.channel, 1)
|
||||
|
||||
p.unsubscribe(self.channel)
|
||||
assert wait_for_message(p) == self.make_message('unsubscribe',
|
||||
self.channel, 0)
|
||||
|
||||
def test_pattern_subscribe_unsubscribe(self, r):
|
||||
p = r.pubsub()
|
||||
p.psubscribe(self.pattern)
|
||||
assert wait_for_message(p) == self.make_message('psubscribe',
|
||||
self.pattern, 1)
|
||||
|
||||
p.punsubscribe(self.pattern)
|
||||
assert wait_for_message(p) == self.make_message('punsubscribe',
|
||||
self.pattern, 0)
|
||||
|
||||
def test_channel_publish(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.subscribe(self.channel)
|
||||
r.publish(self.channel, self.data)
|
||||
assert wait_for_message(p) == self.make_message('message',
|
||||
self.channel,
|
||||
self.data)
|
||||
|
||||
def test_pattern_publish(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.psubscribe(self.pattern)
|
||||
r.publish(self.channel, self.data)
|
||||
assert wait_for_message(p) == self.make_message('pmessage',
|
||||
self.channel,
|
||||
self.data,
|
||||
pattern=self.pattern)
|
||||
|
||||
def test_channel_message_handler(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.subscribe(**{self.channel: self.message_handler})
|
||||
r.publish(self.channel, self.data)
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == self.make_message('message', self.channel,
|
||||
self.data)
|
||||
|
||||
# test that we reconnected to the correct channel
|
||||
p.connection.disconnect()
|
||||
assert wait_for_message(p) is None # should reconnect
|
||||
new_data = self.data + u('new data')
|
||||
r.publish(self.channel, new_data)
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == self.make_message('message', self.channel,
|
||||
new_data)
|
||||
|
||||
def test_pattern_message_handler(self, r):
|
||||
p = r.pubsub(ignore_subscribe_messages=True)
|
||||
p.psubscribe(**{self.pattern: self.message_handler})
|
||||
r.publish(self.channel, self.data)
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == self.make_message('pmessage', self.channel,
|
||||
self.data,
|
||||
pattern=self.pattern)
|
||||
|
||||
# test that we reconnected to the correct pattern
|
||||
p.connection.disconnect()
|
||||
assert wait_for_message(p) is None # should reconnect
|
||||
new_data = self.data + u('new data')
|
||||
r.publish(self.channel, new_data)
|
||||
assert wait_for_message(p) is None
|
||||
assert self.message == self.make_message('pmessage', self.channel,
|
||||
new_data,
|
||||
pattern=self.pattern)
|
||||
|
||||
|
||||
class TestPubSubRedisDown(object):
|
||||
|
||||
def test_channel_subscribe(self, r):
|
||||
r = redis.Redis(host='localhost', port=6390)
|
||||
p = r.pubsub()
|
||||
with pytest.raises(ConnectionError):
|
||||
p.subscribe('foo')
|
|
@ -0,0 +1,82 @@
|
|||
from __future__ import with_statement
|
||||
import pytest
|
||||
|
||||
from redis import exceptions
|
||||
from redis._compat import b
|
||||
|
||||
|
||||
multiply_script = """
|
||||
local value = redis.call('GET', KEYS[1])
|
||||
value = tonumber(value)
|
||||
return value * ARGV[1]"""
|
||||
|
||||
|
||||
class TestScripting(object):
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_scripts(self, r):
|
||||
r.script_flush()
|
||||
|
||||
def test_eval(self, r):
|
||||
r.set('a', 2)
|
||||
# 2 * 3 == 6
|
||||
assert r.eval(multiply_script, 1, 'a', 3) == 6
|
||||
|
||||
def test_evalsha(self, r):
|
||||
r.set('a', 2)
|
||||
sha = r.script_load(multiply_script)
|
||||
# 2 * 3 == 6
|
||||
assert r.evalsha(sha, 1, 'a', 3) == 6
|
||||
|
||||
def test_evalsha_script_not_loaded(self, r):
|
||||
r.set('a', 2)
|
||||
sha = r.script_load(multiply_script)
|
||||
# remove the script from Redis's cache
|
||||
r.script_flush()
|
||||
with pytest.raises(exceptions.NoScriptError):
|
||||
r.evalsha(sha, 1, 'a', 3)
|
||||
|
||||
def test_script_loading(self, r):
|
||||
# get the sha, then clear the cache
|
||||
sha = r.script_load(multiply_script)
|
||||
r.script_flush()
|
||||
assert r.script_exists(sha) == [False]
|
||||
r.script_load(multiply_script)
|
||||
assert r.script_exists(sha) == [True]
|
||||
|
||||
def test_script_object(self, r):
|
||||
r.set('a', 2)
|
||||
multiply = r.register_script(multiply_script)
|
||||
assert not multiply.sha
|
||||
# test evalsha fail -> script load + retry
|
||||
assert multiply(keys=['a'], args=[3]) == 6
|
||||
assert multiply.sha
|
||||
assert r.script_exists(multiply.sha) == [True]
|
||||
# test first evalsha
|
||||
assert multiply(keys=['a'], args=[3]) == 6
|
||||
|
||||
def test_script_object_in_pipeline(self, r):
|
||||
multiply = r.register_script(multiply_script)
|
||||
assert not multiply.sha
|
||||
pipe = r.pipeline()
|
||||
pipe.set('a', 2)
|
||||
pipe.get('a')
|
||||
multiply(keys=['a'], args=[3], client=pipe)
|
||||
# even though the pipeline wasn't executed yet, we made sure the
|
||||
# script was loaded and got a valid sha
|
||||
assert multiply.sha
|
||||
assert r.script_exists(multiply.sha) == [True]
|
||||
# [SET worked, GET 'a', result of multiple script]
|
||||
assert pipe.execute() == [True, b('2'), 6]
|
||||
|
||||
# purge the script from redis's cache and re-run the pipeline
|
||||
# the multiply script object knows it's sha, so it shouldn't get
|
||||
# reloaded until pipe.execute()
|
||||
r.script_flush()
|
||||
pipe = r.pipeline()
|
||||
pipe.set('a', 2)
|
||||
pipe.get('a')
|
||||
assert multiply.sha
|
||||
multiply(keys=['a'], args=[3], client=pipe)
|
||||
assert r.script_exists(multiply.sha) == [False]
|
||||
# [SET worked, GET 'a', result of multiple script]
|
||||
assert pipe.execute() == [True, b('2'), 6]
|
|
@ -0,0 +1,173 @@
|
|||
from __future__ import with_statement
|
||||
import pytest
|
||||
|
||||
from redis import exceptions
|
||||
from redis.sentinel import (Sentinel, SentinelConnectionPool,
|
||||
MasterNotFoundError, SlaveNotFoundError)
|
||||
from redis._compat import next
|
||||
import redis.sentinel
|
||||
|
||||
|
||||
class SentinelTestClient(object):
|
||||
def __init__(self, cluster, id):
|
||||
self.cluster = cluster
|
||||
self.id = id
|
||||
|
||||
def sentinel_masters(self):
|
||||
self.cluster.connection_error_if_down(self)
|
||||
return {self.cluster.service_name: self.cluster.master}
|
||||
|
||||
def sentinel_slaves(self, master_name):
|
||||
self.cluster.connection_error_if_down(self)
|
||||
if master_name != self.cluster.service_name:
|
||||
return []
|
||||
return self.cluster.slaves
|
||||
|
||||
|
||||
class SentinelTestCluster(object):
|
||||
def __init__(self, service_name='mymaster', ip='127.0.0.1', port=6379):
|
||||
self.clients = {}
|
||||
self.master = {
|
||||
'ip': ip,
|
||||
'port': port,
|
||||
'is_master': True,
|
||||
'is_sdown': False,
|
||||
'is_odown': False,
|
||||
'num-other-sentinels': 0,
|
||||
}
|
||||
self.service_name = service_name
|
||||
self.slaves = []
|
||||
self.nodes_down = set()
|
||||
|
||||
def connection_error_if_down(self, node):
|
||||
if node.id in self.nodes_down:
|
||||
raise exceptions.ConnectionError
|
||||
|
||||
def client(self, host, port, **kwargs):
|
||||
return SentinelTestClient(self, (host, port))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def cluster(request):
|
||||
def teardown():
|
||||
redis.sentinel.StrictRedis = saved_StrictRedis
|
||||
cluster = SentinelTestCluster()
|
||||
saved_StrictRedis = redis.sentinel.StrictRedis
|
||||
redis.sentinel.StrictRedis = cluster.client
|
||||
request.addfinalizer(teardown)
|
||||
return cluster
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sentinel(request, cluster):
|
||||
return Sentinel([('foo', 26379), ('bar', 26379)])
|
||||
|
||||
|
||||
def test_discover_master(sentinel):
|
||||
address = sentinel.discover_master('mymaster')
|
||||
assert address == ('127.0.0.1', 6379)
|
||||
|
||||
|
||||
def test_discover_master_error(sentinel):
|
||||
with pytest.raises(MasterNotFoundError):
|
||||
sentinel.discover_master('xxx')
|
||||
|
||||
|
||||
def test_discover_master_sentinel_down(cluster, sentinel):
|
||||
# Put first sentinel 'foo' down
|
||||
cluster.nodes_down.add(('foo', 26379))
|
||||
address = sentinel.discover_master('mymaster')
|
||||
assert address == ('127.0.0.1', 6379)
|
||||
# 'bar' is now first sentinel
|
||||
assert sentinel.sentinels[0].id == ('bar', 26379)
|
||||
|
||||
|
||||
def test_master_min_other_sentinels(cluster):
|
||||
sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1)
|
||||
# min_other_sentinels
|
||||
with pytest.raises(MasterNotFoundError):
|
||||
sentinel.discover_master('mymaster')
|
||||
cluster.master['num-other-sentinels'] = 2
|
||||
address = sentinel.discover_master('mymaster')
|
||||
assert address == ('127.0.0.1', 6379)
|
||||
|
||||
|
||||
def test_master_odown(cluster, sentinel):
|
||||
cluster.master['is_odown'] = True
|
||||
with pytest.raises(MasterNotFoundError):
|
||||
sentinel.discover_master('mymaster')
|
||||
|
||||
|
||||
def test_master_sdown(cluster, sentinel):
|
||||
cluster.master['is_sdown'] = True
|
||||
with pytest.raises(MasterNotFoundError):
|
||||
sentinel.discover_master('mymaster')
|
||||
|
||||
|
||||
def test_discover_slaves(cluster, sentinel):
|
||||
assert sentinel.discover_slaves('mymaster') == []
|
||||
|
||||
cluster.slaves = [
|
||||
{'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False},
|
||||
{'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False},
|
||||
]
|
||||
assert sentinel.discover_slaves('mymaster') == [
|
||||
('slave0', 1234), ('slave1', 1234)]
|
||||
|
||||
# slave0 -> ODOWN
|
||||
cluster.slaves[0]['is_odown'] = True
|
||||
assert sentinel.discover_slaves('mymaster') == [
|
||||
('slave1', 1234)]
|
||||
|
||||
# slave1 -> SDOWN
|
||||
cluster.slaves[1]['is_sdown'] = True
|
||||
assert sentinel.discover_slaves('mymaster') == []
|
||||
|
||||
cluster.slaves[0]['is_odown'] = False
|
||||
cluster.slaves[1]['is_sdown'] = False
|
||||
|
||||
# node0 -> DOWN
|
||||
cluster.nodes_down.add(('foo', 26379))
|
||||
assert sentinel.discover_slaves('mymaster') == [
|
||||
('slave0', 1234), ('slave1', 1234)]
|
||||
|
||||
|
||||
def test_master_for(cluster, sentinel):
|
||||
master = sentinel.master_for('mymaster', db=9)
|
||||
assert master.ping()
|
||||
assert master.connection_pool.master_address == ('127.0.0.1', 6379)
|
||||
|
||||
# Use internal connection check
|
||||
master = sentinel.master_for('mymaster', db=9, check_connection=True)
|
||||
assert master.ping()
|
||||
|
||||
|
||||
def test_slave_for(cluster, sentinel):
|
||||
cluster.slaves = [
|
||||
{'ip': '127.0.0.1', 'port': 6379,
|
||||
'is_odown': False, 'is_sdown': False},
|
||||
]
|
||||
slave = sentinel.slave_for('mymaster', db=9)
|
||||
assert slave.ping()
|
||||
|
||||
|
||||
def test_slave_for_slave_not_found_error(cluster, sentinel):
|
||||
cluster.master['is_odown'] = True
|
||||
slave = sentinel.slave_for('mymaster', db=9)
|
||||
with pytest.raises(SlaveNotFoundError):
|
||||
slave.ping()
|
||||
|
||||
|
||||
def test_slave_round_robin(cluster, sentinel):
|
||||
cluster.slaves = [
|
||||
{'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False},
|
||||
{'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False},
|
||||
]
|
||||
pool = SentinelConnectionPool('mymaster', sentinel)
|
||||
rotator = pool.rotate_slaves()
|
||||
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
|
||||
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
|
||||
# Fallback to master
|
||||
assert next(rotator) == ('127.0.0.1', 6379)
|
||||
with pytest.raises(SlaveNotFoundError):
|
||||
next(rotator)
|
|
@ -0,0 +1,362 @@
|
|||
--- refer from openresty redis lib
|
||||
|
||||
local sub = string.sub
|
||||
local byte = string.byte
|
||||
local tcp = ngx.socket.tcp
|
||||
local concat = table.concat
|
||||
local null = ngx.null
|
||||
local pairs = pairs
|
||||
local unpack = unpack
|
||||
local setmetatable = setmetatable
|
||||
local tonumber = tonumber
|
||||
local error = error
|
||||
|
||||
|
||||
local ok, new_tab = pcall(require, "table.new")
|
||||
if not ok then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
|
||||
local _M = new_tab(0, 155)
|
||||
_M._VERSION = '0.01'
|
||||
|
||||
|
||||
local commands = {
|
||||
--[[kv]]
|
||||
"decr",
|
||||
"decrby",
|
||||
"del",
|
||||
"exists",
|
||||
"get",
|
||||
"getset",
|
||||
"incr",
|
||||
"incrby",
|
||||
"mget",
|
||||
"mset",
|
||||
"set",
|
||||
"setnx",
|
||||
|
||||
--[[hash]]
|
||||
"hdel",
|
||||
"hexists",
|
||||
"hget",
|
||||
"hgetall",
|
||||
"hincrby",
|
||||
"hkeys",
|
||||
"hlen",
|
||||
"hmget",
|
||||
--[["hmset",]]
|
||||
"hset",
|
||||
"hvals",
|
||||
"hclear",
|
||||
|
||||
--[[list]]
|
||||
"lindex",
|
||||
"llen",
|
||||
"lpop",
|
||||
"lrange",
|
||||
"lpush",
|
||||
"rpop",
|
||||
"rpush",
|
||||
"lclear",
|
||||
|
||||
--[[zset]]
|
||||
"zadd",
|
||||
"zcard",
|
||||
"zcount",
|
||||
"zincrby",
|
||||
"zrange",
|
||||
"zrangebyscore",
|
||||
"zrank",
|
||||
"zrem",
|
||||
"zremrangebyrank",
|
||||
"zremrangebyscore",
|
||||
"zrevrange",
|
||||
"zrevrank",
|
||||
"zrevrangebyscore",
|
||||
"zscore",
|
||||
"zclear",
|
||||
|
||||
--[[server]]
|
||||
"ping",
|
||||
"echo",
|
||||
"select"
|
||||
}
|
||||
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local sock, err = tcp()
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
return setmetatable({ sock = sock }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, timeout)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:settimeout(timeout)
|
||||
end
|
||||
|
||||
|
||||
function _M.connect(self, ...)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:connect(...)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_keepalive(self, ...)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:setkeepalive(...)
|
||||
end
|
||||
|
||||
|
||||
function _M.get_reused_times(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:getreusedtimes()
|
||||
end
|
||||
|
||||
|
||||
local function close(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:close()
|
||||
end
|
||||
_M.close = close
|
||||
|
||||
|
||||
local function _read_reply(self, sock)
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local prefix = byte(line)
|
||||
|
||||
if prefix == 36 then -- char '$'
|
||||
-- print("bulk reply")
|
||||
|
||||
local size = tonumber(sub(line, 2))
|
||||
if size < 0 then
|
||||
return null
|
||||
end
|
||||
|
||||
local data, err = sock:receive(size)
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local dummy, err = sock:receive(2) -- ignore CRLF
|
||||
if not dummy then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return data
|
||||
|
||||
elseif prefix == 43 then -- char '+'
|
||||
-- print("status reply")
|
||||
|
||||
return sub(line, 2)
|
||||
|
||||
elseif prefix == 42 then -- char '*'
|
||||
local n = tonumber(sub(line, 2))
|
||||
|
||||
-- print("multi-bulk reply: ", n)
|
||||
if n < 0 then
|
||||
return null
|
||||
end
|
||||
|
||||
local vals = new_tab(n, 0);
|
||||
local nvals = 0
|
||||
for i = 1, n do
|
||||
local res, err = _read_reply(self, sock)
|
||||
if res then
|
||||
nvals = nvals + 1
|
||||
vals[nvals] = res
|
||||
|
||||
elseif res == nil then
|
||||
return nil, err
|
||||
|
||||
else
|
||||
-- be a valid redis error value
|
||||
nvals = nvals + 1
|
||||
vals[nvals] = {false, err}
|
||||
end
|
||||
end
|
||||
|
||||
return vals
|
||||
|
||||
elseif prefix == 58 then -- char ':'
|
||||
-- print("integer reply")
|
||||
return tonumber(sub(line, 2))
|
||||
|
||||
elseif prefix == 45 then -- char '-'
|
||||
-- print("error reply: ", n)
|
||||
|
||||
return false, sub(line, 2)
|
||||
|
||||
else
|
||||
return nil, "unkown prefix: \"" .. prefix .. "\""
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function _gen_req(args)
|
||||
local nargs = #args
|
||||
|
||||
local req = new_tab(nargs + 1, 0)
|
||||
req[1] = "*" .. nargs .. "\r\n"
|
||||
local nbits = 1
|
||||
|
||||
for i = 1, nargs do
|
||||
local arg = args[i]
|
||||
nbits = nbits + 1
|
||||
|
||||
if not arg then
|
||||
req[nbits] = "$-1\r\n"
|
||||
|
||||
else
|
||||
if type(arg) ~= "string" then
|
||||
arg = tostring(arg)
|
||||
end
|
||||
req[nbits] = "$" .. #arg .. "\r\n" .. arg .. "\r\n"
|
||||
end
|
||||
end
|
||||
|
||||
-- it is faster to do string concatenation on the Lua land
|
||||
return concat(req)
|
||||
end
|
||||
|
||||
|
||||
local function _do_cmd(self, ...)
|
||||
local args = {...}
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local req = _gen_req(args)
|
||||
|
||||
local reqs = self._reqs
|
||||
if reqs then
|
||||
reqs[#reqs + 1] = req
|
||||
return
|
||||
end
|
||||
|
||||
-- print("request: ", table.concat(req))
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return _read_reply(self, sock)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
function _M.read_reply(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local res, err = _read_reply(self, sock)
|
||||
|
||||
return res, err
|
||||
end
|
||||
|
||||
|
||||
for i = 1, #commands do
|
||||
local cmd = commands[i]
|
||||
|
||||
_M[cmd] =
|
||||
function (self, ...)
|
||||
return _do_cmd(self, cmd, ...)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function _M.hmset(self, hashname, ...)
|
||||
local args = {...}
|
||||
if #args == 1 then
|
||||
local t = args[1]
|
||||
|
||||
local n = 0
|
||||
for k, v in pairs(t) do
|
||||
n = n + 2
|
||||
end
|
||||
|
||||
local array = new_tab(n, 0)
|
||||
|
||||
local i = 0
|
||||
for k, v in pairs(t) do
|
||||
array[i + 1] = k
|
||||
array[i + 2] = v
|
||||
i = i + 2
|
||||
end
|
||||
-- print("key", hashname)
|
||||
return _do_cmd(self, "hmset", hashname, unpack(array))
|
||||
end
|
||||
|
||||
-- backwards compatibility
|
||||
return _do_cmd(self, "hmset", hashname, ...)
|
||||
end
|
||||
|
||||
|
||||
function _M.array_to_hash(self, t)
|
||||
local n = #t
|
||||
-- print("n = ", n)
|
||||
local h = new_tab(0, n / 2)
|
||||
for i = 1, n, 2 do
|
||||
h[t[i]] = t[i + 1]
|
||||
end
|
||||
return h
|
||||
end
|
||||
|
||||
|
||||
function _M.add_commands(...)
|
||||
local cmds = {...}
|
||||
for i = 1, #cmds do
|
||||
local cmd = cmds[i]
|
||||
_M[cmd] =
|
||||
function (self, ...)
|
||||
return _do_cmd(self, cmd, ...)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return _M
|
|
@ -3,7 +3,7 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"github.com/siddontang/ledisdb/client/go/redis"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
"compression": false,
|
||||
"block_size": 32768,
|
||||
"write_buffer_size": 67108864,
|
||||
"cache_size": 524288000
|
||||
"cache_size": 524288000,
|
||||
"max_open_files":1024
|
||||
}
|
||||
},
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/siddontang/go-log/log"
|
||||
"github.com/siddontang/ledisdb/log"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
@ -73,7 +73,9 @@ func (l *Ledis) Dump(w io.Writer) error {
|
|||
return err
|
||||
}
|
||||
|
||||
it := sp.Iterator(nil, nil, leveldb.RangeClose, 0, -1)
|
||||
it := sp.NewIterator()
|
||||
it.SeekToFirst()
|
||||
|
||||
var key []byte
|
||||
var value []byte
|
||||
for ; it.Valid(); it.Next() {
|
||||
|
|
|
@ -2,7 +2,7 @@ package ledis
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
@ -59,7 +59,7 @@ func TestDump(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1)
|
||||
it := master.ldb.RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
key := it.Key()
|
||||
value := it.Value()
|
||||
|
|
|
@ -3,8 +3,8 @@ package ledis
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/go-log/log"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"github.com/siddontang/ledisdb/log"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -46,6 +46,7 @@ type Ledis struct {
|
|||
binlog *BinLog
|
||||
|
||||
quit chan struct{}
|
||||
jobs *sync.WaitGroup
|
||||
}
|
||||
|
||||
func Open(configJson json.RawMessage) (*Ledis, error) {
|
||||
|
@ -75,6 +76,7 @@ func OpenWithConfig(cfg *Config) (*Ledis, error) {
|
|||
l := new(Ledis)
|
||||
|
||||
l.quit = make(chan struct{})
|
||||
l.jobs = new(sync.WaitGroup)
|
||||
|
||||
l.ldb = ldb
|
||||
|
||||
|
@ -118,6 +120,7 @@ func newDB(l *Ledis, index uint8) *DB {
|
|||
|
||||
func (l *Ledis) Close() {
|
||||
close(l.quit)
|
||||
l.jobs.Wait()
|
||||
|
||||
l.ldb.Close()
|
||||
|
||||
|
@ -156,19 +159,23 @@ func (l *Ledis) activeExpireCycle() {
|
|||
executors[i] = db.newEliminator()
|
||||
}
|
||||
|
||||
l.jobs.Add(1)
|
||||
go func() {
|
||||
tick := time.NewTicker(1 * time.Second)
|
||||
for {
|
||||
end := false
|
||||
for !end {
|
||||
select {
|
||||
case <-tick.C:
|
||||
for _, eli := range executors {
|
||||
eli.active()
|
||||
}
|
||||
case <-l.quit:
|
||||
end = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tick.Stop()
|
||||
l.jobs.Done()
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/siddontang/go-log/log"
|
||||
"github.com/siddontang/ledisdb/log"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
|
|
@ -3,14 +3,14 @@ package ledis
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func checkLedisEqual(master *Ledis, slave *Ledis) error {
|
||||
it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1)
|
||||
it := master.ldb.RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
key := it.Key()
|
||||
value := it.Value()
|
||||
|
|
|
@ -3,7 +3,7 @@ package ledis
|
|||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -132,7 +132,7 @@ func (db *DB) hDelete(t *tx, key []byte) int64 {
|
|||
stop := db.hEncodeStopKey(key)
|
||||
|
||||
var num int64 = 0
|
||||
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
t.Delete(it.Key())
|
||||
num++
|
||||
|
@ -232,10 +232,11 @@ func (db *DB) HMset(key []byte, args ...FVPair) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) {
|
||||
func (db *DB) HMget(key []byte, args ...[]byte) ([]interface{}, error) {
|
||||
var ek []byte
|
||||
var v []byte
|
||||
var err error
|
||||
|
||||
it := db.db.NewIterator()
|
||||
defer it.Close()
|
||||
|
||||
r := make([]interface{}, len(args))
|
||||
for i := 0; i < len(args); i++ {
|
||||
|
@ -245,17 +246,13 @@ func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) {
|
|||
|
||||
ek = db.hEncodeHashKey(key, args[i])
|
||||
|
||||
if v, err = db.db.Get(ek); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r[i] = v
|
||||
r[i] = it.Find(ek)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (db *DB) HDel(key []byte, args [][]byte) (int64, error) {
|
||||
func (db *DB) HDel(key []byte, args ...[]byte) (int64, error) {
|
||||
t := db.hashTx
|
||||
|
||||
var ek []byte
|
||||
|
@ -265,6 +262,9 @@ func (db *DB) HDel(key []byte, args [][]byte) (int64, error) {
|
|||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
it := db.db.NewIterator()
|
||||
defer it.Close()
|
||||
|
||||
var num int64 = 0
|
||||
for i := 0; i < len(args); i++ {
|
||||
if err := checkHashKFSize(key, args[i]); err != nil {
|
||||
|
@ -273,9 +273,8 @@ func (db *DB) HDel(key []byte, args [][]byte) (int64, error) {
|
|||
|
||||
ek = db.hEncodeHashKey(key, args[i])
|
||||
|
||||
if v, err = db.db.Get(ek); err != nil {
|
||||
return 0, err
|
||||
} else if v == nil {
|
||||
v = it.Find(ek)
|
||||
if v == nil {
|
||||
continue
|
||||
} else {
|
||||
num++
|
||||
|
@ -355,7 +354,7 @@ func (db *DB) HGetAll(key []byte) ([]interface{}, error) {
|
|||
|
||||
v := make([]interface{}, 0, 16)
|
||||
|
||||
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
_, k, err := db.hDecodeHashKey(it.Key())
|
||||
if err != nil {
|
||||
|
@ -380,7 +379,7 @@ func (db *DB) HKeys(key []byte) ([]interface{}, error) {
|
|||
|
||||
v := make([]interface{}, 0, 16)
|
||||
|
||||
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
_, k, err := db.hDecodeHashKey(it.Key())
|
||||
if err != nil {
|
||||
|
@ -404,7 +403,7 @@ func (db *DB) HValues(key []byte) ([]interface{}, error) {
|
|||
|
||||
v := make([]interface{}, 0, 16)
|
||||
|
||||
it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
v = append(v, it.Value())
|
||||
}
|
||||
|
@ -443,7 +442,7 @@ func (db *DB) hFlush() (drop int64, err error) {
|
|||
maxKey[0] = db.index
|
||||
maxKey[1] = hSizeType + 1
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
t.Delete(it.Key())
|
||||
drop++
|
||||
|
@ -485,7 +484,7 @@ func (db *DB) HScan(key []byte, field []byte, count int, inclusive bool) ([]FVPa
|
|||
rangeType = leveldb.RangeOpen
|
||||
}
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, rangeType, 0, count)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count)
|
||||
for ; it.Valid(); it.Next() {
|
||||
if _, f, err := db.hDecodeHashKey(it.Key()); err != nil {
|
||||
continue
|
||||
|
|
|
@ -32,11 +32,28 @@ func TestDBHash(t *testing.T) {
|
|||
|
||||
key := []byte("testdb_hash_a")
|
||||
|
||||
if n, err := db.HSet(key, []byte("a"), []byte("hello world")); err != nil {
|
||||
if n, err := db.HSet(key, []byte("a"), []byte("hello world 1")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
if n, err := db.HSet(key, []byte("b"), []byte("hello world 2")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
ay, _ := db.HMget(key, []byte("a"), []byte("b"))
|
||||
|
||||
if v1, _ := ay[0].([]byte); string(v1) != "hello world 1" {
|
||||
t.Fatal(string(v1))
|
||||
}
|
||||
|
||||
if v2, _ := ay[1].([]byte); string(v2) != "hello world 2" {
|
||||
t.Fatal(string(v2))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDBHScan(t *testing.T) {
|
||||
|
|
|
@ -2,7 +2,7 @@ package ledis
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -204,18 +204,15 @@ func (db *DB) IncryBy(key []byte, increment int64) (int64, error) {
|
|||
func (db *DB) MGet(keys ...[]byte) ([]interface{}, error) {
|
||||
values := make([]interface{}, len(keys))
|
||||
|
||||
var err error
|
||||
var value []byte
|
||||
it := db.db.NewIterator()
|
||||
defer it.Close()
|
||||
|
||||
for i := range keys {
|
||||
if err := checkKeySize(keys[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value, err = db.db.Get(db.encodeKVKey(keys[i])); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values[i] = value
|
||||
values[i] = it.Find(db.encodeKVKey(keys[i]))
|
||||
}
|
||||
|
||||
return values, nil
|
||||
|
@ -319,7 +316,7 @@ func (db *DB) flush() (drop int64, err error) {
|
|||
minKey := db.encodeKVMinKey()
|
||||
maxKey := db.encodeKVMaxKey()
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
t.Delete(it.Key())
|
||||
drop++
|
||||
|
@ -362,7 +359,7 @@ func (db *DB) Scan(key []byte, count int, inclusive bool) ([]KVPair, error) {
|
|||
rangeType = leveldb.RangeOpen
|
||||
}
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, rangeType, 0, count)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count)
|
||||
for ; it.Valid(); it.Next() {
|
||||
if key, err := db.decodeKVKey(it.Key()); err != nil {
|
||||
continue
|
||||
|
|
|
@ -19,11 +19,28 @@ func TestKVCodec(t *testing.T) {
|
|||
func TestDBKV(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
key := []byte("testdb_kv_a")
|
||||
key1 := []byte("testdb_kv_a")
|
||||
|
||||
if err := db.Set(key, []byte("hello world")); err != nil {
|
||||
if err := db.Set(key1, []byte("hello world 1")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key2 := []byte("testdb_kv_b")
|
||||
|
||||
if err := db.Set(key2, []byte("hello world 2")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ay, _ := db.MGet(key1, key2)
|
||||
|
||||
if v1, _ := ay[0].([]byte); string(v1) != "hello world 1" {
|
||||
t.Fatal(string(v1))
|
||||
}
|
||||
|
||||
if v2, _ := ay[1].([]byte); string(v2) != "hello world 2" {
|
||||
t.Fatal(string(v2))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDBScan(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package ledis
|
|||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -200,7 +200,7 @@ func (db *DB) lDelete(t *tx, key []byte) int64 {
|
|||
startKey := db.lEncodeListKey(key, headSeq)
|
||||
stopKey := db.lEncodeListKey(key, tailSeq)
|
||||
|
||||
it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1)
|
||||
it := db.db.RangeLimitIterator(startKey, stopKey, leveldb.RangeClose, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
t.Delete(it.Key())
|
||||
num++
|
||||
|
@ -361,7 +361,7 @@ func (db *DB) LRange(key []byte, start int32, stop int32) ([]interface{}, error)
|
|||
|
||||
startKey := db.lEncodeListKey(key, startSeq)
|
||||
stopKey := db.lEncodeListKey(key, stopSeq)
|
||||
it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1)
|
||||
it := db.db.RangeLimitIterator(startKey, stopKey, leveldb.RangeClose, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
v = append(v, it.Value())
|
||||
}
|
||||
|
@ -408,7 +408,7 @@ func (db *DB) lFlush() (drop int64, err error) {
|
|||
maxKey[0] = db.index
|
||||
maxKey[1] = lMetaType + 1
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
t.Delete(it.Key())
|
||||
drop++
|
||||
|
|
|
@ -3,7 +3,7 @@ package ledis
|
|||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -119,7 +119,7 @@ func (db *DB) expFlush(t *tx, expType byte) (err error) {
|
|||
maxKey[0] = db.index
|
||||
maxKey[1] = expMetaType + 1
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
t.Delete(it.Key())
|
||||
drop++
|
||||
|
@ -173,7 +173,7 @@ func (eli *elimination) active() {
|
|||
continue
|
||||
}
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
for it.Valid() {
|
||||
for i := 1; i < 512 && it.Valid(); i++ {
|
||||
expKeys = append(expKeys, it.Key(), it.Value())
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -434,7 +434,7 @@ func (db *DB) ZCount(key []byte, min int64, max int64) (int64, error) {
|
|||
|
||||
rangeType := leveldb.RangeROpen
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, rangeType, 0, -1)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, -1)
|
||||
var n int64 = 0
|
||||
for ; it.Valid(); it.Next() {
|
||||
n++
|
||||
|
@ -459,16 +459,16 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) {
|
|||
if s, err := Int64(v, err); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
var it *leveldb.Iterator
|
||||
var it *leveldb.RangeLimitIterator
|
||||
|
||||
sk := db.zEncodeScoreKey(key, member, s)
|
||||
|
||||
if !reverse {
|
||||
minKey := db.zEncodeStartScoreKey(key, MinScore)
|
||||
it = db.db.Iterator(minKey, sk, leveldb.RangeClose, 0, -1)
|
||||
it = db.db.RangeLimitIterator(minKey, sk, leveldb.RangeClose, 0, -1)
|
||||
} else {
|
||||
maxKey := db.zEncodeStopScoreKey(key, MaxScore)
|
||||
it = db.db.RevIterator(sk, maxKey, leveldb.RangeClose, 0, -1)
|
||||
it = db.db.RevRangeLimitIterator(sk, maxKey, leveldb.RangeClose, 0, -1)
|
||||
}
|
||||
|
||||
var lastKey []byte = nil
|
||||
|
@ -492,14 +492,14 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) {
|
|||
return -1, nil
|
||||
}
|
||||
|
||||
func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, reverse bool) *leveldb.Iterator {
|
||||
func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, reverse bool) *leveldb.RangeLimitIterator {
|
||||
minKey := db.zEncodeStartScoreKey(key, min)
|
||||
maxKey := db.zEncodeStopScoreKey(key, max)
|
||||
|
||||
if !reverse {
|
||||
return db.db.Iterator(minKey, maxKey, leveldb.RangeClose, offset, limit)
|
||||
return db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeClose, offset, limit)
|
||||
} else {
|
||||
return db.db.RevIterator(minKey, maxKey, leveldb.RangeClose, offset, limit)
|
||||
return db.db.RevRangeLimitIterator(minKey, maxKey, leveldb.RangeClose, offset, limit)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -567,7 +567,7 @@ func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset i
|
|||
}
|
||||
v := make([]interface{}, 0, nv)
|
||||
|
||||
var it *leveldb.Iterator
|
||||
var it *leveldb.RangeLimitIterator
|
||||
|
||||
//if reverse and offset is 0, limit < 0, we may use forward iterator then reverse
|
||||
//because leveldb iterator prev is slower than next
|
||||
|
@ -745,7 +745,7 @@ func (db *DB) zFlush() (drop int64, err error) {
|
|||
maxKey[0] = db.index
|
||||
maxKey[1] = zScoreType + 1
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
t.Delete(it.Key())
|
||||
drop++
|
||||
|
@ -788,7 +788,7 @@ func (db *DB) ZScan(key []byte, member []byte, count int, inclusive bool) ([]Sco
|
|||
rangeType = leveldb.RangeOpen
|
||||
}
|
||||
|
||||
it := db.db.Iterator(minKey, maxKey, rangeType, 0, count)
|
||||
it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count)
|
||||
for ; it.Valid(); it.Next() {
|
||||
if _, m, err := db.zDecodeSetKey(it.Key()); err != nil {
|
||||
continue
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package ledis
|
||||
|
||||
import (
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
package leveldb
|
||||
|
||||
// #cgo LDFLAGS: -lleveldb
|
||||
// #include "leveldb/c.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type WriteBatch struct {
|
||||
db *DB
|
||||
wbatch *C.leveldb_writebatch_t
|
||||
}
|
||||
|
||||
func (w *WriteBatch) Close() {
|
||||
C.leveldb_writebatch_destroy(w.wbatch)
|
||||
}
|
||||
|
||||
func (w *WriteBatch) Put(key, value []byte) {
|
||||
var k, v *C.char
|
||||
if len(key) != 0 {
|
||||
k = (*C.char)(unsafe.Pointer(&key[0]))
|
||||
}
|
||||
if len(value) != 0 {
|
||||
v = (*C.char)(unsafe.Pointer(&value[0]))
|
||||
}
|
||||
|
||||
lenk := len(key)
|
||||
lenv := len(value)
|
||||
|
||||
C.leveldb_writebatch_put(w.wbatch, k, C.size_t(lenk), v, C.size_t(lenv))
|
||||
}
|
||||
|
||||
func (w *WriteBatch) Delete(key []byte) {
|
||||
C.leveldb_writebatch_delete(w.wbatch,
|
||||
(*C.char)(unsafe.Pointer(&key[0])), C.size_t(len(key)))
|
||||
}
|
||||
|
||||
func (w *WriteBatch) Commit() error {
|
||||
return w.commit(w.db.writeOpts)
|
||||
}
|
||||
|
||||
func (w *WriteBatch) SyncCommit() error {
|
||||
return w.commit(w.db.syncWriteOpts)
|
||||
}
|
||||
|
||||
func (w *WriteBatch) Rollback() {
|
||||
C.leveldb_writebatch_clear(w.wbatch)
|
||||
}
|
||||
|
||||
func (w *WriteBatch) commit(wb *WriteOptions) error {
|
||||
var errStr *C.char
|
||||
C.leveldb_write(w.db.db, wb.Opt, w.wbatch, &errStr)
|
||||
if errStr != nil {
|
||||
return saveError(errStr)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package leveldb
|
||||
|
||||
// #cgo LDFLAGS: -lleveldb
|
||||
// #include <stdint.h>
|
||||
// #include "leveldb/c.h"
|
||||
import "C"
|
||||
|
||||
type Cache struct {
|
||||
Cache *C.leveldb_cache_t
|
||||
}
|
||||
|
||||
func NewLRUCache(capacity int) *Cache {
|
||||
return &Cache{C.leveldb_cache_create_lru(C.size_t(capacity))}
|
||||
}
|
||||
|
||||
func (c *Cache) Close() {
|
||||
C.leveldb_cache_destroy(c.Cache)
|
||||
}
|
|
@ -0,0 +1,328 @@
|
|||
package leveldb
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lleveldb
|
||||
#include <leveldb/c.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const defaultFilterBits int = 10
|
||||
|
||||
type Config struct {
|
||||
Path string `json:"path"`
|
||||
|
||||
Compression bool `json:"compression"`
|
||||
BlockSize int `json:"block_size"`
|
||||
WriteBufferSize int `json:"write_buffer_size"`
|
||||
CacheSize int `json:"cache_size"`
|
||||
MaxOpenFiles int `json:"max_open_files"`
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
cfg *Config
|
||||
|
||||
db *C.leveldb_t
|
||||
|
||||
opts *Options
|
||||
|
||||
//for default read and write options
|
||||
readOpts *ReadOptions
|
||||
writeOpts *WriteOptions
|
||||
iteratorOpts *ReadOptions
|
||||
|
||||
syncWriteOpts *WriteOptions
|
||||
|
||||
cache *Cache
|
||||
|
||||
filter *FilterPolicy
|
||||
}
|
||||
|
||||
func Open(configJson json.RawMessage) (*DB, error) {
|
||||
cfg := new(Config)
|
||||
err := json.Unmarshal(configJson, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return OpenWithConfig(cfg)
|
||||
}
|
||||
|
||||
func OpenWithConfig(cfg *Config) (*DB, error) {
|
||||
if err := os.MkdirAll(cfg.Path, os.ModePerm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := new(DB)
|
||||
db.cfg = cfg
|
||||
|
||||
if err := db.open(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (db *DB) open() error {
|
||||
db.opts = db.initOptions(db.cfg)
|
||||
|
||||
db.readOpts = NewReadOptions()
|
||||
db.writeOpts = NewWriteOptions()
|
||||
|
||||
db.iteratorOpts = NewReadOptions()
|
||||
db.iteratorOpts.SetFillCache(false)
|
||||
|
||||
db.syncWriteOpts = NewWriteOptions()
|
||||
db.syncWriteOpts.SetSync(true)
|
||||
|
||||
var errStr *C.char
|
||||
ldbname := C.CString(db.cfg.Path)
|
||||
defer C.leveldb_free(unsafe.Pointer(ldbname))
|
||||
|
||||
db.db = C.leveldb_open(db.opts.Opt, ldbname, &errStr)
|
||||
if errStr != nil {
|
||||
return saveError(errStr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) initOptions(cfg *Config) *Options {
|
||||
opts := NewOptions()
|
||||
|
||||
opts.SetCreateIfMissing(true)
|
||||
|
||||
if cfg.CacheSize <= 0 {
|
||||
cfg.CacheSize = 4 * 1024 * 1024
|
||||
}
|
||||
|
||||
db.cache = NewLRUCache(cfg.CacheSize)
|
||||
opts.SetCache(db.cache)
|
||||
|
||||
//we must use bloomfilter
|
||||
db.filter = NewBloomFilter(defaultFilterBits)
|
||||
opts.SetFilterPolicy(db.filter)
|
||||
|
||||
if !cfg.Compression {
|
||||
opts.SetCompression(NoCompression)
|
||||
}
|
||||
|
||||
if cfg.BlockSize <= 0 {
|
||||
cfg.BlockSize = 4 * 1024
|
||||
}
|
||||
|
||||
opts.SetBlockSize(cfg.BlockSize)
|
||||
|
||||
if cfg.WriteBufferSize <= 0 {
|
||||
cfg.WriteBufferSize = 4 * 1024 * 1024
|
||||
}
|
||||
|
||||
opts.SetWriteBufferSize(cfg.WriteBufferSize)
|
||||
|
||||
if cfg.MaxOpenFiles < 1024 {
|
||||
cfg.MaxOpenFiles = 1024
|
||||
}
|
||||
|
||||
opts.SetMaxOpenFiles(cfg.MaxOpenFiles)
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func (db *DB) Close() {
|
||||
C.leveldb_close(db.db)
|
||||
db.db = nil
|
||||
|
||||
db.opts.Close()
|
||||
|
||||
if db.cache != nil {
|
||||
db.cache.Close()
|
||||
}
|
||||
|
||||
if db.filter != nil {
|
||||
db.filter.Close()
|
||||
}
|
||||
|
||||
db.readOpts.Close()
|
||||
db.writeOpts.Close()
|
||||
db.iteratorOpts.Close()
|
||||
db.syncWriteOpts.Close()
|
||||
}
|
||||
|
||||
func (db *DB) Destroy() error {
|
||||
path := db.cfg.Path
|
||||
|
||||
db.Close()
|
||||
|
||||
opts := NewOptions()
|
||||
defer opts.Close()
|
||||
|
||||
var errStr *C.char
|
||||
ldbname := C.CString(path)
|
||||
defer C.leveldb_free(unsafe.Pointer(ldbname))
|
||||
|
||||
C.leveldb_destroy_db(opts.Opt, ldbname, &errStr)
|
||||
if errStr != nil {
|
||||
return saveError(errStr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Clear() error {
|
||||
bc := db.NewWriteBatch()
|
||||
defer bc.Close()
|
||||
|
||||
var err error
|
||||
it := db.NewIterator()
|
||||
it.SeekToFirst()
|
||||
|
||||
num := 0
|
||||
for ; it.Valid(); it.Next() {
|
||||
bc.Delete(it.Key())
|
||||
num++
|
||||
if num == 1000 {
|
||||
num = 0
|
||||
if err = bc.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = bc.Commit()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DB) Put(key, value []byte) error {
|
||||
return db.put(db.writeOpts, key, value)
|
||||
}
|
||||
|
||||
func (db *DB) SyncPut(key, value []byte) error {
|
||||
return db.put(db.syncWriteOpts, key, value)
|
||||
}
|
||||
|
||||
func (db *DB) Get(key []byte) ([]byte, error) {
|
||||
return db.get(db.readOpts, key)
|
||||
}
|
||||
|
||||
func (db *DB) Delete(key []byte) error {
|
||||
return db.delete(db.writeOpts, key)
|
||||
}
|
||||
|
||||
func (db *DB) SyncDelete(key []byte) error {
|
||||
return db.delete(db.syncWriteOpts, key)
|
||||
}
|
||||
|
||||
func (db *DB) NewWriteBatch() *WriteBatch {
|
||||
wb := &WriteBatch{
|
||||
db: db,
|
||||
wbatch: C.leveldb_writebatch_create(),
|
||||
}
|
||||
return wb
|
||||
}
|
||||
|
||||
func (db *DB) NewSnapshot() *Snapshot {
|
||||
s := &Snapshot{
|
||||
db: db,
|
||||
snap: C.leveldb_create_snapshot(db.db),
|
||||
readOpts: NewReadOptions(),
|
||||
iteratorOpts: NewReadOptions(),
|
||||
}
|
||||
|
||||
s.readOpts.SetSnapshot(s)
|
||||
s.iteratorOpts.SetSnapshot(s)
|
||||
s.iteratorOpts.SetFillCache(false)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (db *DB) NewIterator() *Iterator {
|
||||
it := new(Iterator)
|
||||
|
||||
it.it = C.leveldb_create_iterator(db.db, db.iteratorOpts.Opt)
|
||||
|
||||
return it
|
||||
}
|
||||
|
||||
func (db *DB) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorForward)
|
||||
}
|
||||
|
||||
func (db *DB) RevRangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorBackward)
|
||||
}
|
||||
|
||||
//limit < 0, unlimit
|
||||
//offset must >= 0, if < 0, will get nothing
|
||||
func (db *DB) RangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorForward)
|
||||
}
|
||||
|
||||
//limit < 0, unlimit
|
||||
//offset must >= 0, if < 0, will get nothing
|
||||
func (db *DB) RevRangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorBackward)
|
||||
}
|
||||
|
||||
func (db *DB) put(wo *WriteOptions, key, value []byte) error {
|
||||
var errStr *C.char
|
||||
var k, v *C.char
|
||||
if len(key) != 0 {
|
||||
k = (*C.char)(unsafe.Pointer(&key[0]))
|
||||
}
|
||||
if len(value) != 0 {
|
||||
v = (*C.char)(unsafe.Pointer(&value[0]))
|
||||
}
|
||||
|
||||
lenk := len(key)
|
||||
lenv := len(value)
|
||||
C.leveldb_put(
|
||||
db.db, wo.Opt, k, C.size_t(lenk), v, C.size_t(lenv), &errStr)
|
||||
|
||||
if errStr != nil {
|
||||
return saveError(errStr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) get(ro *ReadOptions, key []byte) ([]byte, error) {
|
||||
var errStr *C.char
|
||||
var vallen C.size_t
|
||||
var k *C.char
|
||||
if len(key) != 0 {
|
||||
k = (*C.char)(unsafe.Pointer(&key[0]))
|
||||
}
|
||||
|
||||
value := C.leveldb_get(
|
||||
db.db, ro.Opt, k, C.size_t(len(key)), &vallen, &errStr)
|
||||
|
||||
if errStr != nil {
|
||||
return nil, saveError(errStr)
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
defer C.leveldb_free(unsafe.Pointer(value))
|
||||
return C.GoBytes(unsafe.Pointer(value), C.int(vallen)), nil
|
||||
}
|
||||
|
||||
func (db *DB) delete(wo *WriteOptions, key []byte) error {
|
||||
var errStr *C.char
|
||||
var k *C.char
|
||||
if len(key) != 0 {
|
||||
k = (*C.char)(unsafe.Pointer(&key[0]))
|
||||
}
|
||||
|
||||
C.leveldb_delete(
|
||||
db.db, wo.Opt, k, C.size_t(len(key)), &errStr)
|
||||
|
||||
if errStr != nil {
|
||||
return saveError(errStr)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package leveldb
|
||||
|
||||
// #cgo LDFLAGS: -lleveldb
|
||||
// #include <stdlib.h>
|
||||
// #include "leveldb/c.h"
|
||||
import "C"
|
||||
|
||||
type FilterPolicy struct {
|
||||
Policy *C.leveldb_filterpolicy_t
|
||||
}
|
||||
|
||||
func NewBloomFilter(bitsPerKey int) *FilterPolicy {
|
||||
policy := C.leveldb_filterpolicy_create_bloom(C.int(bitsPerKey))
|
||||
return &FilterPolicy{policy}
|
||||
}
|
||||
|
||||
func (fp *FilterPolicy) Close() {
|
||||
C.leveldb_filterpolicy_destroy(fp.Policy)
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
package leveldb
|
||||
|
||||
// #cgo LDFLAGS: -lleveldb
|
||||
// #include <stdlib.h>
|
||||
// #include "leveldb/c.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
IteratorForward uint8 = 0
|
||||
IteratorBackward uint8 = 1
|
||||
)
|
||||
|
||||
const (
|
||||
RangeClose uint8 = 0x00
|
||||
RangeLOpen uint8 = 0x01
|
||||
RangeROpen uint8 = 0x10
|
||||
RangeOpen uint8 = 0x11
|
||||
)
|
||||
|
||||
//min must less or equal than max
|
||||
//range type:
|
||||
//close: [min, max]
|
||||
//open: (min, max)
|
||||
//lopen: (min, max]
|
||||
//ropen: [min, max)
|
||||
type Range struct {
|
||||
Min []byte
|
||||
Max []byte
|
||||
|
||||
Type uint8
|
||||
}
|
||||
|
||||
type Iterator struct {
|
||||
it *C.leveldb_iterator_t
|
||||
}
|
||||
|
||||
func (it *Iterator) Key() []byte {
|
||||
var klen C.size_t
|
||||
kdata := C.leveldb_iter_key(it.it, &klen)
|
||||
if kdata == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return C.GoBytes(unsafe.Pointer(kdata), C.int(klen))
|
||||
}
|
||||
|
||||
func (it *Iterator) Value() []byte {
|
||||
var vlen C.size_t
|
||||
vdata := C.leveldb_iter_value(it.it, &vlen)
|
||||
if vdata == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return C.GoBytes(unsafe.Pointer(vdata), C.int(vlen))
|
||||
}
|
||||
|
||||
func (it *Iterator) Close() {
|
||||
C.leveldb_iter_destroy(it.it)
|
||||
it.it = nil
|
||||
}
|
||||
|
||||
func (it *Iterator) Valid() bool {
|
||||
return ucharToBool(C.leveldb_iter_valid(it.it))
|
||||
}
|
||||
|
||||
func (it *Iterator) Next() {
|
||||
C.leveldb_iter_next(it.it)
|
||||
}
|
||||
|
||||
func (it *Iterator) Prev() {
|
||||
C.leveldb_iter_prev(it.it)
|
||||
}
|
||||
|
||||
func (it *Iterator) SeekToFirst() {
|
||||
C.leveldb_iter_seek_to_first(it.it)
|
||||
}
|
||||
|
||||
func (it *Iterator) SeekToLast() {
|
||||
C.leveldb_iter_seek_to_last(it.it)
|
||||
}
|
||||
|
||||
func (it *Iterator) Seek(key []byte) {
|
||||
C.leveldb_iter_seek(it.it, (*C.char)(unsafe.Pointer(&key[0])), C.size_t(len(key)))
|
||||
}
|
||||
|
||||
func (it *Iterator) Find(key []byte) []byte {
|
||||
it.Seek(key)
|
||||
if it.Valid() {
|
||||
var klen C.size_t
|
||||
kdata := C.leveldb_iter_key(it.it, &klen)
|
||||
if kdata == nil {
|
||||
return nil
|
||||
} else if bytes.Equal(slice(unsafe.Pointer(kdata), int(C.int(klen))), key) {
|
||||
return it.Value()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type RangeLimitIterator struct {
|
||||
it *Iterator
|
||||
|
||||
r *Range
|
||||
|
||||
offset int
|
||||
limit int
|
||||
|
||||
step int
|
||||
|
||||
//0 for IteratorForward, 1 for IteratorBackward
|
||||
direction uint8
|
||||
}
|
||||
|
||||
func (it *RangeLimitIterator) Key() []byte {
|
||||
return it.it.Key()
|
||||
}
|
||||
|
||||
func (it *RangeLimitIterator) Value() []byte {
|
||||
return it.it.Value()
|
||||
}
|
||||
|
||||
func (it *RangeLimitIterator) Valid() bool {
|
||||
if it.offset < 0 {
|
||||
return false
|
||||
} else if !it.it.Valid() {
|
||||
return false
|
||||
} else if it.limit >= 0 && it.step >= it.limit {
|
||||
return false
|
||||
}
|
||||
|
||||
if it.direction == IteratorForward {
|
||||
if it.r.Max != nil {
|
||||
r := bytes.Compare(it.it.Key(), it.r.Max)
|
||||
if it.r.Type&RangeROpen > 0 {
|
||||
return !(r >= 0)
|
||||
} else {
|
||||
return !(r > 0)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if it.r.Min != nil {
|
||||
r := bytes.Compare(it.it.Key(), it.r.Min)
|
||||
if it.r.Type&RangeLOpen > 0 {
|
||||
return !(r <= 0)
|
||||
} else {
|
||||
return !(r < 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (it *RangeLimitIterator) Next() {
|
||||
it.step++
|
||||
|
||||
if it.direction == IteratorForward {
|
||||
it.it.Next()
|
||||
} else {
|
||||
it.it.Prev()
|
||||
}
|
||||
}
|
||||
|
||||
func (it *RangeLimitIterator) Close() {
|
||||
it.it.Close()
|
||||
}
|
||||
|
||||
func newRangeLimitIterator(i *Iterator, r *Range, offset int, limit int, direction uint8) *RangeLimitIterator {
|
||||
it := new(RangeLimitIterator)
|
||||
|
||||
it.it = i
|
||||
|
||||
it.r = r
|
||||
it.offset = offset
|
||||
it.limit = limit
|
||||
it.direction = direction
|
||||
|
||||
it.step = 0
|
||||
|
||||
if offset < 0 {
|
||||
return it
|
||||
}
|
||||
|
||||
if direction == IteratorForward {
|
||||
if r.Min == nil {
|
||||
it.it.SeekToFirst()
|
||||
} else {
|
||||
it.it.Seek(r.Min)
|
||||
|
||||
if r.Type&RangeLOpen > 0 {
|
||||
if it.it.Valid() && bytes.Equal(it.it.Key(), r.Min) {
|
||||
it.it.Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if r.Max == nil {
|
||||
it.it.SeekToLast()
|
||||
} else {
|
||||
it.it.Seek(r.Max)
|
||||
|
||||
if !it.it.Valid() {
|
||||
it.it.SeekToLast()
|
||||
} else {
|
||||
if !bytes.Equal(it.it.Key(), r.Max) {
|
||||
it.it.Prev()
|
||||
}
|
||||
}
|
||||
|
||||
if r.Type&RangeROpen > 0 {
|
||||
if it.it.Valid() && bytes.Equal(it.it.Key(), r.Max) {
|
||||
it.it.Prev()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < offset; i++ {
|
||||
if it.it.Valid() {
|
||||
if it.direction == IteratorForward {
|
||||
it.it.Next()
|
||||
} else {
|
||||
it.it.Prev()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return it
|
||||
}
|
|
@ -0,0 +1,259 @@
|
|||
package leveldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testConfigJson = []byte(`
|
||||
{
|
||||
"path" : "./testdb",
|
||||
"compression":true,
|
||||
"block_size" : 32768,
|
||||
"write_buffer_size" : 2097152,
|
||||
"cache_size" : 20971520
|
||||
}
|
||||
`)
|
||||
|
||||
var testOnce sync.Once
|
||||
var testDB *DB
|
||||
|
||||
func getTestDB() *DB {
|
||||
f := func() {
|
||||
var err error
|
||||
testDB, err = Open(testConfigJson)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
testOnce.Do(f)
|
||||
return testDB
|
||||
}
|
||||
|
||||
func TestSimple(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
key := []byte("key")
|
||||
value := []byte("hello world")
|
||||
if err := db.Put(key, value); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v, err := db.Get(key); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !bytes.Equal(v, value) {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
|
||||
if err := db.Delete(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v, err := db.Get(key); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if v != nil {
|
||||
t.Fatal("must nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatch(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
key1 := []byte("key1")
|
||||
key2 := []byte("key2")
|
||||
|
||||
value := []byte("hello world")
|
||||
|
||||
db.Put(key1, value)
|
||||
db.Put(key2, value)
|
||||
|
||||
wb := db.NewWriteBatch()
|
||||
defer wb.Close()
|
||||
|
||||
wb.Delete(key2)
|
||||
wb.Put(key1, []byte("hello world2"))
|
||||
|
||||
if err := wb.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v, err := db.Get(key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if v != nil {
|
||||
t.Fatal("must nil")
|
||||
}
|
||||
|
||||
if v, err := db.Get(key1); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != "hello world2" {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
wb.Delete(key1)
|
||||
|
||||
wb.Rollback()
|
||||
|
||||
if v, err := db.Get(key1); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != "hello world2" {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
db.Delete(key1)
|
||||
}
|
||||
|
||||
func checkIterator(it *RangeLimitIterator, cv ...int) error {
|
||||
v := make([]string, 0, len(cv))
|
||||
for ; it.Valid(); it.Next() {
|
||||
k := it.Key()
|
||||
v = append(v, string(k))
|
||||
}
|
||||
|
||||
it.Close()
|
||||
|
||||
if len(v) != len(cv) {
|
||||
return fmt.Errorf("len error %d != %d", len(v), len(cv))
|
||||
}
|
||||
|
||||
for k, i := range cv {
|
||||
if fmt.Sprintf("key_%d", i) != v[k] {
|
||||
return fmt.Errorf("%s, %d", v[k], i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
db.Clear()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
key := []byte(fmt.Sprintf("key_%d", i))
|
||||
value := []byte("")
|
||||
db.Put(key, value)
|
||||
}
|
||||
|
||||
var it *RangeLimitIterator
|
||||
|
||||
k := func(i int) []byte {
|
||||
return []byte(fmt.Sprintf("key_%d", i))
|
||||
}
|
||||
|
||||
it = db.RangeLimitIterator(k(1), k(5), RangeClose, 0, -1)
|
||||
if err := checkIterator(it, 1, 2, 3, 4, 5); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RangeLimitIterator(k(1), k(5), RangeClose, 1, 3)
|
||||
if err := checkIterator(it, 2, 3, 4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RangeLimitIterator(k(1), k(5), RangeLOpen, 0, -1)
|
||||
if err := checkIterator(it, 2, 3, 4, 5); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RangeLimitIterator(k(1), k(5), RangeROpen, 0, -1)
|
||||
if err := checkIterator(it, 1, 2, 3, 4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RangeLimitIterator(k(1), k(5), RangeOpen, 0, -1)
|
||||
if err := checkIterator(it, 2, 3, 4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RevRangeLimitIterator(k(1), k(5), RangeClose, 0, -1)
|
||||
if err := checkIterator(it, 5, 4, 3, 2, 1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RevRangeLimitIterator(k(1), k(5), RangeClose, 1, 3)
|
||||
if err := checkIterator(it, 4, 3, 2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RevRangeLimitIterator(k(1), k(5), RangeLOpen, 0, -1)
|
||||
if err := checkIterator(it, 5, 4, 3, 2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RevRangeLimitIterator(k(1), k(5), RangeROpen, 0, -1)
|
||||
if err := checkIterator(it, 4, 3, 2, 1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it = db.RevRangeLimitIterator(k(1), k(5), RangeOpen, 0, -1)
|
||||
if err := checkIterator(it, 4, 3, 2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
key := []byte("key")
|
||||
value := []byte("hello world")
|
||||
|
||||
db.Put(key, value)
|
||||
|
||||
s := db.NewSnapshot()
|
||||
defer s.Close()
|
||||
|
||||
db.Put(key, []byte("hello world2"))
|
||||
|
||||
if v, err := s.Get(key); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != string(value) {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDestroy(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
db.Put([]byte("a"), []byte("1"))
|
||||
if err := db.Clear(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(db.cfg.Path); err != nil {
|
||||
t.Fatal("must exist ", err.Error())
|
||||
}
|
||||
|
||||
if v, err := db.Get([]byte("a")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) == "1" {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
db.Destroy()
|
||||
|
||||
if _, err := os.Stat(db.cfg.Path); !os.IsNotExist(err) {
|
||||
t.Fatal("must not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseMore(t *testing.T) {
|
||||
cfg := new(Config)
|
||||
cfg.Path = "/tmp/testdb1234"
|
||||
cfg.CacheSize = 4 * 1024 * 1024
|
||||
os.RemoveAll(cfg.Path)
|
||||
for i := 0; i < 100; i++ {
|
||||
db, err := OpenWithConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Put([]byte("key"), []byte("value"))
|
||||
|
||||
db.Close()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
Copyright (c) 2012 Jeffrey M Hodges
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -0,0 +1,128 @@
|
|||
package leveldb
|
||||
|
||||
// #cgo LDFLAGS: -lleveldb
|
||||
// #include "leveldb/c.h"
|
||||
import "C"
|
||||
|
||||
type CompressionOpt int
|
||||
|
||||
const (
|
||||
NoCompression = CompressionOpt(0)
|
||||
SnappyCompression = CompressionOpt(1)
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Opt *C.leveldb_options_t
|
||||
}
|
||||
|
||||
type ReadOptions struct {
|
||||
Opt *C.leveldb_readoptions_t
|
||||
}
|
||||
|
||||
type WriteOptions struct {
|
||||
Opt *C.leveldb_writeoptions_t
|
||||
}
|
||||
|
||||
func NewOptions() *Options {
|
||||
opt := C.leveldb_options_create()
|
||||
return &Options{opt}
|
||||
}
|
||||
|
||||
func NewReadOptions() *ReadOptions {
|
||||
opt := C.leveldb_readoptions_create()
|
||||
return &ReadOptions{opt}
|
||||
}
|
||||
|
||||
func NewWriteOptions() *WriteOptions {
|
||||
opt := C.leveldb_writeoptions_create()
|
||||
return &WriteOptions{opt}
|
||||
}
|
||||
|
||||
func (o *Options) Close() {
|
||||
C.leveldb_options_destroy(o.Opt)
|
||||
}
|
||||
|
||||
func (o *Options) SetComparator(cmp *C.leveldb_comparator_t) {
|
||||
C.leveldb_options_set_comparator(o.Opt, cmp)
|
||||
}
|
||||
|
||||
func (o *Options) SetErrorIfExists(error_if_exists bool) {
|
||||
eie := boolToUchar(error_if_exists)
|
||||
C.leveldb_options_set_error_if_exists(o.Opt, eie)
|
||||
}
|
||||
|
||||
func (o *Options) SetCache(cache *Cache) {
|
||||
C.leveldb_options_set_cache(o.Opt, cache.Cache)
|
||||
}
|
||||
|
||||
// func (o *Options) SetEnv(env *Env) {
|
||||
// C.leveldb_options_set_env(o.Opt, env.Env)
|
||||
// }
|
||||
|
||||
func (o *Options) SetInfoLog(log *C.leveldb_logger_t) {
|
||||
C.leveldb_options_set_info_log(o.Opt, log)
|
||||
}
|
||||
|
||||
func (o *Options) SetWriteBufferSize(s int) {
|
||||
C.leveldb_options_set_write_buffer_size(o.Opt, C.size_t(s))
|
||||
}
|
||||
|
||||
func (o *Options) SetParanoidChecks(pc bool) {
|
||||
C.leveldb_options_set_paranoid_checks(o.Opt, boolToUchar(pc))
|
||||
}
|
||||
|
||||
func (o *Options) SetMaxOpenFiles(n int) {
|
||||
C.leveldb_options_set_max_open_files(o.Opt, C.int(n))
|
||||
}
|
||||
|
||||
func (o *Options) SetBlockSize(s int) {
|
||||
C.leveldb_options_set_block_size(o.Opt, C.size_t(s))
|
||||
}
|
||||
|
||||
func (o *Options) SetBlockRestartInterval(n int) {
|
||||
C.leveldb_options_set_block_restart_interval(o.Opt, C.int(n))
|
||||
}
|
||||
|
||||
func (o *Options) SetCompression(t CompressionOpt) {
|
||||
C.leveldb_options_set_compression(o.Opt, C.int(t))
|
||||
}
|
||||
|
||||
func (o *Options) SetCreateIfMissing(b bool) {
|
||||
C.leveldb_options_set_create_if_missing(o.Opt, boolToUchar(b))
|
||||
}
|
||||
|
||||
func (o *Options) SetFilterPolicy(fp *FilterPolicy) {
|
||||
var policy *C.leveldb_filterpolicy_t
|
||||
if fp != nil {
|
||||
policy = fp.Policy
|
||||
}
|
||||
C.leveldb_options_set_filter_policy(o.Opt, policy)
|
||||
}
|
||||
|
||||
func (ro *ReadOptions) Close() {
|
||||
C.leveldb_readoptions_destroy(ro.Opt)
|
||||
}
|
||||
|
||||
func (ro *ReadOptions) SetVerifyChecksums(b bool) {
|
||||
C.leveldb_readoptions_set_verify_checksums(ro.Opt, boolToUchar(b))
|
||||
}
|
||||
|
||||
func (ro *ReadOptions) SetFillCache(b bool) {
|
||||
C.leveldb_readoptions_set_fill_cache(ro.Opt, boolToUchar(b))
|
||||
}
|
||||
|
||||
func (ro *ReadOptions) SetSnapshot(snap *Snapshot) {
|
||||
var s *C.leveldb_snapshot_t
|
||||
if snap != nil {
|
||||
s = snap.snap
|
||||
}
|
||||
C.leveldb_readoptions_set_snapshot(ro.Opt, s)
|
||||
}
|
||||
|
||||
func (wo *WriteOptions) Close() {
|
||||
C.leveldb_writeoptions_destroy(wo.Opt)
|
||||
}
|
||||
|
||||
func (wo *WriteOptions) SetSync(b bool) {
|
||||
C.leveldb_writeoptions_set_sync(wo.Opt, boolToUchar(b))
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package leveldb
|
||||
|
||||
// #cgo LDFLAGS: -lleveldb
|
||||
// #include <stdint.h>
|
||||
// #include "leveldb/c.h"
|
||||
import "C"
|
||||
|
||||
type Snapshot struct {
|
||||
db *DB
|
||||
|
||||
snap *C.leveldb_snapshot_t
|
||||
|
||||
readOpts *ReadOptions
|
||||
iteratorOpts *ReadOptions
|
||||
}
|
||||
|
||||
func (s *Snapshot) Close() {
|
||||
C.leveldb_release_snapshot(s.db.db, s.snap)
|
||||
|
||||
s.iteratorOpts.Close()
|
||||
s.readOpts.Close()
|
||||
}
|
||||
|
||||
func (s *Snapshot) Get(key []byte) ([]byte, error) {
|
||||
return s.db.get(s.readOpts, key)
|
||||
}
|
||||
|
||||
func (s *Snapshot) NewIterator() *Iterator {
|
||||
it := new(Iterator)
|
||||
|
||||
it.it = C.leveldb_create_iterator(s.db.db, s.iteratorOpts.Opt)
|
||||
|
||||
return it
|
||||
}
|
||||
|
||||
func (s *Snapshot) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorForward)
|
||||
}
|
||||
|
||||
func (s *Snapshot) RevRangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorBackward)
|
||||
}
|
||||
|
||||
//limit < 0, unlimit
|
||||
//offset must >= 0, if < 0, will get nothing
|
||||
func (s *Snapshot) RangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorForward)
|
||||
}
|
||||
|
||||
//limit < 0, unlimit
|
||||
//offset must >= 0, if < 0, will get nothing
|
||||
func (s *Snapshot) RevRangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator {
|
||||
return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorBackward)
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package leveldb
|
||||
|
||||
// #include "leveldb/c.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func boolToUchar(b bool) C.uchar {
|
||||
uc := C.uchar(0)
|
||||
if b {
|
||||
uc = C.uchar(1)
|
||||
}
|
||||
return uc
|
||||
}
|
||||
|
||||
func ucharToBool(uc C.uchar) bool {
|
||||
if uc == C.uchar(0) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func saveError(errStr *C.char) error {
|
||||
if errStr != nil {
|
||||
gs := C.GoString(errStr)
|
||||
C.leveldb_free(unsafe.Pointer(errStr))
|
||||
return fmt.Errorf(gs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func slice(p unsafe.Pointer, n int) []byte {
|
||||
var b []byte
|
||||
pbyte := (*reflect.SliceHeader)(unsafe.Pointer(&b))
|
||||
pbyte.Data = uintptr(p)
|
||||
pbyte.Len = n
|
||||
pbyte.Cap = n
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FileHandler struct {
|
||||
fd *os.File
|
||||
}
|
||||
|
||||
func NewFileHandler(fileName string, flag int) (*FileHandler, error) {
|
||||
dir := path.Dir(fileName)
|
||||
os.Mkdir(dir, 0777)
|
||||
|
||||
f, err := os.OpenFile(fileName, flag, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h := new(FileHandler)
|
||||
|
||||
h.fd = f
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *FileHandler) Write(b []byte) (n int, err error) {
|
||||
return h.fd.Write(b)
|
||||
}
|
||||
|
||||
func (h *FileHandler) Close() error {
|
||||
return h.fd.Close()
|
||||
}
|
||||
|
||||
type RotatingFileHandler struct {
|
||||
fd *os.File
|
||||
|
||||
fileName string
|
||||
maxBytes int
|
||||
backupCount int
|
||||
}
|
||||
|
||||
func NewRotatingFileHandler(fileName string, maxBytes int, backupCount int) (*RotatingFileHandler, error) {
|
||||
dir := path.Dir(fileName)
|
||||
os.Mkdir(dir, 0777)
|
||||
|
||||
h := new(RotatingFileHandler)
|
||||
|
||||
if maxBytes <= 0 {
|
||||
return nil, fmt.Errorf("invalid max bytes")
|
||||
}
|
||||
|
||||
h.fileName = fileName
|
||||
h.maxBytes = maxBytes
|
||||
h.backupCount = backupCount
|
||||
|
||||
var err error
|
||||
h.fd, err = os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *RotatingFileHandler) Write(p []byte) (n int, err error) {
|
||||
h.doRollover()
|
||||
return h.fd.Write(p)
|
||||
}
|
||||
|
||||
func (h *RotatingFileHandler) Close() error {
|
||||
if h.fd != nil {
|
||||
return h.fd.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *RotatingFileHandler) doRollover() {
|
||||
f, err := h.fd.Stat()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if h.maxBytes <= 0 {
|
||||
return
|
||||
} else if f.Size() < int64(h.maxBytes) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.backupCount > 0 {
|
||||
h.fd.Close()
|
||||
|
||||
for i := h.backupCount - 1; i > 0; i-- {
|
||||
sfn := fmt.Sprintf("%s.%d", h.fileName, i)
|
||||
dfn := fmt.Sprintf("%s.%d", h.fileName, i+1)
|
||||
|
||||
os.Rename(sfn, dfn)
|
||||
}
|
||||
|
||||
dfn := fmt.Sprintf("%s.1", h.fileName)
|
||||
os.Rename(h.fileName, dfn)
|
||||
|
||||
h.fd, _ = os.OpenFile(h.fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
}
|
||||
}
|
||||
|
||||
//refer: http://docs.python.org/2/library/logging.handlers.html
|
||||
//same like python TimedRotatingFileHandler
|
||||
|
||||
type TimeRotatingFileHandler struct {
|
||||
fd *os.File
|
||||
|
||||
baseName string
|
||||
interval int64
|
||||
suffix string
|
||||
rolloverAt int64
|
||||
}
|
||||
|
||||
const (
|
||||
WhenSecond = iota
|
||||
WhenMinute
|
||||
WhenHour
|
||||
WhenDay
|
||||
)
|
||||
|
||||
func NewTimeRotatingFileHandler(baseName string, when int8, interval int) (*TimeRotatingFileHandler, error) {
|
||||
dir := path.Dir(baseName)
|
||||
os.Mkdir(dir, 0777)
|
||||
|
||||
h := new(TimeRotatingFileHandler)
|
||||
|
||||
h.baseName = baseName
|
||||
|
||||
switch when {
|
||||
case WhenSecond:
|
||||
h.interval = 1
|
||||
h.suffix = "2006-01-02_15-04-05"
|
||||
case WhenMinute:
|
||||
h.interval = 60
|
||||
h.suffix = "2006-01-02_15-04"
|
||||
case WhenHour:
|
||||
h.interval = 3600
|
||||
h.suffix = "2006-01-02_15"
|
||||
case WhenDay:
|
||||
h.interval = 3600 * 24
|
||||
h.suffix = "2006-01-02"
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid when_rotate: %d", when)
|
||||
}
|
||||
|
||||
h.interval = h.interval * int64(interval)
|
||||
|
||||
var err error
|
||||
h.fd, err = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fInfo, _ := h.fd.Stat()
|
||||
h.rolloverAt = fInfo.ModTime().Unix() + h.interval
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *TimeRotatingFileHandler) doRollover() {
|
||||
//refer http://hg.python.org/cpython/file/2.7/Lib/logging/handlers.py
|
||||
now := time.Now()
|
||||
|
||||
if h.rolloverAt <= now.Unix() {
|
||||
fName := h.baseName + now.Format(h.suffix)
|
||||
h.fd.Close()
|
||||
e := os.Rename(h.baseName, fName)
|
||||
if e != nil {
|
||||
panic(e)
|
||||
}
|
||||
|
||||
h.fd, _ = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
|
||||
h.rolloverAt = time.Now().Unix() + h.interval
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TimeRotatingFileHandler) Write(b []byte) (n int, err error) {
|
||||
h.doRollover()
|
||||
return h.fd.Write(b)
|
||||
}
|
||||
|
||||
func (h *TimeRotatingFileHandler) Close() error {
|
||||
return h.fd.Close()
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
Write(p []byte) (n int, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type StreamHandler struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func NewStreamHandler(w io.Writer) (*StreamHandler, error) {
|
||||
h := new(StreamHandler)
|
||||
|
||||
h.w = w
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *StreamHandler) Write(b []byte) (n int, err error) {
|
||||
return h.w.Write(b)
|
||||
}
|
||||
|
||||
func (h *StreamHandler) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullHandler struct {
|
||||
}
|
||||
|
||||
func NewNullHandler() (*NullHandler, error) {
|
||||
return new(NullHandler), nil
|
||||
}
|
||||
|
||||
func (h *NullHandler) Write(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (h *NullHandler) Close() {
|
||||
|
||||
}
|
|
@ -0,0 +1,226 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
LevelTrace = iota
|
||||
LevelDebug
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
LevelFatal
|
||||
)
|
||||
|
||||
const (
|
||||
Ltime = 1 << iota //time format "2006/01/02 15:04:05"
|
||||
Lfile //file.go:123
|
||||
Llevel //[Trace|Debug|Info...]
|
||||
)
|
||||
|
||||
var LevelName [6]string = [6]string{"Trace", "Debug", "Info", "Warn", "Error", "Fatal"}
|
||||
|
||||
const TimeFormat = "2006/01/02 15:04:05"
|
||||
|
||||
const maxBufPoolSize = 16
|
||||
|
||||
type Logger struct {
|
||||
sync.Mutex
|
||||
|
||||
level int
|
||||
flag int
|
||||
|
||||
handler Handler
|
||||
|
||||
quit chan struct{}
|
||||
msg chan []byte
|
||||
|
||||
bufs [][]byte
|
||||
}
|
||||
|
||||
func New(handler Handler, flag int) *Logger {
|
||||
var l = new(Logger)
|
||||
|
||||
l.level = LevelInfo
|
||||
l.handler = handler
|
||||
|
||||
l.flag = flag
|
||||
|
||||
l.quit = make(chan struct{})
|
||||
|
||||
l.msg = make(chan []byte, 1024)
|
||||
|
||||
l.bufs = make([][]byte, 0, 16)
|
||||
|
||||
go l.run()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func NewDefault(handler Handler) *Logger {
|
||||
return New(handler, Ltime|Lfile|Llevel)
|
||||
}
|
||||
|
||||
func newStdHandler() *StreamHandler {
|
||||
h, _ := NewStreamHandler(os.Stdout)
|
||||
return h
|
||||
}
|
||||
|
||||
var std = NewDefault(newStdHandler())
|
||||
|
||||
func (l *Logger) run() {
|
||||
for {
|
||||
select {
|
||||
case msg := <-l.msg:
|
||||
l.handler.Write(msg)
|
||||
l.putBuf(msg)
|
||||
case <-l.quit:
|
||||
l.handler.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) popBuf() []byte {
|
||||
l.Lock()
|
||||
var buf []byte
|
||||
if len(l.bufs) == 0 {
|
||||
buf = make([]byte, 0, 1024)
|
||||
} else {
|
||||
buf = l.bufs[len(l.bufs)-1]
|
||||
l.bufs = l.bufs[0 : len(l.bufs)-1]
|
||||
}
|
||||
l.Unlock()
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (l *Logger) putBuf(buf []byte) {
|
||||
l.Lock()
|
||||
if len(l.bufs) < maxBufPoolSize {
|
||||
buf = buf[0:0]
|
||||
l.bufs = append(l.bufs, buf)
|
||||
}
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
if l.quit == nil {
|
||||
return
|
||||
}
|
||||
|
||||
close(l.quit)
|
||||
l.quit = nil
|
||||
}
|
||||
|
||||
func (l *Logger) SetLevel(level int) {
|
||||
l.level = level
|
||||
}
|
||||
|
||||
func (l *Logger) Output(callDepth int, level int, format string, v ...interface{}) {
|
||||
if l.level > level {
|
||||
return
|
||||
}
|
||||
|
||||
buf := l.popBuf()
|
||||
|
||||
if l.flag&Ltime > 0 {
|
||||
now := time.Now().Format(TimeFormat)
|
||||
buf = append(buf, '[')
|
||||
buf = append(buf, now...)
|
||||
buf = append(buf, "] "...)
|
||||
}
|
||||
|
||||
if l.flag&Lfile > 0 {
|
||||
_, file, line, ok := runtime.Caller(callDepth)
|
||||
if !ok {
|
||||
file = "???"
|
||||
line = 0
|
||||
} else {
|
||||
for i := len(file) - 1; i > 0; i-- {
|
||||
if file[i] == '/' {
|
||||
file = file[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf = append(buf, file...)
|
||||
buf = append(buf, ':')
|
||||
|
||||
strconv.AppendInt(buf, int64(line), 10)
|
||||
}
|
||||
|
||||
if l.flag&Llevel > 0 {
|
||||
buf = append(buf, '[')
|
||||
buf = append(buf, LevelName[level]...)
|
||||
buf = append(buf, "] "...)
|
||||
}
|
||||
|
||||
s := fmt.Sprintf(format, v...)
|
||||
|
||||
buf = append(buf, s...)
|
||||
|
||||
if s[len(s)-1] != '\n' {
|
||||
buf = append(buf, '\n')
|
||||
}
|
||||
|
||||
l.msg <- buf
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, v ...interface{}) {
|
||||
l.Output(2, LevelTrace, format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, v ...interface{}) {
|
||||
l.Output(2, LevelDebug, format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, v ...interface{}) {
|
||||
l.Output(2, LevelInfo, format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, v ...interface{}) {
|
||||
l.Output(2, LevelWarn, format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, v ...interface{}) {
|
||||
l.Output(2, LevelError, format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(format string, v ...interface{}) {
|
||||
l.Output(2, LevelFatal, format, v...)
|
||||
}
|
||||
|
||||
func SetLevel(level int) {
|
||||
std.SetLevel(level)
|
||||
}
|
||||
|
||||
func Trace(format string, v ...interface{}) {
|
||||
std.Output(2, LevelTrace, format, v...)
|
||||
}
|
||||
|
||||
func Debug(format string, v ...interface{}) {
|
||||
std.Output(2, LevelDebug, format, v...)
|
||||
}
|
||||
|
||||
func Info(format string, v ...interface{}) {
|
||||
std.Output(2, LevelInfo, format, v...)
|
||||
}
|
||||
|
||||
func Warn(format string, v ...interface{}) {
|
||||
std.Output(2, LevelWarn, format, v...)
|
||||
}
|
||||
|
||||
func Error(format string, v ...interface{}) {
|
||||
std.Output(2, LevelError, format, v...)
|
||||
}
|
||||
|
||||
func Fatal(format string, v ...interface{}) {
|
||||
std.Output(2, LevelFatal, format, v...)
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStdStreamLog(t *testing.T) {
|
||||
h, _ := NewStreamHandler(os.Stdout)
|
||||
s := NewDefault(h)
|
||||
s.Info("hello world")
|
||||
|
||||
s.Close()
|
||||
|
||||
Info("hello world")
|
||||
}
|
||||
|
||||
func TestRotatingFileLog(t *testing.T) {
|
||||
path := "./test_log"
|
||||
os.RemoveAll(path)
|
||||
|
||||
os.Mkdir(path, 0777)
|
||||
fileName := path + "/test"
|
||||
|
||||
h, err := NewRotatingFileHandler(fileName, 10, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 10)
|
||||
|
||||
h.Write(buf)
|
||||
|
||||
h.Write(buf)
|
||||
|
||||
if _, err := os.Stat(fileName + ".1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(fileName + ".2"); err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
h.Write(buf)
|
||||
if _, err := os.Stat(fileName + ".2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
h.Close()
|
||||
|
||||
os.RemoveAll(path)
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SocketHandler struct {
|
||||
c net.Conn
|
||||
protocol string
|
||||
addr string
|
||||
}
|
||||
|
||||
func NewSocketHandler(protocol string, addr string) (*SocketHandler, error) {
|
||||
s := new(SocketHandler)
|
||||
|
||||
s.protocol = protocol
|
||||
s.addr = addr
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (h *SocketHandler) Write(p []byte) (n int, err error) {
|
||||
if err = h.connect(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, len(p)+4)
|
||||
|
||||
binary.BigEndian.PutUint32(buf, uint32(len(p)))
|
||||
|
||||
copy(buf[4:], p)
|
||||
|
||||
n, err = h.c.Write(buf)
|
||||
if err != nil {
|
||||
h.c.Close()
|
||||
h.c = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *SocketHandler) Close() error {
|
||||
if h.c != nil {
|
||||
h.c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SocketHandler) connect() error {
|
||||
if h.c != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
h.c, err = net.DialTimeout(h.protocol, h.addr, 20*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/siddontang/go-log/log"
|
||||
"github.com/siddontang/ledisdb/log"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"github.com/siddontang/ledisdb/client/go/redis"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/siddontang/go-log/log"
|
||||
"github.com/siddontang/ledisdb/ledis"
|
||||
"github.com/siddontang/ledisdb/log"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
|
|
|
@ -59,7 +59,7 @@ func hdelCommand(c *client) error {
|
|||
return ErrCmdParams
|
||||
}
|
||||
|
||||
if n, err := c.db.HDel(args[0], args[1:]); err != nil {
|
||||
if n, err := c.db.HDel(args[0], args[1:]...); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(n)
|
||||
|
@ -138,7 +138,7 @@ func hmgetCommand(c *client) error {
|
|||
return ErrCmdParams
|
||||
}
|
||||
|
||||
if v, err := c.db.HMget(args[0], args[1:]); err != nil {
|
||||
if v, err := c.db.HMget(args[0], args[1:]...); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeArray(v)
|
||||
|
@ -207,6 +207,61 @@ func hclearCommand(c *client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func hexpireCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
duration, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.HExpire(args[0], duration); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hexpireAtCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
when, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.HExpireAt(args[0], when); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func httlCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
if v, err := c.db.HTTL(args[0]); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
register("hdel", hdelCommand)
|
||||
register("hexists", hexistsCommand)
|
||||
|
@ -223,4 +278,7 @@ func init() {
|
|||
//ledisdb special command
|
||||
|
||||
register("hclear", hclearCommand)
|
||||
register("hexpire", hexpireCommand)
|
||||
register("hexpireat", hexpireAtCommand)
|
||||
register("httl", httlCommand)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"github.com/siddontang/ledisdb/client/go/redis"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
|
|
@ -203,6 +203,65 @@ func mgetCommand(c *client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func expireCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
duration, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.Expire(args[0], duration); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func expireAtCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
when, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.ExpireAt(args[0], when); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ttlCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
if v, err := c.db.TTL(args[0]); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (db *DB) Expire(key []byte, duration int6
|
||||
// func (db *DB) ExpireAt(key []byte, when int64)
|
||||
// func (db *DB) TTL(key []byte) (int64, error)
|
||||
|
||||
func init() {
|
||||
register("decr", decrCommand)
|
||||
register("decrby", decrbyCommand)
|
||||
|
@ -216,4 +275,7 @@ func init() {
|
|||
register("mset", msetCommand)
|
||||
register("set", setCommand)
|
||||
register("setnx", setnxCommand)
|
||||
register("expire", expireCommand)
|
||||
register("expireat", expireAtCommand)
|
||||
register("ttl", ttlCommand)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"github.com/siddontang/ledisdb/client/go/redis"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
|
@ -143,6 +143,61 @@ func lclearCommand(c *client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func lexpireCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
duration, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.LExpire(args[0], duration); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lexpireAtCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
when, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.LExpireAt(args[0], when); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lttlCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
if v, err := c.db.LTTL(args[0]); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
register("lindex", lindexCommand)
|
||||
register("llen", llenCommand)
|
||||
|
@ -155,5 +210,7 @@ func init() {
|
|||
//ledisdb special command
|
||||
|
||||
register("lclear", lclearCommand)
|
||||
|
||||
register("lexpire", lexpireCommand)
|
||||
register("lexpireat", lexpireAtCommand)
|
||||
register("lttl", lttlCommand)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"github.com/siddontang/ledisdb/client/go/redis"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
|
|
@ -3,14 +3,14 @@ package server
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/siddontang/go-leveldb/leveldb"
|
||||
"github.com/siddontang/ledisdb/leveldb"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func checkDataEqual(master *App, slave *App) error {
|
||||
it := master.ldb.DataDB().Iterator(nil, nil, leveldb.RangeClose, 0, -1)
|
||||
it := master.ldb.DataDB().RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1)
|
||||
for ; it.Valid(); it.Next() {
|
||||
key := it.Key()
|
||||
value := it.Value()
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/siddontang/ledisdb/client/go/redis"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func now() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
func TestKVExpire(t *testing.T) {
|
||||
c := getTestConn()
|
||||
defer c.Close()
|
||||
|
||||
k := "a_ttl"
|
||||
c.Do("set", k, "123")
|
||||
|
||||
// expire + ttl
|
||||
exp := int64(10)
|
||||
if n, err := redis.Int(c.Do("expire", k, exp)); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
if ttl, err := redis.Int64(c.Do("ttl", k)); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if ttl != exp {
|
||||
t.Fatal(ttl)
|
||||
}
|
||||
|
||||
// expireat + ttl
|
||||
tm := now() + 3
|
||||
if n, err := redis.Int(c.Do("expireat", k, tm)); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
if ttl, err := redis.Int64(c.Do("ttl", k)); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if ttl != 3 {
|
||||
t.Fatal(ttl)
|
||||
}
|
||||
|
||||
kErr := "not_exist_ttl"
|
||||
|
||||
// err - expire, expireat
|
||||
if n, err := redis.Int(c.Do("expire", kErr, tm)); err != nil || n != 0 {
|
||||
t.Fatal(false)
|
||||
}
|
||||
|
||||
if n, err := redis.Int(c.Do("expireat", kErr, tm)); err != nil || n != 0 {
|
||||
t.Fatal(false)
|
||||
}
|
||||
|
||||
if n, err := redis.Int(c.Do("ttl", kErr)); err != nil || n != -1 {
|
||||
t.Fatal(false)
|
||||
}
|
||||
|
||||
}
|
|
@ -421,6 +421,61 @@ func zclearCommand(c *client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func zexpireCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
duration, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.ZExpire(args[0], duration); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func zexpireAtCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
when, err := ledis.StrInt64(args[1], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v, err := c.db.ZExpireAt(args[0], when); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func zttlCommand(c *client) error {
|
||||
args := c.args
|
||||
if len(args) == 0 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
if v, err := c.db.ZTTL(args[0]); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.writeInteger(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
register("zadd", zaddCommand)
|
||||
register("zcard", zcardCommand)
|
||||
|
@ -438,6 +493,9 @@ func init() {
|
|||
register("zscore", zscoreCommand)
|
||||
|
||||
//ledisdb special command
|
||||
register("zclear", zclearCommand)
|
||||
|
||||
register("zclear", zclearCommand)
|
||||
register("zexpire", zexpireCommand)
|
||||
register("zexpireat", zexpireAtCommand)
|
||||
register("zttl", zttlCommand)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"github.com/siddontang/ledisdb/client/go/redis"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/siddontang/go-log/log"
|
||||
"github.com/siddontang/ledisdb/ledis"
|
||||
"github.com/siddontang/ledisdb/log"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
|
|
Loading…
Reference in New Issue