mirror of https://github.com/go-gorm/gorm.git
182 lines
5.0 KiB
Go
182 lines
5.0 KiB
Go
package logger
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"gorm.io/gorm/utils"
|
|
)
|
|
|
|
const (
|
|
tmFmtWithMS = "2006-01-02 15:04:05.999"
|
|
tmFmtZero = "0000-00-00 00:00:00"
|
|
nullStr = "NULL"
|
|
)
|
|
|
|
func isPrintable(s string) bool {
|
|
for _, r := range s {
|
|
if !unicode.IsPrint(r) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// A list of Go types that should be converted to SQL primitives
|
|
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
|
|
|
// RegEx matches only numeric values
|
|
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
|
|
|
func isNumeric(k reflect.Kind) bool {
|
|
switch k {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return true
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
return true
|
|
case reflect.Float32, reflect.Float64:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
|
var (
|
|
convertParams func(interface{}, int)
|
|
vars = make([]string, len(avars))
|
|
)
|
|
|
|
convertParams = func(v interface{}, idx int) {
|
|
switch v := v.(type) {
|
|
case bool:
|
|
vars[idx] = strconv.FormatBool(v)
|
|
case time.Time:
|
|
if v.IsZero() {
|
|
vars[idx] = escaper + tmFmtZero + escaper
|
|
} else {
|
|
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
|
|
}
|
|
case *time.Time:
|
|
if v != nil {
|
|
if v.IsZero() {
|
|
vars[idx] = escaper + tmFmtZero + escaper
|
|
} else {
|
|
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
|
|
}
|
|
} else {
|
|
vars[idx] = nullStr
|
|
}
|
|
case driver.Valuer:
|
|
reflectValue := reflect.ValueOf(v)
|
|
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
|
r, _ := v.Value()
|
|
convertParams(r, idx)
|
|
} else {
|
|
vars[idx] = nullStr
|
|
}
|
|
case fmt.Stringer:
|
|
reflectValue := reflect.ValueOf(v)
|
|
switch reflectValue.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
|
|
case reflect.Float32, reflect.Float64:
|
|
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
|
|
case reflect.Bool:
|
|
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
|
case reflect.String:
|
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
|
default:
|
|
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
|
} else {
|
|
vars[idx] = nullStr
|
|
}
|
|
}
|
|
case []byte:
|
|
if s := string(v); isPrintable(s) {
|
|
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
|
} else {
|
|
vars[idx] = escaper + "<binary>" + escaper
|
|
}
|
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
|
vars[idx] = utils.ToString(v)
|
|
case float32:
|
|
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
|
case float64:
|
|
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
|
case string:
|
|
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
|
|
default:
|
|
rv := reflect.ValueOf(v)
|
|
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
|
vars[idx] = nullStr
|
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
|
v, _ = valuer.Value()
|
|
convertParams(v, idx)
|
|
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
|
convertParams(reflect.Indirect(rv).Interface(), idx)
|
|
} else if isNumeric(rv.Kind()) {
|
|
if rv.CanInt() || rv.CanUint() {
|
|
vars[idx] = fmt.Sprintf("%d", rv.Interface())
|
|
} else {
|
|
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
|
|
}
|
|
} else {
|
|
for _, t := range convertibleTypes {
|
|
if rv.Type().ConvertibleTo(t) {
|
|
convertParams(rv.Convert(t).Interface(), idx)
|
|
return
|
|
}
|
|
}
|
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
|
}
|
|
}
|
|
}
|
|
|
|
for idx, v := range avars {
|
|
convertParams(v, idx)
|
|
}
|
|
|
|
if numericPlaceholder == nil {
|
|
var idx int
|
|
var newSQL strings.Builder
|
|
|
|
for _, v := range []byte(sql) {
|
|
if v == '?' {
|
|
if len(vars) > idx {
|
|
newSQL.WriteString(vars[idx])
|
|
idx++
|
|
continue
|
|
}
|
|
}
|
|
newSQL.WriteByte(v)
|
|
}
|
|
|
|
sql = newSQL.String()
|
|
} else {
|
|
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
|
|
|
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
|
|
num := v[1 : len(v)-1]
|
|
n, _ := strconv.Atoi(num)
|
|
|
|
// position var start from 1 ($1, $2)
|
|
n -= 1
|
|
if n >= 0 && n <= len(vars)-1 {
|
|
return vars[n]
|
|
}
|
|
return v
|
|
})
|
|
}
|
|
|
|
return sql
|
|
}
|