// Copyright 2014 Manu Martinez-Almeida. All rights reserved. // Use of this source code is governed by a MIT style // license that can be found in the LICENSE file. package gin import ( "bytes" "errors" "fmt" "io" "log" "net" "net/http" "net/http/httputil" "os" "runtime" "strings" "time" ) var ( dunno = []byte("???") centerDot = []byte("·") dot = []byte(".") slash = []byte("/") fileCache = make(map[string][][]byte) // 优化:缓存文件内容,避免重复读取 ) // RecoveryFunc defines the function passable to CustomRecovery. type RecoveryFunc func(c *Context, err any) // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. func Recovery() HandlerFunc { return RecoveryWithWriter(DefaultErrorWriter) } // CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it. func CustomRecovery(handle RecoveryFunc) HandlerFunc { return RecoveryWithWriter(DefaultErrorWriter, handle) } // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc { if len(recovery) > 0 { return CustomRecoveryWithWriter(out, recovery[0]) } return CustomRecoveryWithWriter(out, defaultHandleRecovery) } // CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it. func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc { var logger *log.Logger if out != nil { logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags) } return func(c *Context) { defer func() { if err := recover(); err != nil { // Check for a broken connection, as it is not really a condition that warrants a panic stack trace. var brokenPipe bool if ne, ok := err.(*net.OpError); ok { var se *os.SyscallError if errors.As(ne, &se) { seStr := strings.ToLower(se.Error()) if strings.Contains(seStr, "broken pipe") || strings.Contains(seStr, "connection reset by peer") { brokenPipe = true } } } if logger != nil { stack := stack(3) httpRequest, _ := httputil.DumpRequest(c.Request, false) headers := strings.Split(string(httpRequest), "\r\n") headersToStr := sanitizeHeaders(headers) // 优化:抽离出对 headers 的处理 if brokenPipe { go logger.Printf("%s\n%s%s", err, headersToStr, reset) // 优化:并行化日志输出 } else if IsDebugging() { go logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s", timeFormat(time.Now()), headersToStr, err, stack, reset) } else { go logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s", timeFormat(time.Now()), err, stack, reset) } } if brokenPipe { c.Error(err.(error)) //nolint: errcheck c.Abort() } else { handle(c, err) } } }() c.Next() } } func defaultHandleRecovery(c *Context, _ any) { c.AbortWithStatus(http.StatusInternalServerError) } // stack returns a nicely formatted stack frame, skipping skip frames. func stack(skip int) []byte { buf := new(bytes.Buffer) var lines [][]byte var lastFile string for i := skip; ; i++ { pc, file, line, ok := runtime.Caller(i) if !ok { break } fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc) if file != lastFile { if cachedLines, found := fileCache[file]; found { // 优化:使用缓存避免重复读取文件 lines = cachedLines } else { data, err := os.ReadFile(file) if err != nil { continue } lines = bytes.Split(data, []byte{'\n'}) fileCache[file] = lines } lastFile = file } fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line)) } return buf.Bytes() } // sanitizeHeaders masks sensitive header information (e.g., Authorization). // 优化:抽离出 header 处理函数,简化主逻辑 func sanitizeHeaders(headers []string) string { for idx, header := range headers { current := strings.Split(header, ":") if current[0] == "Authorization" { headers[idx] = current[0] + ": *" } } return strings.Join(headers, "\r\n") } // source returns a space-trimmed slice of the n'th line. func source(lines [][]byte, n int) []byte { n-- // in stack trace, lines are 1-indexed but our array is 0-indexed if n < 0 || n >= len(lines) { return dunno } return bytes.TrimSpace(lines[n]) } // function returns, if possible, the name of the function containing the PC. func function(pc uintptr) []byte { fn := runtime.FuncForPC(pc) if fn == nil { return dunno } name := []byte(fn.Name()) if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 { name = name[lastSlash+1:] } if period := bytes.Index(name, dot); period >= 0 { name = name[period+1:] } name = bytes.ReplaceAll(name, centerDot, dot) return name } // timeFormat returns a customized time string for logger. func timeFormat(t time.Time) string { return t.Format("2006/01/02 - 15:04:05") }