Better AOF/AOFMD5 tests

This commit is contained in:
tidwall 2022-09-26 15:43:14 -07:00
parent c093b041e1
commit ad8d40dee5
5 changed files with 218 additions and 128 deletions

View File

@ -399,73 +399,80 @@ func (s liveAOFSwitches) Error() string {
return goingLive
}
func (s *Server) cmdAOFMD5(msg *Message) (res resp.Value, err error) {
// AOFMD5 pos size
func (s *Server) cmdAOFMD5(msg *Message) (resp.Value, error) {
start := time.Now()
vs := msg.Args[1:]
var ok bool
var spos, ssize string
if vs, spos, ok = tokenval(vs); !ok || spos == "" {
return NOMessage, errInvalidNumberOfArguments
// >> Args
args := msg.Args
if len(args) != 3 {
return retrerr(errInvalidNumberOfArguments)
}
if vs, ssize, ok = tokenval(vs); !ok || ssize == "" {
return NOMessage, errInvalidNumberOfArguments
}
if len(vs) != 0 {
return NOMessage, errInvalidNumberOfArguments
}
pos, err := strconv.ParseInt(spos, 10, 64)
pos, err := strconv.ParseInt(args[1], 10, 64)
if err != nil || pos < 0 {
return NOMessage, errInvalidArgument(spos)
return retrerr(errInvalidArgument(args[1]))
}
size, err := strconv.ParseInt(ssize, 10, 64)
size, err := strconv.ParseInt(args[2], 10, 64)
if err != nil || size < 0 {
return NOMessage, errInvalidArgument(ssize)
return retrerr(errInvalidArgument(args[2]))
}
// >> Operation
sum, err := s.checksum(pos, size)
if err != nil {
return NOMessage, err
return retrerr(err)
}
switch msg.OutputType {
case JSON:
res = resp.StringValue(
fmt.Sprintf(`{"ok":true,"md5":"%s","elapsed":"%s"}`, sum, time.Since(start)))
case RESP:
res = resp.SimpleStringValue(sum)
// >> Response
if msg.OutputType == JSON {
return resp.StringValue(fmt.Sprintf(
`{"ok":true,"md5":"%s","elapsed":"%s"}`,
sum, time.Since(start))), nil
}
return res, nil
return resp.SimpleStringValue(sum), nil
}
func (s *Server) cmdAOF(msg *Message) (res resp.Value, err error) {
// AOF pos
func (s *Server) cmdAOF(msg *Message) (resp.Value, error) {
if s.aof == nil {
return NOMessage, errors.New("aof disabled")
return retrerr(errors.New("aof disabled"))
}
vs := msg.Args[1:]
var ok bool
var spos string
if vs, spos, ok = tokenval(vs); !ok || spos == "" {
return NOMessage, errInvalidNumberOfArguments
// >> Args
args := msg.Args
if len(args) != 2 {
return retrerr(errInvalidNumberOfArguments)
}
if len(vs) != 0 {
return NOMessage, errInvalidNumberOfArguments
}
pos, err := strconv.ParseInt(spos, 10, 64)
pos, err := strconv.ParseInt(args[1], 10, 64)
if err != nil || pos < 0 {
return NOMessage, errInvalidArgument(spos)
return retrerr(errInvalidArgument(args[1]))
}
// >> Operation
f, err := os.Open(s.aof.Name())
if err != nil {
return NOMessage, err
return retrerr(err)
}
defer f.Close()
n, err := f.Seek(0, 2)
if err != nil {
return NOMessage, err
return retrerr(err)
}
if n < pos {
return NOMessage, errors.New("pos is too big, must be less that the aof_size of leader")
return retrerr(errors.New(
"pos is too big, must be less that the aof_size of leader"))
}
// >> Response
var ls liveAOFSwitches
ls.pos = pos
return NOMessage, ls
@ -478,8 +485,6 @@ func (s *Server) liveAOF(pos int64, conn net.Conn, rd *PipelineReader, msg *Mess
if err != nil {
return err
}
defer f.Close()
s.mu.Lock()
s.aofconnM[conn] = f
s.mu.Unlock()
@ -488,58 +493,31 @@ func (s *Server) liveAOF(pos int64, conn net.Conn, rd *PipelineReader, msg *Mess
delete(s.aofconnM, conn)
s.mu.Unlock()
conn.Close()
f.Close()
}()
if _, err := conn.Write([]byte("+OK\r\n")); err != nil {
return err
}
if _, err := f.Seek(pos, 0); err != nil {
return err
}
cond := sync.NewCond(&sync.Mutex{})
var mustQuit bool
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer func() {
cond.L.Lock()
mustQuit = true
cond.Broadcast()
cond.L.Unlock()
f.Close()
conn.Close()
wg.Done()
}()
for {
vs, err := rd.ReadMessages()
if err != nil {
if err != io.EOF {
log.Error(err)
}
return
}
for _, v := range vs {
switch v.Command() {
default:
log.Error("received a live command that was not QUIT")
return
case "quit", "":
return
}
}
}
// Any incoming message should end the connection
rd.ReadMessages()
}()
go func() {
defer func() {
cond.L.Lock()
mustQuit = true
cond.Broadcast()
cond.L.Unlock()
}()
err := func() error {
_, err := io.Copy(conn, f)
_, err = io.Copy(conn, f)
if err != nil {
return err
}
b := make([]byte, 4096)
// The reader needs to be OK with the eof not
b := make([]byte, 4096*2)
for {
n, err := f.Read(b)
if n > 0 {
@ -547,32 +525,10 @@ func (s *Server) liveAOF(pos int64, conn net.Conn, rd *PipelineReader, msg *Mess
return err
}
}
if err != io.EOF {
if err != nil {
if err == io.EOF {
time.Sleep(time.Second / 4)
} else if err != nil {
return err
}
continue
}
s.fcond.L.Lock()
s.fcond.Wait()
s.fcond.L.Unlock()
}
}()
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") &&
!strings.Contains(err.Error(), "bad file descriptor") {
log.Error(err)
}
return
}
}()
for {
cond.L.Lock()
if mustQuit {
cond.L.Unlock()
return nil
}
cond.Wait()
cond.L.Unlock()
}
}

View File

@ -302,10 +302,13 @@ func Serve(opts Options) error {
ln := s.ln
s.ln = nil
s.lnmu.Unlock()
if ln != nil {
ln.Close()
}
for conn, f := range s.aofconnM {
conn.Close()
f.Close()
}
}()
// Load the queue before the aof

View File

@ -2,16 +2,31 @@ package tests
import (
"bytes"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"time"
"github.com/gomodule/redigo/redis"
)
func subTestAOF(g *testGroup) {
g.regSubTest("loading", aof_loading_test)
// g.regSubTest("AOFMD5", aof_AOFMD5_test)
g.regSubTest("AOF", aof_AOF_test)
g.regSubTest("AOFMD5", aof_AOFMD5_test)
}
func loadAOFAndClose(aof any) error {
mc, err := loadAOF(aof)
if mc != nil {
mc.Close()
}
return err
}
func loadAOF(aof any) (*mockServer, error) {
var aofb []byte
switch aof := aof.(type) {
case []byte:
@ -19,17 +34,13 @@ func loadAOFAndClose(aof any) error {
case string:
aofb = []byte(aof)
default:
return errors.New("aof is not string or bytes")
return nil, errors.New("aof is not string or bytes")
}
mc, err := mockOpenServer(MockServerOptions{
return mockOpenServer(MockServerOptions{
Silent: true,
Metrics: false,
AOFData: aofb,
})
if mc != nil {
mc.Close()
}
return err
}
func aof_loading_test(mc *mockServer) error {
@ -79,10 +90,129 @@ func aof_loading_test(mc *mockServer) error {
return fmt.Errorf("expected '%v', got '%v'",
"Protocol error: expected '$', got '+'", err)
}
return nil
}
// func aof_AOFMD5_test(mc *mockServer) error {
// return nil
// }
func aof_AOFMD5_test(mc *mockServer) error {
for i := 0; i < 10000; i++ {
_, err := mc.Do("SET", "fleet", rand.Int(),
"POINT", rand.Float64()*180-90, rand.Float64()*360-180)
if err != nil {
return err
}
}
aof, err := mc.readAOF()
if err != nil {
return err
}
check := func(start, size int) func(s string) error {
return func(s string) error {
sum := md5.Sum(aof[start : start+size])
val := hex.EncodeToString(sum[:])
if s != val {
return fmt.Errorf("expected '%s', got '%s'", val, s)
}
return nil
}
}
return mc.DoBatch(
Do("AOFMD5").Err("wrong number of arguments for 'aofmd5' command"),
Do("AOFMD5", 0).Err("wrong number of arguments for 'aofmd5' command"),
Do("AOFMD5", 0, 0, 1).Err("wrong number of arguments for 'aofmd5' command"),
Do("AOFMD5", -1, 0).Err("invalid argument '-1'"),
Do("AOFMD5", 1, -1).Err("invalid argument '-1'"),
Do("AOFMD5", 0, 100000000000).Err("EOF"),
Do("AOFMD5", 0, 0).Str("d41d8cd98f00b204e9800998ecf8427e"),
Do("AOFMD5", 0, 0).JSON().Str(`{"ok":true,"md5":"d41d8cd98f00b204e9800998ecf8427e"}`),
Do("AOFMD5", 0, 0).Func(check(0, 0)),
Do("AOFMD5", 0, 1).Func(check(0, 1)),
Do("AOFMD5", 0, 100).Func(check(0, 100)),
Do("AOFMD5", 1002, 4321).Func(check(1002, 4321)),
)
}
func aof_AOF_test(mc *mockServer) error {
var argss [][]interface{}
for i := 0; i < 10000; i++ {
args := []interface{}{"SET", "fleet", fmt.Sprint(rand.Int()),
"POINT", fmt.Sprint(rand.Float64()*180 - 90),
fmt.Sprint(rand.Float64()*360 - 180)}
argss = append(argss, args)
_, err := mc.Do(fmt.Sprint(args[0]), args[1:]...)
if err != nil {
return err
}
}
readAll := func() (conn redis.Conn, err error) {
conn, err = redis.Dial("tcp", fmt.Sprintf(":%d", mc.port),
redis.DialReadTimeout(time.Second))
if err != nil {
return nil, err
}
defer func() {
if err != nil {
conn.Close()
conn = nil
}
}()
if err := conn.Send("AOF", 0); err != nil {
return nil, err
}
if err := conn.Flush(); err != nil {
return nil, err
}
str, err := redis.String(conn.Receive())
if err != nil {
return nil, err
}
if str != "OK" {
return nil, fmt.Errorf("expected '%s', got '%s'", "OK", str)
}
var t bool
for i := 0; i < len(argss); i++ {
args, err := redis.Values(conn.Receive())
if err != nil {
return nil, err
}
if t || (len(args) == len(argss[0]) &&
fmt.Sprintf("%s", args[2]) == fmt.Sprintf("%s", argss[0][2])) {
t = true
if fmt.Sprintf("%s", args[2]) !=
fmt.Sprintf("%s", argss[i][2]) {
return nil, fmt.Errorf("expected '%s', got '%s'",
argss[i][2], args[2])
}
} else {
i--
}
}
return conn, nil
}
conn, err := readAll()
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Do("fancy") // non-existent error
if err == nil || err.Error() != "EOF" {
return fmt.Errorf("expected '%v', got '%v'", "EOF", err)
}
conn, err = readAll()
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Do("quit")
if err == nil || err.Error() != "EOF" {
return fmt.Errorf("expected '%v', got '%v'", "EOF", err)
}
return mc.DoBatch(
Do("AOF").Err("wrong number of arguments for 'aof' command"),
Do("AOF", 0, 0).Err("wrong number of arguments for 'aof' command"),
Do("AOF", -1).Err("invalid argument '-1'"),
Do("AOF", 1000000000000).Err("pos is too big, must be less that the aof_size of leader"),
)
}

View File

@ -44,9 +44,9 @@ type mockServer struct {
shutdown chan bool
}
// func (mc *mockServer) readAOF() ([]byte, error) {
// return os.ReadFile(filepath.Join(mc.dir, "appendonly.aof"))
// }
func (mc *mockServer) readAOF() ([]byte, error) {
return os.ReadFile(filepath.Join(mc.dir, "appendonly.aof"))
}
func (mc *mockServer) metricsPort() int {
return mc.mport

View File

@ -125,14 +125,15 @@ func runTestGroups(t *testing.T) {
fmt.Printf(bright+"Testing %s\n"+clear, g.name)
g.printed.Store(true)
}
const frtmp = "[" + magenta + " " + clear + "] %s (running) "
for _, s := range g.subs {
if !s.skipped.Load() && !s.printedName.Load() {
fmt.Printf("[..] %s (running) ", s.name)
fmt.Printf(frtmp, s.name)
s.printedName.Store(true)
}
if s.done.Load() && !s.printedResult.Load() {
fmt.Printf("\r")
msg := fmt.Sprintf("[..] %s (running) ", s.name)
msg := fmt.Sprintf(frtmp, s.name)
fmt.Print(strings.Repeat(" ", len(msg)))
fmt.Printf("\r")
if s.err != nil {