Replace abool/aint with new go 1.19 atomics

This commit is contained in:
tidwall 2022-09-27 10:15:31 -07:00
parent 06ebeecf4a
commit 8608ed0917
10 changed files with 53 additions and 101 deletions

View File

@ -1,32 +0,0 @@
package server
import (
"sync/atomic"
)
type aint struct{ v uintptr }
func (a *aint) add(d int) int {
if d < 0 {
return int(atomic.AddUintptr(&a.v, ^uintptr((d*-1)-1)))
}
return int(atomic.AddUintptr(&a.v, uintptr(d)))
}
func (a *aint) get() int {
return int(atomic.LoadUintptr(&a.v))
}
func (a *aint) set(i int) int {
return int(atomic.SwapUintptr(&a.v, uintptr(i)))
}
type abool struct{ v uint32 }
func (a *abool) on() bool {
return atomic.LoadUint32(&a.v) != 0
}
func (a *abool) set(t bool) bool {
if t {
return atomic.SwapUint32(&a.v, 1) != 0
}
return atomic.SwapUint32(&a.v, 0) != 0
}

View File

@ -1,19 +0,0 @@
package server
import "testing"
func TestAtomicInt(t *testing.T) {
var x aint
x.set(10)
if x.get() != 10 {
t.Fatalf("expected %v, got %v", 10, x.get())
}
x.add(-9)
if x.get() != 1 {
t.Fatalf("expected %v, got %v", 1, x.get())
}
x.add(-1)
if x.get() != 0 {
t.Fatalf("expected %v, got %v", 0, x.get())
}
}

View File

@ -144,7 +144,7 @@ func (s *Server) followCheckSome(addr string, followc int, auth string,
} }
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.followc.get() != followc { if int(s.followc.Load()) != followc {
return 0, errNoLongerFollowing return 0, errNoLongerFollowing
} }
if s.aofsz < checksumsz { if s.aofsz < checksumsz {

View File

@ -573,7 +573,7 @@ func (s *Server) cmdFLUSHDB(msg *Message) (resp.Value, commandDetails, error) {
// (HASH geohash)|(STRING value) // (HASH geohash)|(STRING value)
func (s *Server) cmdSET(msg *Message) (resp.Value, commandDetails, error) { func (s *Server) cmdSET(msg *Message) (resp.Value, commandDetails, error) {
start := time.Now() start := time.Now()
if s.config.maxMemory() > 0 && s.outOfMemory.on() { if s.config.maxMemory() > 0 && s.outOfMemory.Load() {
return retwerr(errOOM) return retwerr(errOOM)
} }
@ -780,7 +780,7 @@ func retrerr(err error) (resp.Value, error) {
// FSET key id [XX] field value [field value...] // FSET key id [XX] field value [field value...]
func (s *Server) cmdFSET(msg *Message) (resp.Value, commandDetails, error) { func (s *Server) cmdFSET(msg *Message) (resp.Value, commandDetails, error) {
start := time.Now() start := time.Now()
if s.config.maxMemory() > 0 && s.outOfMemory.on() { if s.config.maxMemory() > 0 && s.outOfMemory.Load() {
return retwerr(errOOM) return retwerr(errOOM)
} }

View File

@ -83,10 +83,11 @@ func (s *Server) cmdFollow(msg *Message) (res resp.Value, err error) {
} }
s.config.write(false) s.config.write(false)
if update { if update {
s.followc.add(1) s.followc.Add(1)
if s.config.followHost() != "" { if s.config.followHost() != "" {
log.Infof("following new host '%s' '%s'.", host, sport) log.Infof("following new host '%s' '%s'.", host, sport)
go s.follow(s.config.followHost(), s.config.followPort(), s.followc.get()) go s.follow(s.config.followHost(), s.config.followPort(),
int(s.followc.Load()))
} else { } else {
log.Infof("following no one") log.Infof("following no one")
} }
@ -152,7 +153,7 @@ func doServer(conn *RESPConn) (map[string]string, error) {
func (s *Server) followHandleCommand(args []string, followc int, w io.Writer) (int, error) { func (s *Server) followHandleCommand(args []string, followc int, w io.Writer) (int, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.followc.get() != followc { if int(s.followc.Load()) != followc {
return s.aofsz, errNoLongerFollowing return s.aofsz, errNoLongerFollowing
} }
msg := &Message{Args: args} msg := &Message{Args: args}
@ -187,7 +188,7 @@ func (s *Server) followDoLeaderAuth(conn *RESPConn, auth string) error {
} }
func (s *Server) followStep(host string, port int, followc int) error { func (s *Server) followStep(host string, port int, followc int) error {
if s.followc.get() != followc { if int(s.followc.Load()) != followc {
return errNoLongerFollowing return errNoLongerFollowing
} }
s.mu.Lock() s.mu.Lock()

View File

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/tidwall/buntdb" "github.com/tidwall/buntdb"
@ -501,7 +502,7 @@ type Hook struct {
query string query string
epm *endpoint.Manager epm *endpoint.Manager
expires time.Time expires time.Time
counter *aint // counter that grows when a message was sent counter *atomic.Int64 // counter that grows when a message was sent
sig int sig int
} }
@ -701,7 +702,7 @@ func (h *Hook) proc() (ok bool) {
} }
log.Debugf("Endpoint send ok: %v: %v: %v", idx, endpoint, err) log.Debugf("Endpoint send ok: %v: %v: %v", idx, endpoint, err)
sent = true sent = true
h.counter.add(1) h.counter.Add(1)
break break
} }
if !sent { if !sent {

View File

@ -11,6 +11,7 @@ import (
"github.com/tidwall/redcon" "github.com/tidwall/redcon"
"github.com/tidwall/tile38/internal/log" "github.com/tidwall/tile38/internal/log"
"go.uber.org/atomic"
) )
type liveBuffer struct { type liveBuffer struct {
@ -23,12 +24,12 @@ type liveBuffer struct {
func (s *Server) processLives(wg *sync.WaitGroup) { func (s *Server) processLives(wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
var done abool var done atomic.Bool
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
for { for {
if done.on() { if done.Load() {
break break
} }
s.lcond.Broadcast() s.lcond.Broadcast()
@ -38,8 +39,8 @@ func (s *Server) processLives(wg *sync.WaitGroup) {
s.lcond.L.Lock() s.lcond.L.Lock()
defer s.lcond.L.Unlock() defer s.lcond.L.Unlock()
for { for {
if s.stopServer.on() { if s.stopServer.Load() {
done.set(true) done.Store(true)
return return
} }
for len(s.lstack) > 0 { for len(s.lstack) > 0 {
@ -211,7 +212,7 @@ func (s *Server) goLive(
return nil // nil return is fine here return nil // nil return is fine here
} }
} }
s.statsTotalMsgsSent.add(len(msgs)) s.statsTotalMsgsSent.Add(int64(len(msgs)))
lb.cond.L.Lock() lb.cond.L.Lock()
} }

View File

@ -280,7 +280,7 @@ func (s *Server) liveSubscription(
write(b) write(b)
} }
} }
s.statsTotalMsgsSent.add(1) s.statsTotalMsgsSent.Add(1)
} }
m := [2]map[string]bool{ m := [2]map[string]bool{

View File

@ -89,15 +89,15 @@ type Server struct {
http500Errors bool http500Errors bool
// atomics // atomics
followc aint // counter increases when follow property changes followc atomic.Int64 // counter when follow property changes
statsTotalConns aint // counter for total connections statsTotalConns atomic.Int64 // counter for total connections
statsTotalCommands aint // counter for total commands statsTotalCommands atomic.Int64 // counter for total commands
statsTotalMsgsSent aint // counter for total sent webhook messages statsTotalMsgsSent atomic.Int64 // counter for total sent webhook messages
statsExpired aint // item expiration counter statsExpired atomic.Int64 // item expiration counter
lastShrinkDuration aint lastShrinkDuration atomic.Int64
stopServer abool stopServer atomic.Bool
outOfMemory abool outOfMemory atomic.Bool
loadedAndReady abool // server is loaded and ready for commands loadedAndReady atomic.Bool // server is loaded and ready for commands
connsmu sync.RWMutex connsmu sync.RWMutex
conns map[int]*Client conns map[int]*Client
@ -296,7 +296,7 @@ func Serve(opts Options) error {
go func() { go func() {
<-opts.Shutdown <-opts.Shutdown
s.stopServer.set(true) s.stopServer.Store(true)
log.Warnf("Shutting down...") log.Warnf("Shutting down...")
s.lnmu.Lock() s.lnmu.Lock()
ln := s.ln ln := s.ln
@ -363,7 +363,7 @@ func Serve(opts Options) error {
go func() { go func() {
defer bgwg.Done() defer bgwg.Done()
s.follow(s.config.followHost(), s.config.followPort(), s.follow(s.config.followHost(), s.config.followPort(),
s.followc.get()) int(s.followc.Load()))
}() }()
} }
@ -382,7 +382,7 @@ func Serve(opts Options) error {
smux.HandleFunc("/metrics", s.MetricsHandler) smux.HandleFunc("/metrics", s.MetricsHandler)
err := http.Serve(mln, smux) err := http.Serve(mln, smux)
if err != nil { if err != nil {
if !s.stopServer.on() { if !s.stopServer.Load() {
log.Fatalf("metrics server: %s", err) log.Fatalf("metrics server: %s", err)
} }
} }
@ -404,8 +404,8 @@ func Serve(opts Options) error {
defer func() { defer func() {
log.Debug("Stopping background routines") log.Debug("Stopping background routines")
// Stop background routines // Stop background routines
s.followc.add(1) // this will force any follow communication to die s.followc.Add(1) // this will force any follow communication to die
s.stopServer.set(true) s.stopServer.Store(true)
if mln != nil { if mln != nil {
mln.Close() // Stop the metrics server mln.Close() // Stop the metrics server
} }
@ -413,7 +413,7 @@ func Serve(opts Options) error {
}() }()
// Server is now loaded and ready. Wait for network error messages. // Server is now loaded and ready. Wait for network error messages.
s.loadedAndReady.set(true) s.loadedAndReady.Store(true)
return <-nerr return <-nerr
} }
@ -466,7 +466,7 @@ func (s *Server) netServe() error {
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
if s.stopServer.on() { if s.stopServer.Load() {
return nil return nil
} }
log.Warn(err) log.Warn(err)
@ -489,7 +489,7 @@ func (s *Server) netServe() error {
s.connsmu.Lock() s.connsmu.Lock()
s.conns[client.id] = client s.conns[client.id] = client
s.connsmu.Unlock() s.connsmu.Unlock()
s.statsTotalConns.add(1) s.statsTotalConns.Add(1)
// set the client keep-alive, if needed // set the client keep-alive, if needed
if s.config.keepAlive() > 0 { if s.config.keepAlive() > 0 {
@ -568,7 +568,7 @@ func (s *Server) netServe() error {
client.mu.Unlock() client.mu.Unlock()
// update total command count // update total command count
s.statsTotalCommands.add(1) s.statsTotalCommands.Add(1)
// handle the command // handle the command
err := s.handleInputCommand(client, msg) err := s.handleInputCommand(client, msg)
@ -728,14 +728,14 @@ func (s *Server) watchAutoGC(wg *sync.WaitGroup) {
} }
func (s *Server) checkOutOfMemory() { func (s *Server) checkOutOfMemory() {
if s.stopServer.on() { if s.stopServer.Load() {
return return
} }
oom := s.outOfMemory.on() oom := s.outOfMemory.Load()
var mem runtime.MemStats var mem runtime.MemStats
if s.config.maxMemory() == 0 { if s.config.maxMemory() == 0 {
if oom { if oom {
s.outOfMemory.set(false) s.outOfMemory.Store(false)
} }
return return
} }
@ -743,13 +743,13 @@ func (s *Server) checkOutOfMemory() {
runtime.GC() runtime.GC()
} }
runtime.ReadMemStats(&mem) runtime.ReadMemStats(&mem)
s.outOfMemory.set(int(mem.HeapAlloc) > s.config.maxMemory()) s.outOfMemory.Store(int(mem.HeapAlloc) > s.config.maxMemory())
} }
func (s *Server) loopUntilServerStops(dur time.Duration, op func()) { func (s *Server) loopUntilServerStops(dur time.Duration, op func()) {
var last time.Time var last time.Time
for { for {
if s.stopServer.on() { if s.stopServer.Load() {
return return
} }
now := time.Now() now := time.Now()
@ -923,7 +923,7 @@ func (s *Server) handleInputCommand(client *Client, msg *Message) error {
return nil return nil
} }
if !s.loadedAndReady.on() { if !s.loadedAndReady.Load() {
switch msg.Command() { switch msg.Command() {
case "output", "ping", "echo": case "output", "ping", "echo":
default: default:

View File

@ -322,7 +322,7 @@ func (s *Server) extStats(m map[string]interface{}) {
// Whether or not an AOF shrink is currently in progress // Whether or not an AOF shrink is currently in progress
m["tile38_aof_rewrite_in_progress"] = s.shrinking m["tile38_aof_rewrite_in_progress"] = s.shrinking
// Length of time the last AOF shrink took // Length of time the last AOF shrink took
m["tile38_aof_last_rewrite_time_sec"] = s.lastShrinkDuration.get() / int(time.Second) m["tile38_aof_last_rewrite_time_sec"] = s.lastShrinkDuration.Load() / int64(time.Second)
// Duration of the on-going AOF rewrite operation if any // Duration of the on-going AOF rewrite operation if any
var currentShrinkStart time.Time var currentShrinkStart time.Time
if currentShrinkStart.IsZero() { if currentShrinkStart.IsZero() {
@ -335,13 +335,13 @@ func (s *Server) extStats(m map[string]interface{}) {
// Whether or no the HTTP transport is being served // Whether or no the HTTP transport is being served
m["tile38_http_transport"] = s.http m["tile38_http_transport"] = s.http
// Number of connections accepted by the server // Number of connections accepted by the server
m["tile38_total_connections_received"] = s.statsTotalConns.get() m["tile38_total_connections_received"] = s.statsTotalConns.Load()
// Number of commands processed by the server // Number of commands processed by the server
m["tile38_total_commands_processed"] = s.statsTotalCommands.get() m["tile38_total_commands_processed"] = s.statsTotalCommands.Load()
// Number of webhook messages sent by server // Number of webhook messages sent by server
m["tile38_total_messages_sent"] = s.statsTotalMsgsSent.get() m["tile38_total_messages_sent"] = s.statsTotalMsgsSent.Load()
// Number of key expiration events // Number of key expiration events
m["tile38_expired_keys"] = s.statsExpired.get() m["tile38_expired_keys"] = s.statsExpired.Load()
// Number of connected slaves // Number of connected slaves
m["tile38_connected_slaves"] = len(s.aofconnM) m["tile38_connected_slaves"] = len(s.aofconnM)
@ -411,7 +411,7 @@ func boolInt(t bool) int {
func (s *Server) writeInfoPersistence(w *bytes.Buffer) { func (s *Server) writeInfoPersistence(w *bytes.Buffer) {
fmt.Fprintf(w, "aof_enabled:%d\r\n", boolInt(s.opts.AppendOnly)) fmt.Fprintf(w, "aof_enabled:%d\r\n", boolInt(s.opts.AppendOnly))
fmt.Fprintf(w, "aof_rewrite_in_progress:%d\r\n", boolInt(s.shrinking)) // Flag indicating a AOF rewrite operation is on-going fmt.Fprintf(w, "aof_rewrite_in_progress:%d\r\n", boolInt(s.shrinking)) // Flag indicating a AOF rewrite operation is on-going
fmt.Fprintf(w, "aof_last_rewrite_time_sec:%d\r\n", s.lastShrinkDuration.get()/int(time.Second)) // Duration of the last AOF rewrite operation in seconds fmt.Fprintf(w, "aof_last_rewrite_time_sec:%d\r\n", s.lastShrinkDuration.Load()/int64(time.Second)) // Duration of the last AOF rewrite operation in seconds
var currentShrinkStart time.Time // c.currentShrinkStart.get() var currentShrinkStart time.Time // c.currentShrinkStart.get()
if currentShrinkStart.IsZero() { if currentShrinkStart.IsZero() {
@ -422,10 +422,10 @@ func (s *Server) writeInfoPersistence(w *bytes.Buffer) {
} }
func (s *Server) writeInfoStats(w *bytes.Buffer) { func (s *Server) writeInfoStats(w *bytes.Buffer) {
fmt.Fprintf(w, "total_connections_received:%d\r\n", s.statsTotalConns.get()) // Total number of connections accepted by the server fmt.Fprintf(w, "total_connections_received:%d\r\n", s.statsTotalConns.Load()) // Total number of connections accepted by the server
fmt.Fprintf(w, "total_commands_processed:%d\r\n", s.statsTotalCommands.get()) // Total number of commands processed by the server fmt.Fprintf(w, "total_commands_processed:%d\r\n", s.statsTotalCommands.Load()) // Total number of commands processed by the server
fmt.Fprintf(w, "total_messages_sent:%d\r\n", s.statsTotalMsgsSent.get()) // Total number of commands processed by the server fmt.Fprintf(w, "total_messages_sent:%d\r\n", s.statsTotalMsgsSent.Load()) // Total number of commands processed by the server
fmt.Fprintf(w, "expired_keys:%d\r\n", s.statsExpired.get()) // Total number of key expiration events fmt.Fprintf(w, "expired_keys:%d\r\n", s.statsExpired.Load()) // Total number of key expiration events
} }
// writeInfoReplication writes all replication data to the 'info' response // writeInfoReplication writes all replication data to the 'info' response