logrus/logger_test.go

225 lines
4.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package logrus
import (
"bytes"
"context"
"encoding/json"
"fmt"
"path/filepath"
"runtime"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFieldValueError(t *testing.T) {
buf := &bytes.Buffer{}
l := &Logger{
Out: buf,
Formatter: new(JSONFormatter),
Hooks: make(LevelHooks),
Level: DebugLevel,
}
l.WithField("func", func() {}).Info("test")
fmt.Println(buf.String())
var data map[string]interface{}
if err := json.Unmarshal(buf.Bytes(), &data); err != nil {
t.Error("unexpected error", err)
}
_, ok := data[FieldKeyLogrusError]
require.True(t, ok, `cannot found expected "logrus_error" field: %v`, data)
}
func TestNoFieldValueError(t *testing.T) {
buf := &bytes.Buffer{}
l := &Logger{
Out: buf,
Formatter: new(JSONFormatter),
Hooks: make(LevelHooks),
Level: DebugLevel,
}
l.WithField("str", "str").Info("test")
fmt.Println(buf.String())
var data map[string]interface{}
if err := json.Unmarshal(buf.Bytes(), &data); err != nil {
t.Error("unexpected error", err)
}
_, ok := data[FieldKeyLogrusError]
require.False(t, ok)
}
func TestWarninglnNotEqualToWarning(t *testing.T) {
buf := &bytes.Buffer{}
bufln := &bytes.Buffer{}
formatter := new(TextFormatter)
formatter.DisableTimestamp = true
formatter.DisableLevelTruncation = true
l := &Logger{
Out: buf,
Formatter: formatter,
Hooks: make(LevelHooks),
Level: DebugLevel,
}
l.Warning("hello,", "world")
l.SetOutput(bufln)
l.Warningln("hello,", "world")
assert.NotEqual(t, buf.String(), bufln.String(), "Warning() and Wantingln() should not be equal")
}
type testBufferPool struct {
buffers []*bytes.Buffer
get int
}
func (p *testBufferPool) Get() *bytes.Buffer {
p.get++
return new(bytes.Buffer)
}
func (p *testBufferPool) Put(buf *bytes.Buffer) {
p.buffers = append(p.buffers, buf)
}
func TestLogger_SetBufferPool(t *testing.T) {
out := &bytes.Buffer{}
l := New()
l.SetOutput(out)
pool := new(testBufferPool)
l.SetBufferPool(pool)
l.Info("test")
assert.Equal(t, pool.get, 1, "Logger.SetBufferPool(): The BufferPool.Get() must be called")
assert.Len(t, pool.buffers, 1, "Logger.SetBufferPool(): The BufferPool.Put() must be called")
}
func TestLogger_concurrentLock(t *testing.T) {
SetFormatter(&LogFormatter{})
go func() {
for {
func() {
defer func() {
// 处理所有异常防止panic导致程序关闭
if p := recover(); p != nil {
}
}()
hook := AddTraceIdHook("123")
defer RemoveTraceHook(hook)
Infof("test why ")
}()
}
}()
go func() {
for {
func() {
defer func() {
// 处理所有异常防止panic导致程序关闭
if p := recover(); p != nil {
}
}()
hook := AddTraceIdHook("1233")
defer RemoveTraceHook(hook)
Infof("test why 2")
}()
}
}()
c := make(chan int)
<-c
}
var traceLock = &sync.Mutex{}
func AddTraceIdHook(traceId string) Hook {
defer traceLock.Unlock()
traceLock.Lock()
traceHook := newTraceIdHook(traceId)
if StandardLogger().Hooks == nil {
hooks := new(LevelHooks)
StandardLogger().ReplaceHooks(*hooks)
}
AddHook(traceHook)
return traceHook
}
func RemoveTraceHook(hook Hook) {
allHooks := StandardLogger().Hooks
func() {
defer Unlock()
Lock()
for key, hooks := range allHooks {
replaceHooks := hooks
for index, h := range hooks {
if h == hook {
replaceHooks = append(hooks[:index], hooks[index:]...)
break
}
}
allHooks[key] = replaceHooks
}
}()
StandardLogger().ReplaceHooks(allHooks)
}
type TraceIdHook struct {
TraceId string
GID uint64
}
func newTraceIdHook(traceId string) Hook {
return &TraceIdHook{
TraceId: traceId,
GID: getGID(),
}
}
func (t TraceIdHook) Levels() []Level {
return AllLevels
}
func (t TraceIdHook) Fire(entry *Entry) error {
if getGID() == t.GID {
entry.Context = context.WithValue(context.Background(), "trace_id", t.TraceId)
}
return nil
}
type LogFormatter struct{}
//格式详情
func (s *LogFormatter) Format(entry *Entry) ([]byte, error) {
timestamp := time.Now().Format("2006-01-02 15:04:05")
var file string
var line int
if entry.Caller != nil {
file = filepath.Base(entry.Caller.File)
line = entry.Caller.Line
}
level := entry.Level.String()
if entry.Context == nil || entry.Context.Value("trace_id") == "" {
uuid := "NO UUID"
entry.Context = context.WithValue(context.Background(), "trace_id", uuid)
}
msg := fmt.Sprintf("%-15s [%-3d] [%-5s] [%s] %s:%d %s\n", timestamp, getGID(), level, entry.Context.Value("trace_id"), file, line, entry.Message)
return []byte(msg), nil
}
// 获取当前协程id
func getGID() uint64 {
b := make([]byte, 64)
b = b[:runtime.Stack(b, false)]
b = bytes.TrimPrefix(b, []byte("goroutine "))
b = b[:bytes.IndexByte(b, ' ')]
n, _ := strconv.ParseUint(string(b), 10, 64)
return n
}