From c3b798aec869da7b8c513c45e275a4310dfede31 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 13:22:08 +0800 Subject: [PATCH] Refactor SQL Explainer --- logger/sql.go | 33 ++++++++++++++++++++++++--------- logger/sql_test.go | 32 ++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index b0e11027..f63dc160 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -3,6 +3,7 @@ package logger import ( "database/sql/driver" "fmt" + "reflect" "regexp" "strconv" "strings" @@ -19,19 +20,17 @@ func isPrintable(s []byte) bool { return true } -func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { - for idx, v := range vars { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } +var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { + var convertParams func(interface{}, int) + + convertParams = func(v interface{}, idx int) { switch v := v.(type) { case bool: vars[idx] = fmt.Sprint(v) case time.Time: vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper - case *time.Time: - vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper @@ -48,19 +47,35 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v if v == nil { vars[idx] = "NULL" } else { + rv := reflect.Indirect(reflect.ValueOf(v)) + for _, t := range convertableTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } + } + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper } } } + for idx, v := range vars { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + convertParams(v, idx) + } + if numericPlaceholder == nil { for _, v := range vars { sql = strings.Replace(sql, "?", v.(string), 1) } } else { - sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1") + sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { - sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1) + sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1) } } diff --git a/logger/sql_test.go b/logger/sql_test.go index d98e19b3..829d6302 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -9,7 +9,13 @@ import ( ) func TestExplainSQL(t *testing.T) { - tt := now.MustParse("2020-02-23 11:10:10") + type role string + type password []byte + var ( + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password([]byte("pass")) + ) results := []struct { SQL string @@ -18,22 +24,28 @@ func TestExplainSQL(t *testing.T) { Result string }{ { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, - Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", NumericRegexp: regexp.MustCompile("@p(\\d+)"), - Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($2, $3, $0, $1, $6, $7, $4, $5, $8, $9, $10)", NumericRegexp: regexp.MustCompile("\\$(\\d+)"), - Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p10, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9)", + NumericRegexp: regexp.MustCompile("@p(\\d+)"), + Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, }