logrus/logger_test.go

199 lines
4.3 KiB
Go

package logrus
import (
"bytes"
"context"
"encoding/json"
"fmt"
"path/filepath"
"runtime"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFieldValueError(t *testing.T) {
buf := &bytes.Buffer{}
l := &Logger{
Out: buf,
Formatter: new(JSONFormatter),
Hooks: make(LevelHooks),
Level: DebugLevel,
}
l.WithField("func", func() {}).Info("test")
fmt.Println(buf.String())
var data map[string]interface{}
if err := json.Unmarshal(buf.Bytes(), &data); err != nil {
t.Error("unexpected error", err)
}
_, ok := data[FieldKeyLogrusError]
require.True(t, ok, `cannot found expected "logrus_error" field: %v`, data)
}
func TestNoFieldValueError(t *testing.T) {
buf := &bytes.Buffer{}
l := &Logger{
Out: buf,
Formatter: new(JSONFormatter),
Hooks: make(LevelHooks),
Level: DebugLevel,
}
l.WithField("str", "str").Info("test")
fmt.Println(buf.String())
var data map[string]interface{}
if err := json.Unmarshal(buf.Bytes(), &data); err != nil {
t.Error("unexpected error", err)
}
_, ok := data[FieldKeyLogrusError]
require.False(t, ok)
}
func TestWarninglnNotEqualToWarning(t *testing.T) {
buf := &bytes.Buffer{}
bufln := &bytes.Buffer{}
formatter := new(TextFormatter)
formatter.DisableTimestamp = true
formatter.DisableLevelTruncation = true
l := &Logger{
Out: buf,
Formatter: formatter,
Hooks: make(LevelHooks),
Level: DebugLevel,
}
l.Warning("hello,", "world")
l.SetOutput(bufln)
l.Warningln("hello,", "world")
assert.NotEqual(t, buf.String(), bufln.String(), "Warning() and Wantingln() should not be equal")
}
type testBufferPool struct {
buffers []*bytes.Buffer
get int
}
func (p *testBufferPool) Get() *bytes.Buffer {
p.get++
return new(bytes.Buffer)
}
func (p *testBufferPool) Put(buf *bytes.Buffer) {
p.buffers = append(p.buffers, buf)
}
func TestLogger_SetBufferPool(t *testing.T) {
out := &bytes.Buffer{}
l := New()
l.SetOutput(out)
pool := new(testBufferPool)
l.SetBufferPool(pool)
l.Info("test")
assert.Equal(t, pool.get, 1, "Logger.SetBufferPool(): The BufferPool.Get() must be called")
assert.Len(t, pool.buffers, 1, "Logger.SetBufferPool(): The BufferPool.Put() must be called")
}
func TestLogger_concurrentLock(t *testing.T) {
SetFormatter(&LogFormatter{})
go func() {
for {
func() {
defer func() {
if p := recover(); p != nil {
}
}()
hook := AddTraceIdHook("123")
defer RemoveTraceHook(hook)
Infof("test why ")
}()
}
}()
go func() {
for {
func() {
defer func() {
if p := recover(); p != nil {
}
}()
hook := AddTraceIdHook("1233")
defer RemoveTraceHook(hook)
Infof("test why 2")
}()
}
}()
time.Sleep(5 * time.Second)
}
func AddTraceIdHook(traceId string) Hook {
traceHook := newTraceIdHook(traceId)
if StandardLogger().Hooks == nil {
hooks := new(LevelHooks)
StandardLogger().ReplaceHooks(*hooks)
}
AddHook(traceHook)
return traceHook
}
func RemoveTraceHook(hook Hook) {
StandardLogger().ReplaceHook(hook)
}
type TraceIdHook struct {
TraceId string
GID uint64
}
func newTraceIdHook(traceId string) Hook {
return &TraceIdHook{
TraceId: traceId,
GID: getGID(),
}
}
func (t TraceIdHook) Levels() []Level {
return AllLevels
}
func (t TraceIdHook) Fire(entry *Entry) error {
if getGID() == t.GID {
entry.Context = context.WithValue(context.Background(), "trace_id", t.TraceId)
}
return nil
}
type LogFormatter struct{}
func (s *LogFormatter) Format(entry *Entry) ([]byte, error) {
timestamp := time.Now().Format("2006-01-02 15:04:05")
var file string
var line int
if entry.Caller != nil {
file = filepath.Base(entry.Caller.File)
line = entry.Caller.Line
}
level := entry.Level.String()
if entry.Context == nil || entry.Context.Value("trace_id") == "" {
uuid := "NO UUID"
entry.Context = context.WithValue(context.Background(), "trace_id", uuid)
}
msg := fmt.Sprintf("%-15s [%-3d] [%-5s] [%s] %s:%d %s\n", timestamp, getGID(), level, entry.Context.Value("trace_id"), file, line, entry.Message)
return []byte(msg), nil
}
func getGID() uint64 {
b := make([]byte, 64)
b = b[:runtime.Stack(b, false)]
b = bytes.TrimPrefix(b, []byte("goroutine "))
b = b[:bytes.IndexByte(b, ' ')]
n, _ := strconv.ParseUint(string(b), 10, 64)
return n
}