fix for entry data field race condition

This commit is contained in:
David Bariod 2021-02-16 10:29:37 +01:00
parent f513f99c15
commit ac6e35b4c2
5 changed files with 79 additions and 34 deletions

View File

@ -78,8 +78,20 @@ func NewEntry(logger *Logger) *Entry {
} }
} }
func (entry *Entry) Dup() *Entry {
data := make(Fields, len(entry.Data))
for k, v := range entry.Data {
data[k] = v
}
return &Entry{Logger: entry.Logger, Data: data, Time: entry.Time, Context: entry.Context, err: entry.err}
}
// Returns the bytes representation of this entry from the formatter. // Returns the bytes representation of this entry from the formatter.
func (entry *Entry) Bytes() ([]byte, error) { func (entry *Entry) Bytes() ([]byte, error) {
return entry.bytes_nolock()
}
func (entry *Entry) bytes_nolock() ([]byte, error) {
return entry.Logger.Formatter.Format(entry) return entry.Logger.Formatter.Format(entry)
} }
@ -212,51 +224,49 @@ func (entry Entry) HasCaller() (has bool) {
// This function is not declared with a pointer value because otherwise // This function is not declared with a pointer value because otherwise
// race conditions will occur when using multiple goroutines // race conditions will occur when using multiple goroutines
func (entry Entry) log(level Level, msg string) { func (entry *Entry) log(level Level, msg string) {
var buffer *bytes.Buffer var buffer *bytes.Buffer
// Default to now, but allow users to override if they want. newEntry := entry.Dup()
//
// We don't have to worry about polluting future calls to Entry#log() if newEntry.Time.IsZero() {
// with this assignment because this function is declared with a newEntry.Time = time.Now()
// non-pointer receiver.
if entry.Time.IsZero() {
entry.Time = time.Now()
} }
entry.Level = level newEntry.Level = level
entry.Message = msg newEntry.Message = msg
entry.Logger.mu.Lock()
if entry.Logger.ReportCaller {
entry.Caller = getCaller()
}
entry.Logger.mu.Unlock()
entry.fireHooks() newEntry.Logger.mu.Lock()
reportCaller := newEntry.Logger.ReportCaller
newEntry.Logger.mu.Unlock()
if reportCaller {
newEntry.Caller = getCaller()
}
newEntry.fireHooks()
buffer = getBuffer() buffer = getBuffer()
defer func() { defer func() {
entry.Buffer = nil newEntry.Buffer = nil
putBuffer(buffer) putBuffer(buffer)
}() }()
buffer.Reset() buffer.Reset()
entry.Buffer = buffer newEntry.Buffer = buffer
entry.write() newEntry.write()
entry.Buffer = nil newEntry.Buffer = nil
// To avoid Entry#log() returning a value that only would make sense for // To avoid Entry#log() returning a value that only would make sense for
// panic() to use in Entry#Panic(), we avoid the allocation by checking // panic() to use in Entry#Panic(), we avoid the allocation by checking
// directly here. // directly here.
if level <= PanicLevel { if level <= PanicLevel {
panic(&entry) panic(newEntry)
} }
} }
func (entry *Entry) fireHooks() { func (entry *Entry) fireHooks() {
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
err := entry.Logger.Hooks.Fire(entry.Level, entry) err := entry.Logger.Hooks.Fire(entry.Level, entry)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err) fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
@ -264,16 +274,18 @@ func (entry *Entry) fireHooks() {
} }
func (entry *Entry) write() { func (entry *Entry) write() {
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
serialized, err := entry.Logger.Formatter.Format(entry) serialized, err := entry.Logger.Formatter.Format(entry)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err) fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
return return
} }
if _, err = entry.Logger.Out.Write(serialized); err != nil { func() {
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
if _, err := entry.Logger.Out.Write(serialized); err != nil {
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err) fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
} }
}()
} }
func (entry *Entry) Log(level Level, args ...interface{}) { func (entry *Entry) Log(level Level, args ...interface{}) {

View File

@ -232,7 +232,7 @@ func TestEntryWithIncorrectField(t *testing.T) {
fn := func() {} fn := func() {}
e := Entry{} e := Entry{Logger: New()}
eWithFunc := e.WithFields(Fields{"func": fn}) eWithFunc := e.WithFields(Fields{"func": fn})
eWithFuncPtr := e.WithFields(Fields{"funcPtr": &fn}) eWithFuncPtr := e.WithFields(Fields{"funcPtr": &fn})

View File

@ -12,7 +12,7 @@ import (
// LogFunction For big messages, it can be more efficient to pass a function // LogFunction For big messages, it can be more efficient to pass a function
// and only call it if the log level is actually enables rather than // and only call it if the log level is actually enables rather than
// generating the log message and then checking if the level is enabled // generating the log message and then checking if the level is enabled
type LogFunction func()[]interface{} type LogFunction func() []interface{}
type Logger struct { type Logger struct {
// The logs are `io.Copy`'d to this in a mutex. It's common to set this to a // The logs are `io.Copy`'d to this in a mutex. It's common to set this to a

View File

@ -25,7 +25,7 @@ func TestFieldValueError(t *testing.T) {
t.Error("unexpected error", err) t.Error("unexpected error", err)
} }
_, ok := data[FieldKeyLogrusError] _, ok := data[FieldKeyLogrusError]
require.True(t, ok) require.True(t, ok, `cannot found expected "logrus_error" field: %v`, data)
} }
func TestNoFieldValueError(t *testing.T) { func TestNoFieldValueError(t *testing.T) {

View File

@ -588,15 +588,48 @@ func TestLoggingRaceWithHooksOnEntry(t *testing.T) {
logger.AddHook(hook) logger.AddHook(hook)
entry := logger.WithField("context", "clue") entry := logger.WithField("context", "clue")
var wg sync.WaitGroup var (
wg sync.WaitGroup
mtx sync.Mutex
start bool
)
cond := sync.NewCond(&mtx)
wg.Add(100) wg.Add(100)
for i := 0; i < 100; i++ { for i := 0; i < 50; i++ {
go func() { go func() {
cond.L.Lock()
for !start {
cond.Wait()
}
cond.L.Unlock()
for j := 0; j < 100; j++ {
entry.Info("info") entry.Info("info")
}
wg.Done() wg.Done()
}() }()
} }
for i := 0; i < 50; i++ {
go func() {
cond.L.Lock()
for !start {
cond.Wait()
}
cond.L.Unlock()
for j := 0; j < 100; j++ {
entry.WithField("another field", "with some data").Info("info")
}
wg.Done()
}()
}
cond.L.Lock()
start = true
cond.L.Unlock()
cond.Broadcast()
wg.Wait() wg.Wait()
} }