diff --git a/internal/server/aof.go b/internal/server/aof.go index ffab4957..250080b1 100644 --- a/internal/server/aof.go +++ b/internal/server/aof.go @@ -399,73 +399,80 @@ func (s liveAOFSwitches) Error() string { return goingLive } -func (s *Server) cmdAOFMD5(msg *Message) (res resp.Value, err error) { +// AOFMD5 pos size +func (s *Server) cmdAOFMD5(msg *Message) (resp.Value, error) { start := time.Now() - vs := msg.Args[1:] - var ok bool - var spos, ssize string - if vs, spos, ok = tokenval(vs); !ok || spos == "" { - return NOMessage, errInvalidNumberOfArguments + // >> Args + + args := msg.Args + if len(args) != 3 { + return retrerr(errInvalidNumberOfArguments) } - if vs, ssize, ok = tokenval(vs); !ok || ssize == "" { - return NOMessage, errInvalidNumberOfArguments - } - if len(vs) != 0 { - return NOMessage, errInvalidNumberOfArguments - } - pos, err := strconv.ParseInt(spos, 10, 64) + pos, err := strconv.ParseInt(args[1], 10, 64) if err != nil || pos < 0 { - return NOMessage, errInvalidArgument(spos) + return retrerr(errInvalidArgument(args[1])) } - size, err := strconv.ParseInt(ssize, 10, 64) + size, err := strconv.ParseInt(args[2], 10, 64) if err != nil || size < 0 { - return NOMessage, errInvalidArgument(ssize) + return retrerr(errInvalidArgument(args[2])) } + + // >> Operation + sum, err := s.checksum(pos, size) if err != nil { - return NOMessage, err + return retrerr(err) } - switch msg.OutputType { - case JSON: - res = resp.StringValue( - fmt.Sprintf(`{"ok":true,"md5":"%s","elapsed":"%s"}`, sum, time.Since(start))) - case RESP: - res = resp.SimpleStringValue(sum) + + // >> Response + + if msg.OutputType == JSON { + return resp.StringValue(fmt.Sprintf( + `{"ok":true,"md5":"%s","elapsed":"%s"}`, + sum, time.Since(start))), nil } - return res, nil + return resp.SimpleStringValue(sum), nil } -func (s *Server) cmdAOF(msg *Message) (res resp.Value, err error) { +// AOF pos +func (s *Server) cmdAOF(msg *Message) (resp.Value, error) { if s.aof == nil { - return NOMessage, errors.New("aof disabled") + return retrerr(errors.New("aof disabled")) } - vs := msg.Args[1:] - var ok bool - var spos string - if vs, spos, ok = tokenval(vs); !ok || spos == "" { - return NOMessage, errInvalidNumberOfArguments + // >> Args + + args := msg.Args + if len(args) != 2 { + return retrerr(errInvalidNumberOfArguments) } - if len(vs) != 0 { - return NOMessage, errInvalidNumberOfArguments - } - pos, err := strconv.ParseInt(spos, 10, 64) + + pos, err := strconv.ParseInt(args[1], 10, 64) if err != nil || pos < 0 { - return NOMessage, errInvalidArgument(spos) + return retrerr(errInvalidArgument(args[1])) } + + // >> Operation + f, err := os.Open(s.aof.Name()) if err != nil { - return NOMessage, err + return retrerr(err) } defer f.Close() + n, err := f.Seek(0, 2) if err != nil { - return NOMessage, err + return retrerr(err) } + if n < pos { - return NOMessage, errors.New("pos is too big, must be less that the aof_size of leader") + return retrerr(errors.New( + "pos is too big, must be less that the aof_size of leader")) } + + // >> Response + var ls liveAOFSwitches ls.pos = pos return NOMessage, ls @@ -478,8 +485,6 @@ func (s *Server) liveAOF(pos int64, conn net.Conn, rd *PipelineReader, msg *Mess if err != nil { return err } - defer f.Close() - s.mu.Lock() s.aofconnM[conn] = f s.mu.Unlock() @@ -488,91 +493,42 @@ func (s *Server) liveAOF(pos int64, conn net.Conn, rd *PipelineReader, msg *Mess delete(s.aofconnM, conn) s.mu.Unlock() conn.Close() + f.Close() }() if _, err := conn.Write([]byte("+OK\r\n")); err != nil { return err } - if _, err := f.Seek(pos, 0); err != nil { return err } - cond := sync.NewCond(&sync.Mutex{}) - var mustQuit bool + var wg sync.WaitGroup + wg.Add(1) go func() { defer func() { - cond.L.Lock() - mustQuit = true - cond.Broadcast() - cond.L.Unlock() + f.Close() + conn.Close() + wg.Done() }() - for { - vs, err := rd.ReadMessages() - if err != nil { - if err != io.EOF { - log.Error(err) - } - return - } - for _, v := range vs { - switch v.Command() { - default: - log.Error("received a live command that was not QUIT") - return - case "quit", "": - return - } - } - } + // Any incoming message should end the connection + rd.ReadMessages() }() - go func() { - defer func() { - cond.L.Lock() - mustQuit = true - cond.Broadcast() - cond.L.Unlock() - }() - err := func() error { - _, err := io.Copy(conn, f) - if err != nil { + _, err = io.Copy(conn, f) + if err != nil { + return err + } + b := make([]byte, 4096*2) + for { + n, err := f.Read(b) + if n > 0 { + if _, err := conn.Write(b[:n]); err != nil { return err } - - b := make([]byte, 4096) - // The reader needs to be OK with the eof not - for { - n, err := f.Read(b) - if n > 0 { - if _, err := conn.Write(b[:n]); err != nil { - return err - } - } - if err != io.EOF { - if err != nil { - return err - } - continue - } - s.fcond.L.Lock() - s.fcond.Wait() - s.fcond.L.Unlock() - } - }() - if err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") && - !strings.Contains(err.Error(), "bad file descriptor") { - log.Error(err) - } - return } - }() - for { - cond.L.Lock() - if mustQuit { - cond.L.Unlock() - return nil + if err == io.EOF { + time.Sleep(time.Second / 4) + } else if err != nil { + return err } - cond.Wait() - cond.L.Unlock() } } diff --git a/internal/server/server.go b/internal/server/server.go index dd4948a7..ea49da26 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -302,10 +302,13 @@ func Serve(opts Options) error { ln := s.ln s.ln = nil s.lnmu.Unlock() - if ln != nil { ln.Close() } + for conn, f := range s.aofconnM { + conn.Close() + f.Close() + } }() // Load the queue before the aof diff --git a/tests/aof_test.go b/tests/aof_test.go index edde81d6..6ffb0b1b 100644 --- a/tests/aof_test.go +++ b/tests/aof_test.go @@ -2,16 +2,31 @@ package tests import ( "bytes" + "crypto/md5" + "encoding/hex" "errors" "fmt" + "math/rand" + "time" + + "github.com/gomodule/redigo/redis" ) func subTestAOF(g *testGroup) { g.regSubTest("loading", aof_loading_test) - // g.regSubTest("AOFMD5", aof_AOFMD5_test) + g.regSubTest("AOF", aof_AOF_test) + g.regSubTest("AOFMD5", aof_AOFMD5_test) } func loadAOFAndClose(aof any) error { + mc, err := loadAOF(aof) + if mc != nil { + mc.Close() + } + return err +} + +func loadAOF(aof any) (*mockServer, error) { var aofb []byte switch aof := aof.(type) { case []byte: @@ -19,17 +34,13 @@ func loadAOFAndClose(aof any) error { case string: aofb = []byte(aof) default: - return errors.New("aof is not string or bytes") + return nil, errors.New("aof is not string or bytes") } - mc, err := mockOpenServer(MockServerOptions{ + return mockOpenServer(MockServerOptions{ Silent: true, Metrics: false, AOFData: aofb, }) - if mc != nil { - mc.Close() - } - return err } func aof_loading_test(mc *mockServer) error { @@ -79,10 +90,129 @@ func aof_loading_test(mc *mockServer) error { return fmt.Errorf("expected '%v', got '%v'", "Protocol error: expected '$', got '+'", err) } - return nil } -// func aof_AOFMD5_test(mc *mockServer) error { -// return nil -// } +func aof_AOFMD5_test(mc *mockServer) error { + for i := 0; i < 10000; i++ { + _, err := mc.Do("SET", "fleet", rand.Int(), + "POINT", rand.Float64()*180-90, rand.Float64()*360-180) + if err != nil { + return err + } + } + aof, err := mc.readAOF() + if err != nil { + return err + } + check := func(start, size int) func(s string) error { + return func(s string) error { + sum := md5.Sum(aof[start : start+size]) + val := hex.EncodeToString(sum[:]) + if s != val { + return fmt.Errorf("expected '%s', got '%s'", val, s) + } + return nil + } + } + return mc.DoBatch( + Do("AOFMD5").Err("wrong number of arguments for 'aofmd5' command"), + Do("AOFMD5", 0).Err("wrong number of arguments for 'aofmd5' command"), + Do("AOFMD5", 0, 0, 1).Err("wrong number of arguments for 'aofmd5' command"), + Do("AOFMD5", -1, 0).Err("invalid argument '-1'"), + Do("AOFMD5", 1, -1).Err("invalid argument '-1'"), + Do("AOFMD5", 0, 100000000000).Err("EOF"), + Do("AOFMD5", 0, 0).Str("d41d8cd98f00b204e9800998ecf8427e"), + Do("AOFMD5", 0, 0).JSON().Str(`{"ok":true,"md5":"d41d8cd98f00b204e9800998ecf8427e"}`), + Do("AOFMD5", 0, 0).Func(check(0, 0)), + Do("AOFMD5", 0, 1).Func(check(0, 1)), + Do("AOFMD5", 0, 100).Func(check(0, 100)), + Do("AOFMD5", 1002, 4321).Func(check(1002, 4321)), + ) +} + +func aof_AOF_test(mc *mockServer) error { + var argss [][]interface{} + for i := 0; i < 10000; i++ { + args := []interface{}{"SET", "fleet", fmt.Sprint(rand.Int()), + "POINT", fmt.Sprint(rand.Float64()*180 - 90), + fmt.Sprint(rand.Float64()*360 - 180)} + argss = append(argss, args) + _, err := mc.Do(fmt.Sprint(args[0]), args[1:]...) + if err != nil { + return err + } + } + readAll := func() (conn redis.Conn, err error) { + conn, err = redis.Dial("tcp", fmt.Sprintf(":%d", mc.port), + redis.DialReadTimeout(time.Second)) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + conn.Close() + conn = nil + } + }() + if err := conn.Send("AOF", 0); err != nil { + return nil, err + } + if err := conn.Flush(); err != nil { + return nil, err + } + str, err := redis.String(conn.Receive()) + if err != nil { + return nil, err + } + if str != "OK" { + return nil, fmt.Errorf("expected '%s', got '%s'", "OK", str) + } + var t bool + for i := 0; i < len(argss); i++ { + args, err := redis.Values(conn.Receive()) + if err != nil { + return nil, err + } + if t || (len(args) == len(argss[0]) && + fmt.Sprintf("%s", args[2]) == fmt.Sprintf("%s", argss[0][2])) { + t = true + if fmt.Sprintf("%s", args[2]) != + fmt.Sprintf("%s", argss[i][2]) { + return nil, fmt.Errorf("expected '%s', got '%s'", + argss[i][2], args[2]) + } + } else { + i-- + } + } + return conn, nil + } + + conn, err := readAll() + if err != nil { + return err + } + defer conn.Close() + _, err = conn.Do("fancy") // non-existent error + if err == nil || err.Error() != "EOF" { + return fmt.Errorf("expected '%v', got '%v'", "EOF", err) + } + + conn, err = readAll() + if err != nil { + return err + } + defer conn.Close() + _, err = conn.Do("quit") + if err == nil || err.Error() != "EOF" { + return fmt.Errorf("expected '%v', got '%v'", "EOF", err) + } + + return mc.DoBatch( + Do("AOF").Err("wrong number of arguments for 'aof' command"), + Do("AOF", 0, 0).Err("wrong number of arguments for 'aof' command"), + Do("AOF", -1).Err("invalid argument '-1'"), + Do("AOF", 1000000000000).Err("pos is too big, must be less that the aof_size of leader"), + ) +} diff --git a/tests/mock_test.go b/tests/mock_test.go index 36a30814..c87015b1 100644 --- a/tests/mock_test.go +++ b/tests/mock_test.go @@ -44,9 +44,9 @@ type mockServer struct { shutdown chan bool } -// func (mc *mockServer) readAOF() ([]byte, error) { -// return os.ReadFile(filepath.Join(mc.dir, "appendonly.aof")) -// } +func (mc *mockServer) readAOF() ([]byte, error) { + return os.ReadFile(filepath.Join(mc.dir, "appendonly.aof")) +} func (mc *mockServer) metricsPort() int { return mc.mport diff --git a/tests/tests_test.go b/tests/tests_test.go index b6d33eb8..7c2abcd0 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -125,14 +125,15 @@ func runTestGroups(t *testing.T) { fmt.Printf(bright+"Testing %s\n"+clear, g.name) g.printed.Store(true) } + const frtmp = "[" + magenta + " " + clear + "] %s (running) " for _, s := range g.subs { if !s.skipped.Load() && !s.printedName.Load() { - fmt.Printf("[..] %s (running) ", s.name) + fmt.Printf(frtmp, s.name) s.printedName.Store(true) } if s.done.Load() && !s.printedResult.Load() { fmt.Printf("\r") - msg := fmt.Sprintf("[..] %s (running) ", s.name) + msg := fmt.Sprintf(frtmp, s.name) fmt.Print(strings.Repeat(" ", len(msg))) fmt.Printf("\r") if s.err != nil {