Better support Stringer when explain SQL

This commit is contained in:
Jinzhu 2022-02-09 15:17:19 +08:00
parent d22215129e
commit 4eeb839cea
4 changed files with 86 additions and 7 deletions

View File

@ -12,6 +12,7 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found") var ErrRecordNotFound = errors.New("record not found")
// Colors // Colors
@ -30,13 +31,17 @@ const (
YellowBold = "\033[33;1m" YellowBold = "\033[33;1m"
) )
// LogLevel // LogLevel log level
type LogLevel int type LogLevel int
const ( const (
// Silent silent log level
Silent LogLevel = iota + 1 Silent LogLevel = iota + 1
// Error error log level
Error Error
// Warn warn log level
Warn Warn
// Info info log level
Info Info
) )
@ -45,6 +50,7 @@ type Writer interface {
Printf(string, ...interface{}) Printf(string, ...interface{})
} }
// Config logger config
type Config struct { type Config struct {
SlowThreshold time.Duration SlowThreshold time.Duration
Colorful bool Colorful bool
@ -62,16 +68,20 @@ type Interface interface {
} }
var ( var (
// Discard Discard logger will print any log to ioutil.Discard
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond, SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn, LogLevel: Warn,
IgnoreRecordNotFoundError: false, IgnoreRecordNotFoundError: false,
Colorful: true, Colorful: true,
}) })
// Recorder Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
) )
// New initialize logger
func New(writer Writer, config Config) Interface { func New(writer Writer, config Config) Interface {
var ( var (
infoStr = "%s\n[info] " infoStr = "%s\n[info] "
@ -179,10 +189,12 @@ type traceRecorder struct {
Err error Err error
} }
// New new trace recorder
func (l traceRecorder) New() *traceRecorder { func (l traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
} }
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin l.BeginAt = begin
l.SQL, l.RowsAffected = fc() l.SQL, l.RowsAffected = fc()

View File

@ -30,9 +30,12 @@ func isPrintable(s []byte) bool {
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int) var (
vars := make([]string, len(avars)) convertParams func(interface{}, int)
vars = make([]string, len(avars))
)
convertParams = func(v interface{}, idx int) { convertParams = func(v interface{}, idx int) {
switch v := v.(type) { switch v := v.(type) {
@ -64,11 +67,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
} }
case fmt.Stringer: case fmt.Stringer:
reflectValue := reflect.ValueOf(v) reflectValue := reflect.ValueOf(v)
switch reflectValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
} else { } else {
vars[idx] = nullStr vars[idx] = nullStr
} }
}
case []byte: case []byte:
if isPrintable(v) { if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper

View File

@ -9,7 +9,7 @@ require (
github.com/jinzhu/now v1.1.4 github.com/jinzhu/now v1.1.4
github.com/lib/pq v1.10.4 github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.11 // indirect github.com/mattn/go-sqlite3 v1.14.11 // indirect
golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect
gorm.io/driver/mysql v1.2.3 gorm.io/driver/mysql v1.2.3
gorm.io/driver/postgres v1.2.3 gorm.io/driver/postgres v1.2.3
gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlite v1.2.6

View File

@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) {
} }
} }
type ageInt int8
func (ageInt) String() string {
return "age"
}
type ageBool bool
func (ageBool) String() string {
return "age"
}
type ageUint64 uint64
func (ageUint64) String() string {
return "age"
}
type ageFloat float64
func (ageFloat) String() string {
return "age"
}
func TestExplainSQL(t *testing.T) {
user := *GetUser("explain-sql", Config{})
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement
sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
}
func TestGroupConditions(t *testing.T) { func TestGroupConditions(t *testing.T) {
type Pizza struct { type Pizza struct {
ID uint ID uint