// Copyright 2012 Gary Burd // // Licensed under the Apache License, Version 2.0 (the "License"): you may // not use this file except in compliance with the License. You may obtain // a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. package redis import ( "bufio" "bytes" "errors" "fmt" "io" "net" "strconv" "sync" "time" ) // conn is the low-level implementation of Conn type conn struct { // Shared mu sync.Mutex pending int err error conn net.Conn // Read readTimeout time.Duration br *bufio.Reader // Write writeTimeout time.Duration bw *bufio.Writer // Scratch space for formatting argument length. // '*' or '$', length, "\r\n" lenScratch [32]byte // Scratch space for formatting integers and floats. numScratch [40]byte } // Dial connects to the Redis server at the given network and address. func Dial(network, address string) (Conn, error) { c, err := net.Dial(network, address) if err != nil { return nil, err } return NewConn(c, 0, 0), nil } // DialTimeout acts like Dial but takes timeouts for establishing the // connection to the server, writing a command and reading a reply. func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { var c net.Conn var err error if connectTimeout > 0 { c, err = net.DialTimeout(network, address, connectTimeout) } else { c, err = net.Dial(network, address) } if err != nil { return nil, err } return NewConn(c, readTimeout, writeTimeout), nil } // NewConn returns a new Redigo connection for the given net connection. func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { return &conn{ conn: netConn, bw: bufio.NewWriter(netConn), br: bufio.NewReader(netConn), readTimeout: readTimeout, writeTimeout: writeTimeout, } } func (c *conn) Close() error { c.mu.Lock() err := c.err if c.err == nil { c.err = errors.New("redigo: closed") err = c.conn.Close() } c.mu.Unlock() return err } func (c *conn) fatal(err error) error { c.mu.Lock() if c.err == nil { c.err = err // Close connection to force errors on subsequent calls and to unblock // other reader or writer. c.conn.Close() } c.mu.Unlock() return err } func (c *conn) Err() error { c.mu.Lock() err := c.err c.mu.Unlock() return err } func (c *conn) writeLen(prefix byte, n int) error { c.lenScratch[len(c.lenScratch)-1] = '\n' c.lenScratch[len(c.lenScratch)-2] = '\r' i := len(c.lenScratch) - 3 for { c.lenScratch[i] = byte('0' + n%10) i -= 1 n = n / 10 if n == 0 { break } } c.lenScratch[i] = prefix _, err := c.bw.Write(c.lenScratch[i:]) return err } func (c *conn) writeString(s string) error { c.writeLen('$', len(s)) c.bw.WriteString(s) _, err := c.bw.WriteString("\r\n") return err } func (c *conn) writeBytes(p []byte) error { c.writeLen('$', len(p)) c.bw.Write(p) _, err := c.bw.WriteString("\r\n") return err } func (c *conn) writeInt64(n int64) error { return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)) } func (c *conn) writeFloat64(n float64) error { return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)) } func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { c.writeLen('*', 1+len(args)) err = c.writeString(cmd) for _, arg := range args { if err != nil { break } switch arg := arg.(type) { case string: err = c.writeString(arg) case []byte: err = c.writeBytes(arg) case int: err = c.writeInt64(int64(arg)) case int64: err = c.writeInt64(arg) case float64: err = c.writeFloat64(arg) case bool: if arg { err = c.writeString("1") } else { err = c.writeString("0") } case nil: err = c.writeString("") default: var buf bytes.Buffer fmt.Fprint(&buf, arg) err = c.writeBytes(buf.Bytes()) } } return err } func (c *conn) readLine() ([]byte, error) { p, err := c.br.ReadSlice('\n') if err == bufio.ErrBufferFull { return nil, errors.New("redigo: long response line") } if err != nil { return nil, err } i := len(p) - 2 if i < 0 || p[i] != '\r' { return nil, errors.New("redigo: bad response line terminator") } return p[:i], nil } // parseLen parses bulk string and array lengths. func parseLen(p []byte) (int, error) { if len(p) == 0 { return -1, errors.New("redigo: malformed length") } if p[0] == '-' && len(p) == 2 && p[1] == '1' { // handle $-1 and $-1 null replies. return -1, nil } var n int for _, b := range p { n *= 10 if b < '0' || b > '9' { return -1, errors.New("redigo: illegal bytes in length") } n += int(b - '0') } return n, nil } // parseInt parses an integer reply. func parseInt(p []byte) (interface{}, error) { if len(p) == 0 { return 0, errors.New("redigo: malformed integer") } var negate bool if p[0] == '-' { negate = true p = p[1:] if len(p) == 0 { return 0, errors.New("redigo: malformed integer") } } var n int64 for _, b := range p { n *= 10 if b < '0' || b > '9' { return 0, errors.New("redigo: illegal bytes in length") } n += int64(b - '0') } if negate { n = -n } return n, nil } var ( okReply interface{} = "OK" pongReply interface{} = "PONG" ) func (c *conn) readReply() (interface{}, error) { line, err := c.readLine() if err != nil { return nil, err } if len(line) == 0 { return nil, errors.New("redigo: short response line") } switch line[0] { case '+': switch { case len(line) == 3 && line[1] == 'O' && line[2] == 'K': // Avoid allocation for frequent "+OK" response. return okReply, nil case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G': // Avoid allocation in PING command benchmarks :) return pongReply, nil default: return string(line[1:]), nil } case '-': return Error(string(line[1:])), nil case ':': return parseInt(line[1:]) case '$': n, err := parseLen(line[1:]) if n < 0 || err != nil { return nil, err } p := make([]byte, n) _, err = io.ReadFull(c.br, p) if err != nil { return nil, err } if line, err := c.readLine(); err != nil { return nil, err } else if len(line) != 0 { return nil, errors.New("redigo: bad bulk string format") } return p, nil case '*': n, err := parseLen(line[1:]) if n < 0 || err != nil { return nil, err } r := make([]interface{}, n) for i := range r { r[i], err = c.readReply() if err != nil { return nil, err } } return r, nil } return nil, errors.New("redigo: unexpected response line") } func (c *conn) Send(cmd string, args ...interface{}) error { c.mu.Lock() c.pending += 1 c.mu.Unlock() if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) } if err := c.writeCommand(cmd, args); err != nil { return c.fatal(err) } return nil } func (c *conn) Flush() error { if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) } if err := c.bw.Flush(); err != nil { return c.fatal(err) } return nil } func (c *conn) Receive() (reply interface{}, err error) { c.mu.Lock() // There can be more receives than sends when using pub/sub. To allow // normal use of the connection after unsubscribe from all channels, do not // decrement pending to a negative value. if c.pending > 0 { c.pending -= 1 } c.mu.Unlock() if c.readTimeout != 0 { c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) } if reply, err = c.readReply(); err != nil { return nil, c.fatal(err) } if err, ok := reply.(Error); ok { return nil, err } return } func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { c.mu.Lock() pending := c.pending c.pending = 0 c.mu.Unlock() if cmd == "" && pending == 0 { return nil, nil } if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) } if cmd != "" { c.writeCommand(cmd, args) } if err := c.bw.Flush(); err != nil { return nil, c.fatal(err) } if c.readTimeout != 0 { c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) } if cmd == "" { reply := make([]interface{}, pending) for i := range reply { r, e := c.readReply() if e != nil { return nil, c.fatal(e) } reply[i] = r } return reply, nil } var err error var reply interface{} for i := 0; i <= pending; i++ { var e error if reply, e = c.readReply(); e != nil { return nil, c.fatal(e) } if e, ok := reply.(Error); ok && err == nil { err = e } } return reply, err }