redis/internal/proto/writer.go

196 lines
3.5 KiB
Go
Raw Normal View History

2018-08-17 13:56:37 +03:00
package proto
import (
"encoding"
"fmt"
"io"
"net"
2018-08-17 13:56:37 +03:00
"strconv"
"time"
2018-08-17 13:56:37 +03:00
2023-01-23 09:48:54 +03:00
"github.com/redis/go-redis/v9/internal/util"
2018-08-17 13:56:37 +03:00
)
2020-06-09 16:29:53 +03:00
type writer interface {
io.Writer
io.ByteWriter
// WriteString implement io.StringWriter.
2020-06-12 10:09:32 +03:00
WriteString(s string) (n int, err error)
2020-06-09 16:29:53 +03:00
}
2018-08-17 13:56:37 +03:00
type Writer struct {
2020-06-09 16:29:53 +03:00
writer
2018-08-17 13:56:37 +03:00
lenBuf []byte
numBuf []byte
}
2020-06-09 16:29:53 +03:00
func NewWriter(wr writer) *Writer {
2018-08-17 13:56:37 +03:00
return &Writer{
2020-06-09 16:29:53 +03:00
writer: wr,
2018-08-17 13:56:37 +03:00
lenBuf: make([]byte, 64),
numBuf: make([]byte, 64),
}
}
func (w *Writer) WriteArgs(args []interface{}) error {
if err := w.WriteByte(RespArray); err != nil {
2018-08-17 13:56:37 +03:00
return err
}
2020-06-09 16:29:53 +03:00
if err := w.writeLen(len(args)); err != nil {
2018-08-17 13:56:37 +03:00
return err
}
for _, arg := range args {
2020-06-09 16:29:53 +03:00
if err := w.WriteArg(arg); err != nil {
2018-08-17 13:56:37 +03:00
return err
}
}
return nil
}
func (w *Writer) writeLen(n int) error {
w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10)
w.lenBuf = append(w.lenBuf, '\r', '\n')
2020-06-09 16:29:53 +03:00
_, err := w.Write(w.lenBuf)
2018-08-17 13:56:37 +03:00
return err
}
2020-06-09 16:29:53 +03:00
func (w *Writer) WriteArg(v interface{}) error {
2018-08-17 13:56:37 +03:00
switch v := v.(type) {
case nil:
return w.string("")
case string:
return w.string(v)
case *string:
return w.string(*v)
2018-08-17 13:56:37 +03:00
case []byte:
return w.bytes(v)
case int:
return w.int(int64(v))
case *int:
return w.int(int64(*v))
2018-08-17 13:56:37 +03:00
case int8:
return w.int(int64(v))
case *int8:
return w.int(int64(*v))
2018-08-17 13:56:37 +03:00
case int16:
return w.int(int64(v))
case *int16:
return w.int(int64(*v))
2018-08-17 13:56:37 +03:00
case int32:
return w.int(int64(v))
case *int32:
return w.int(int64(*v))
2018-08-17 13:56:37 +03:00
case int64:
return w.int(v)
case *int64:
return w.int(*v)
2018-08-17 13:56:37 +03:00
case uint:
return w.uint(uint64(v))
case *uint:
return w.uint(uint64(*v))
2018-08-17 13:56:37 +03:00
case uint8:
return w.uint(uint64(v))
case *uint8:
return w.uint(uint64(*v))
2018-08-17 13:56:37 +03:00
case uint16:
return w.uint(uint64(v))
case *uint16:
return w.uint(uint64(*v))
2018-08-17 13:56:37 +03:00
case uint32:
return w.uint(uint64(v))
case *uint32:
return w.uint(uint64(*v))
2018-08-17 13:56:37 +03:00
case uint64:
return w.uint(v)
case *uint64:
return w.uint(*v)
2018-08-17 13:56:37 +03:00
case float32:
return w.float(float64(v))
case *float32:
return w.float(float64(*v))
2018-08-17 13:56:37 +03:00
case float64:
return w.float(v)
case *float64:
return w.float(*v)
2018-08-17 13:56:37 +03:00
case bool:
if v {
return w.int(1)
}
2019-07-25 13:53:00 +03:00
return w.int(0)
case *bool:
if *v {
return w.int(1)
}
return w.int(0)
case time.Time:
2020-05-09 17:30:16 +03:00
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
case time.Duration:
return w.int(v.Nanoseconds())
2018-08-17 13:56:37 +03:00
case encoding.BinaryMarshaler:
b, err := v.MarshalBinary()
if err != nil {
return err
}
return w.bytes(b)
2024-02-23 09:51:20 +03:00
case encoding.TextMarshaler:
b, err := v.MarshalText()
if err != nil {
return err
}
return w.bytes(b)
case net.IP:
return w.bytes(v)
2018-08-17 13:56:37 +03:00
default:
return fmt.Errorf(
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)
}
}
func (w *Writer) bytes(b []byte) error {
if err := w.WriteByte(RespString); err != nil {
2018-08-17 13:56:37 +03:00
return err
}
2020-06-09 16:29:53 +03:00
if err := w.writeLen(len(b)); err != nil {
2018-08-17 13:56:37 +03:00
return err
}
2020-06-09 16:29:53 +03:00
if _, err := w.Write(b); err != nil {
2018-08-17 13:56:37 +03:00
return err
}
return w.crlf()
}
func (w *Writer) string(s string) error {
return w.bytes(util.StringToBytes(s))
}
func (w *Writer) uint(n uint64) error {
w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) int(n int64) error {
w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) float(f float64) error {
w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64)
return w.bytes(w.numBuf)
}
func (w *Writer) crlf() error {
2020-06-09 16:29:53 +03:00
if err := w.WriteByte('\r'); err != nil {
2018-08-17 13:56:37 +03:00
return err
}
2020-06-09 16:29:53 +03:00
return w.WriteByte('\n')
2018-08-17 13:56:37 +03:00
}