From 4eeb839ceabb983b634f9cf9fffa1dd773b6803d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Feb 2022 15:17:19 +0800 Subject: [PATCH] Better support Stringer when explain SQL --- logger/logger.go | 14 ++++++++++- logger/sql.go | 24 ++++++++++++++---- tests/go.mod | 2 +- tests/sql_builder_test.go | 53 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 7 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 0c4ca4a0..2ffd28d5 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm/utils" ) +// ErrRecordNotFound record not found error var ErrRecordNotFound = errors.New("record not found") // Colors @@ -30,13 +31,17 @@ const ( YellowBold = "\033[33;1m" ) -// LogLevel +// LogLevel log level type LogLevel int const ( + // Silent silent log level Silent LogLevel = iota + 1 + // Error error log level Error + // Warn warn log level Warn + // Info info log level Info ) @@ -45,6 +50,7 @@ type Writer interface { Printf(string, ...interface{}) } +// Config logger config type Config struct { SlowThreshold time.Duration Colorful bool @@ -62,16 +68,20 @@ type Interface interface { } var ( + // Discard Discard logger will print any log to ioutil.Discard Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, Colorful: true, }) + // Recorder Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) +// New initialize logger func New(writer Writer, config Config) Interface { var ( infoStr = "%s\n[info] " @@ -179,10 +189,12 @@ type traceRecorder struct { Err error } +// New new trace recorder func (l traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } +// Trace implement logger interface func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { l.BeginAt = begin l.SQL, l.RowsAffected = fc() diff --git a/logger/sql.go b/logger/sql.go index 5ecb0ae2..e0be57c0 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -30,9 +30,12 @@ func isPrintable(s []byte) bool { var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { - var convertParams func(interface{}, int) - vars := make([]string, len(avars)) + var ( + convertParams func(interface{}, int) + vars = make([]string, len(avars)) + ) convertParams = func(v interface{}, idx int) { switch v := v.(type) { @@ -64,10 +67,21 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } case fmt.Stringer: reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + switch reflectValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) + case reflect.Float32, reflect.Float64: + vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) + case reflect.Bool: + vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) + case reflect.String: vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper - } else { - vars[idx] = nullStr + default: + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = nullStr + } } case []byte: if isPrintable(v) { diff --git a/tests/go.mod b/tests/go.mod index 5488c17e..3453f77b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect + golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 237d807b..897f687f 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) { } } +type ageInt int8 + +func (ageInt) String() string { + return "age" +} + +type ageBool bool + +func (ageBool) String() string { + return "age" +} + +type ageUint64 uint64 + +func (ageUint64) String() string { + return "age" +} + +type ageFloat float64 + +func (ageFloat) String() string { + return "age" +} + +func TestExplainSQL(t *testing.T) { + user := *GetUser("explain-sql", Config{}) + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement + sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } +} + func TestGroupConditions(t *testing.T) { type Pizza struct { ID uint