ledisdb/client/go/ledis/conn.go

413 lines
7.4 KiB
Go

package ledis
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
)
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
type Conn struct {
cm sync.Mutex
wm sync.Mutex
rm sync.Mutex
closed bool
client *Client
addr string
c net.Conn
br *bufio.Reader
bw *bufio.Writer
rSize int
wSize int
// Scratch space for formatting argument length.
// '*' or '$', length, "\r\n"
lenScratch [32]byte
// Scratch space for formatting integers and floats.
numScratch [40]byte
}
func NewConn(addr string) *Conn {
co := new(Conn)
co.addr = addr
co.rSize = 4096
co.wSize = 4096
co.closed = false
return co
}
func NewConnSize(addr string, readSize int, writeSize int) *Conn {
co := NewConn(addr)
co.rSize = readSize
co.wSize = writeSize
return co
}
func (c *Conn) Close() {
if c.client != nil {
c.client.put(c)
} else {
c.finalize()
}
}
func (c *Conn) Do(cmd string, args ...interface{}) (interface{}, error) {
if err := c.Send(cmd, args...); err != nil {
return nil, err
}
return c.Receive()
}
func (c *Conn) Send(cmd string, args ...interface{}) error {
if err := c.connect(); err != nil {
return err
}
c.wm.Lock()
defer c.wm.Unlock()
if err := c.writeCommand(cmd, args); err != nil {
c.finalize()
return err
}
if err := c.bw.Flush(); err != nil {
c.finalize()
return err
}
return nil
}
func (c *Conn) Receive() (interface{}, error) {
c.rm.Lock()
defer c.rm.Unlock()
if reply, err := c.readReply(); err != nil {
c.finalize()
return nil, err
} else {
if e, ok := reply.(Error); ok {
return reply, e
} else {
return reply, nil
}
}
}
func (c *Conn) ReceiveBulkTo(w io.Writer) error {
c.rm.Lock()
defer c.rm.Unlock()
err := c.readBulkReplyTo(w)
if err != nil {
if _, ok := err.(Error); !ok {
c.finalize()
}
}
return err
}
func (c *Conn) finalize() {
c.cm.Lock()
if !c.closed {
c.c.Close()
c.closed = true
}
c.cm.Unlock()
}
func (c *Conn) connect() error {
c.cm.Lock()
defer c.cm.Unlock()
if !c.closed && c.c != nil {
return nil
}
var err error
c.c, err = net.Dial(getProto(c.addr), c.addr)
if err != nil {
c.c = nil
return err
}
if c.br != nil {
c.br.Reset(c.c)
} else {
c.br = bufio.NewReaderSize(c.c, c.rSize)
}
if c.bw != nil {
c.bw.Reset(c.c)
} else {
c.bw = bufio.NewWriterSize(c.c, c.wSize)
}
return nil
}
func (c *Conn) writeLen(prefix byte, n int) error {
c.lenScratch[len(c.lenScratch)-1] = '\n'
c.lenScratch[len(c.lenScratch)-2] = '\r'
i := len(c.lenScratch) - 3
for {
c.lenScratch[i] = byte('0' + n%10)
i -= 1
n = n / 10
if n == 0 {
break
}
}
c.lenScratch[i] = prefix
_, err := c.bw.Write(c.lenScratch[i:])
return err
}
func (c *Conn) writeString(s string) error {
c.writeLen('$', len(s))
c.bw.WriteString(s)
_, err := c.bw.WriteString("\r\n")
return err
}
func (c *Conn) writeBytes(p []byte) error {
c.writeLen('$', len(p))
c.bw.Write(p)
_, err := c.bw.WriteString("\r\n")
return err
}
func (c *Conn) writeInt64(n int64) error {
return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
}
func (c *Conn) writeFloat64(n float64) error {
return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
}
func (c *Conn) writeCommand(cmd string, args []interface{}) (err error) {
c.writeLen('*', 1+len(args))
err = c.writeString(cmd)
for _, arg := range args {
if err != nil {
break
}
switch arg := arg.(type) {
case string:
err = c.writeString(arg)
case []byte:
err = c.writeBytes(arg)
case int:
err = c.writeInt64(int64(arg))
case int64:
err = c.writeInt64(arg)
case float64:
err = c.writeFloat64(arg)
case bool:
if arg {
err = c.writeString("1")
} else {
err = c.writeString("0")
}
case nil:
err = c.writeString("")
default:
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
err = c.writeBytes(buf.Bytes())
}
}
return err
}
func (c *Conn) readLine() ([]byte, error) {
p, err := c.br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
return nil, errors.New("ledis: long response line")
}
if err != nil {
return nil, err
}
i := len(p) - 2
if i < 0 || p[i] != '\r' {
return nil, errors.New("ledis: bad response line terminator")
}
return p[:i], nil
}
// parseLen parses bulk string and array lengths.
func parseLen(p []byte) (int, error) {
if len(p) == 0 {
return -1, errors.New("ledis: malformed length")
}
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
// handle $-1 and $-1 null replies.
return -1, nil
}
var n int
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return -1, errors.New("ledis: illegal bytes in length")
}
n += int(b - '0')
}
return n, nil
}
// parseInt parses an integer reply.
func parseInt(p []byte) (interface{}, error) {
if len(p) == 0 {
return 0, errors.New("ledis: malformed integer")
}
var negate bool
if p[0] == '-' {
negate = true
p = p[1:]
if len(p) == 0 {
return 0, errors.New("ledis: malformed integer")
}
}
var n int64
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return 0, errors.New("ledis: illegal bytes in length")
}
n += int64(b - '0')
}
if negate {
n = -n
}
return n, nil
}
var (
okReply interface{} = "OK"
pongReply interface{} = "PONG"
)
func (c *Conn) readBulkReplyTo(w io.Writer) error {
line, err := c.readLine()
if err != nil {
return err
}
if len(line) == 0 {
return errors.New("ledis: short response line")
}
switch line[0] {
case '-':
return Error(string(line[1:]))
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return err
}
var nn int64
if nn, err = io.CopyN(w, c.br, int64(n)); err != nil {
return err
} else if nn != int64(n) {
return io.ErrShortWrite
}
if line, err := c.readLine(); err != nil {
return err
} else if len(line) != 0 {
return errors.New("ledis: bad bulk string format")
}
return nil
default:
return fmt.Errorf("ledis: not invalid bulk string type, but %c", line[0])
}
}
func (c *Conn) readReply() (interface{}, error) {
line, err := c.readLine()
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, errors.New("ledis: short response line")
}
switch line[0] {
case '+':
switch {
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
// Avoid allocation for frequent "+OK" response.
return okReply, nil
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
// Avoid allocation in PING command benchmarks :)
return pongReply, nil
default:
return string(line[1:]), nil
}
case '-':
return Error(string(line[1:])), nil
case ':':
return parseInt(line[1:])
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(c.br, p)
if err != nil {
return nil, err
}
if line, err := c.readLine(); err != nil {
return nil, err
} else if len(line) != 0 {
return nil, errors.New("ledis: bad bulk string format")
}
return p, nil
case '*':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
r := make([]interface{}, n)
for i := range r {
r[i], err = c.readReply()
if err != nil {
return nil, err
}
}
return r, nil
}
return nil, errors.New("ledis: unexpected response line")
}
func (c *Client) newConn(addr string) *Conn {
co := NewConnSize(addr, c.cfg.ReadBufferSize, c.cfg.WriteBufferSize)
co.client = c
return co
}