forked from mirror/gorm
Better support Stringer when explain SQL
This commit is contained in:
parent
d22215129e
commit
4eeb839cea
|
@ -12,6 +12,7 @@ import (
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrRecordNotFound record not found error
|
||||||
var ErrRecordNotFound = errors.New("record not found")
|
var ErrRecordNotFound = errors.New("record not found")
|
||||||
|
|
||||||
// Colors
|
// Colors
|
||||||
|
@ -30,13 +31,17 @@ const (
|
||||||
YellowBold = "\033[33;1m"
|
YellowBold = "\033[33;1m"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LogLevel
|
// LogLevel log level
|
||||||
type LogLevel int
|
type LogLevel int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Silent silent log level
|
||||||
Silent LogLevel = iota + 1
|
Silent LogLevel = iota + 1
|
||||||
|
// Error error log level
|
||||||
Error
|
Error
|
||||||
|
// Warn warn log level
|
||||||
Warn
|
Warn
|
||||||
|
// Info info log level
|
||||||
Info
|
Info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,6 +50,7 @@ type Writer interface {
|
||||||
Printf(string, ...interface{})
|
Printf(string, ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Config logger config
|
||||||
type Config struct {
|
type Config struct {
|
||||||
SlowThreshold time.Duration
|
SlowThreshold time.Duration
|
||||||
Colorful bool
|
Colorful bool
|
||||||
|
@ -62,16 +68,20 @@ type Interface interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
// Discard Discard logger will print any log to ioutil.Discard
|
||||||
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
||||||
|
// Default Default logger
|
||||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||||
SlowThreshold: 200 * time.Millisecond,
|
SlowThreshold: 200 * time.Millisecond,
|
||||||
LogLevel: Warn,
|
LogLevel: Warn,
|
||||||
IgnoreRecordNotFoundError: false,
|
IgnoreRecordNotFoundError: false,
|
||||||
Colorful: true,
|
Colorful: true,
|
||||||
})
|
})
|
||||||
|
// Recorder Recorder logger records running SQL into a recorder instance
|
||||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// New initialize logger
|
||||||
func New(writer Writer, config Config) Interface {
|
func New(writer Writer, config Config) Interface {
|
||||||
var (
|
var (
|
||||||
infoStr = "%s\n[info] "
|
infoStr = "%s\n[info] "
|
||||||
|
@ -179,10 +189,12 @@ type traceRecorder struct {
|
||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New new trace recorder
|
||||||
func (l traceRecorder) New() *traceRecorder {
|
func (l traceRecorder) New() *traceRecorder {
|
||||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trace implement logger interface
|
||||||
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||||
l.BeginAt = begin
|
l.BeginAt = begin
|
||||||
l.SQL, l.RowsAffected = fc()
|
l.SQL, l.RowsAffected = fc()
|
||||||
|
|
|
@ -30,9 +30,12 @@ func isPrintable(s []byte) bool {
|
||||||
|
|
||||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||||
|
|
||||||
|
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||||
var convertParams func(interface{}, int)
|
var (
|
||||||
vars := make([]string, len(avars))
|
convertParams func(interface{}, int)
|
||||||
|
vars = make([]string, len(avars))
|
||||||
|
)
|
||||||
|
|
||||||
convertParams = func(v interface{}, idx int) {
|
convertParams = func(v interface{}, idx int) {
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
|
@ -64,11 +67,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||||
}
|
}
|
||||||
case fmt.Stringer:
|
case fmt.Stringer:
|
||||||
reflectValue := reflect.ValueOf(v)
|
reflectValue := reflect.ValueOf(v)
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
|
||||||
|
case reflect.Bool:
|
||||||
|
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||||
|
case reflect.String:
|
||||||
|
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||||
|
default:
|
||||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||||
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||||
} else {
|
} else {
|
||||||
vars[idx] = nullStr
|
vars[idx] = nullStr
|
||||||
}
|
}
|
||||||
|
}
|
||||||
case []byte:
|
case []byte:
|
||||||
if isPrintable(v) {
|
if isPrintable(v) {
|
||||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||||
|
|
|
@ -9,7 +9,7 @@ require (
|
||||||
github.com/jinzhu/now v1.1.4
|
github.com/jinzhu/now v1.1.4
|
||||||
github.com/lib/pq v1.10.4
|
github.com/lib/pq v1.10.4
|
||||||
github.com/mattn/go-sqlite3 v1.14.11 // indirect
|
github.com/mattn/go-sqlite3 v1.14.11 // indirect
|
||||||
golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect
|
golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect
|
||||||
gorm.io/driver/mysql v1.2.3
|
gorm.io/driver/mysql v1.2.3
|
||||||
gorm.io/driver/postgres v1.2.3
|
gorm.io/driver/postgres v1.2.3
|
||||||
gorm.io/driver/sqlite v1.2.6
|
gorm.io/driver/sqlite v1.2.6
|
||||||
|
|
|
@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ageInt int8
|
||||||
|
|
||||||
|
func (ageInt) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
type ageBool bool
|
||||||
|
|
||||||
|
func (ageBool) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
type ageUint64 uint64
|
||||||
|
|
||||||
|
func (ageUint64) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
type ageFloat float64
|
||||||
|
|
||||||
|
func (ageFloat) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExplainSQL(t *testing.T) {
|
||||||
|
user := *GetUser("explain-sql", Config{})
|
||||||
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
|
||||||
|
stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement
|
||||||
|
sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement
|
||||||
|
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement
|
||||||
|
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement
|
||||||
|
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGroupConditions(t *testing.T) {
|
func TestGroupConditions(t *testing.T) {
|
||||||
type Pizza struct {
|
type Pizza struct {
|
||||||
ID uint
|
ID uint
|
||||||
|
|
Loading…
Reference in New Issue