forked from mirror/gorm
Better support Stringer when explain SQL
This commit is contained in:
parent
d22215129e
commit
4eeb839cea
|
@ -12,6 +12,7 @@ import (
|
|||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// ErrRecordNotFound record not found error
|
||||
var ErrRecordNotFound = errors.New("record not found")
|
||||
|
||||
// Colors
|
||||
|
@ -30,13 +31,17 @@ const (
|
|||
YellowBold = "\033[33;1m"
|
||||
)
|
||||
|
||||
// LogLevel
|
||||
// LogLevel log level
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
// Silent silent log level
|
||||
Silent LogLevel = iota + 1
|
||||
// Error error log level
|
||||
Error
|
||||
// Warn warn log level
|
||||
Warn
|
||||
// Info info log level
|
||||
Info
|
||||
)
|
||||
|
||||
|
@ -45,6 +50,7 @@ type Writer interface {
|
|||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
// Config logger config
|
||||
type Config struct {
|
||||
SlowThreshold time.Duration
|
||||
Colorful bool
|
||||
|
@ -62,16 +68,20 @@ type Interface interface {
|
|||
}
|
||||
|
||||
var (
|
||||
// Discard Discard logger will print any log to ioutil.Discard
|
||||
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
||||
// Default Default logger
|
||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: Warn,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
Colorful: true,
|
||||
})
|
||||
// Recorder Recorder logger records running SQL into a recorder instance
|
||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||
)
|
||||
|
||||
// New initialize logger
|
||||
func New(writer Writer, config Config) Interface {
|
||||
var (
|
||||
infoStr = "%s\n[info] "
|
||||
|
@ -179,10 +189,12 @@ type traceRecorder struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
// New new trace recorder
|
||||
func (l traceRecorder) New() *traceRecorder {
|
||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||
}
|
||||
|
||||
// Trace implement logger interface
|
||||
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
l.BeginAt = begin
|
||||
l.SQL, l.RowsAffected = fc()
|
||||
|
|
|
@ -30,9 +30,12 @@ func isPrintable(s []byte) bool {
|
|||
|
||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var convertParams func(interface{}, int)
|
||||
vars := make([]string, len(avars))
|
||||
var (
|
||||
convertParams func(interface{}, int)
|
||||
vars = make([]string, len(avars))
|
||||
)
|
||||
|
||||
convertParams = func(v interface{}, idx int) {
|
||||
switch v := v.(type) {
|
||||
|
@ -64,11 +67,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||
}
|
||||
case fmt.Stringer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
}
|
||||
case []byte:
|
||||
if isPrintable(v) {
|
||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||
|
|
|
@ -9,7 +9,7 @@ require (
|
|||
github.com/jinzhu/now v1.1.4
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/mattn/go-sqlite3 v1.14.11 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect
|
||||
golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect
|
||||
gorm.io/driver/mysql v1.2.3
|
||||
gorm.io/driver/postgres v1.2.3
|
||||
gorm.io/driver/sqlite v1.2.6
|
||||
|
|
|
@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type ageInt int8
|
||||
|
||||
func (ageInt) String() string {
|
||||
return "age"
|
||||
}
|
||||
|
||||
type ageBool bool
|
||||
|
||||
func (ageBool) String() string {
|
||||
return "age"
|
||||
}
|
||||
|
||||
type ageUint64 uint64
|
||||
|
||||
func (ageUint64) String() string {
|
||||
return "age"
|
||||
}
|
||||
|
||||
type ageFloat float64
|
||||
|
||||
func (ageFloat) String() string {
|
||||
return "age"
|
||||
}
|
||||
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
user := *GetUser("explain-sql", Config{})
|
||||
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
|
||||
stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement
|
||||
sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) {
|
||||
t.Errorf("Failed to generate sql, got %v", sql)
|
||||
}
|
||||
|
||||
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement
|
||||
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) {
|
||||
t.Errorf("Failed to generate sql, got %v", sql)
|
||||
}
|
||||
|
||||
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement
|
||||
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) {
|
||||
t.Errorf("Failed to generate sql, got %v", sql)
|
||||
}
|
||||
|
||||
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement
|
||||
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) {
|
||||
t.Errorf("Failed to generate sql, got %v", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupConditions(t *testing.T) {
|
||||
type Pizza struct {
|
||||
ID uint
|
||||
|
|
Loading…
Reference in New Issue