redis/internal/proto/writer.go

221 lines
4.1 KiB
Go

package proto
import (
"encoding"
"fmt"
"io"
"net"
"reflect"
"strconv"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
type writer interface {
io.Writer
io.ByteWriter
// WriteString implement io.StringWriter.
WriteString(s string) (n int, err error)
}
type Writer struct {
writer
lenBuf []byte
numBuf []byte
}
func NewWriter(wr writer) *Writer {
return &Writer{
writer: wr,
lenBuf: make([]byte, 64),
numBuf: make([]byte, 64),
}
}
func (w *Writer) WriteArgs(args []interface{}) error {
if err := w.WriteByte(RespArray); err != nil {
return err
}
if err := w.writeLen(len(args)); err != nil {
return err
}
for _, arg := range args {
if err := w.WriteArg(arg); err != nil {
return err
}
}
return nil
}
func (w *Writer) writeLen(n int) error {
w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10)
w.lenBuf = append(w.lenBuf, '\r', '\n')
_, err := w.Write(w.lenBuf)
return err
}
func (w *Writer) WriteArg(v interface{}) error {
switch v := v.(type) {
case nil:
return w.string("")
case string:
return w.string(v)
case *string:
return w.string(*v)
case []byte:
return w.bytes(v)
case int:
return w.int(int64(v))
case *int:
return w.int(int64(*v))
case int8:
return w.int(int64(v))
case *int8:
return w.int(int64(*v))
case int16:
return w.int(int64(v))
case *int16:
return w.int(int64(*v))
case int32:
return w.int(int64(v))
case *int32:
return w.int(int64(*v))
case int64:
return w.int(v)
case *int64:
return w.int(*v)
case uint:
return w.uint(uint64(v))
case *uint:
return w.uint(uint64(*v))
case uint8:
return w.uint(uint64(v))
case *uint8:
return w.uint(uint64(*v))
case uint16:
return w.uint(uint64(v))
case *uint16:
return w.uint(uint64(*v))
case uint32:
return w.uint(uint64(v))
case *uint32:
return w.uint(uint64(*v))
case uint64:
return w.uint(v)
case *uint64:
return w.uint(*v)
case float32:
return w.float(float64(v))
case *float32:
return w.float(float64(*v))
case float64:
return w.float(v)
case *float64:
return w.float(*v)
case bool:
if v {
return w.int(1)
}
return w.int(0)
case *bool:
if *v {
return w.int(1)
}
return w.int(0)
case time.Time:
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
case time.Duration:
return w.int(v.Nanoseconds())
case encoding.BinaryMarshaler:
b, err := v.MarshalBinary()
if err != nil {
return err
}
return w.bytes(b)
case net.IP:
return w.bytes(v)
default:
return w.writeArgExtra(v)
}
}
func (w *Writer) writeArgExtra(v interface{}) error {
var (
rfValue = reflect.ValueOf(v)
rfKind = rfValue.Kind()
)
switch rfKind {
case reflect.Bool:
if rfValue.Bool() {
return w.int(1)
}
return w.int(0)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return w.int(rfValue.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return w.uint(rfValue.Uint())
case reflect.Float32, reflect.Float64:
return w.float(rfValue.Float())
case reflect.String:
return w.string(rfValue.String())
case reflect.Slice:
if rfValue.Type().Elem().Kind() == reflect.Uint8 {
return w.bytes(rfValue.Bytes())
}
fallthrough
default:
return fmt.Errorf(
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)
}
}
func (w *Writer) bytes(b []byte) error {
if err := w.WriteByte(RespString); err != nil {
return err
}
if err := w.writeLen(len(b)); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
return w.crlf()
}
func (w *Writer) string(s string) error {
return w.bytes(util.StringToBytes(s))
}
func (w *Writer) uint(n uint64) error {
w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) int(n int64) error {
w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) float(f float64) error {
w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64)
return w.bytes(w.numBuf)
}
func (w *Writer) crlf() error {
if err := w.WriteByte('\r'); err != nil {
return err
}
return w.WriteByte('\n')
}