From bc5ceff82ff17b72081cc40bb7711489312349c4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 12:39:26 +0800 Subject: [PATCH] Explain SQL for dialects --- callbacks.go | 8 ++++ dialects/mssql/mssql.go | 8 ++++ dialects/mssql/mssql_test.go | 8 +++- dialects/mysql/mysql.go | 5 +++ dialects/mysql/mysql_test.go | 8 +++- dialects/postgres/postgres.go | 8 ++++ dialects/postgres/postgres_test.go | 8 +++- dialects/sqlite/sqlite.go | 5 +++ interfaces.go | 1 + logger/logger.go | 17 +++++--- logger/sql.go | 68 ++++++++++++++++++++++++++++++ logger/sql_test.go | 45 ++++++++++++++++++++ tests/tests_all.sh | 2 +- 13 files changed, 182 insertions(+), 9 deletions(-) create mode 100644 logger/sql.go create mode 100644 logger/sql_test.go diff --git a/callbacks.go b/callbacks.go index 4f19a681..41951168 100644 --- a/callbacks.go +++ b/callbacks.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "time" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/schema" @@ -69,6 +70,7 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { + curTime := time.Now() if stmt := db.Statement; stmt != nil { if stmt.Model == nil { stmt.Model = stmt.Dest @@ -86,6 +88,12 @@ func (p *processor) Execute(db *DB) { for _, f := range p.fns { f(db) } + + if stmt := db.Statement; stmt != nil { + db.Logger.RunWith(logger.Info, func() { + db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars)) + }) + } } func (p *processor) Get(name string) func(*DB) { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 79e36385..b93cc8f6 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -3,11 +3,13 @@ package mssql import ( "database/sql" "fmt" + "regexp" "strconv" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" ) @@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'"', '"'} // `name` } +var numericPlaceholder = regexp.MustCompile("@p(\\d+)") + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/dialects/mssql/mssql_test.go b/dialects/mssql/mssql_test.go index b56e7369..49b3cd6a 100644 --- a/dialects/mssql/mssql_test.go +++ b/dialects/mssql/mssql_test.go @@ -2,6 +2,7 @@ package mssql_test import ( "fmt" + "os" "testing" "github.com/jinzhu/gorm" @@ -15,7 +16,12 @@ var ( ) func init() { - if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil { + dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + if os.Getenv("GORM_DSN") != "" { + dsn = os.Getenv("GORM_DSN") + } + + if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil { panic(fmt.Sprintf("failed to initialize database, got error %v", err)) } } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 629b89df..e1bf985a 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -8,6 +8,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" ) @@ -42,6 +43,10 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index f079ad60..5bc1debd 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -2,6 +2,7 @@ package mysql_test import ( "fmt" + "os" "testing" "github.com/jinzhu/gorm" @@ -15,7 +16,12 @@ var ( ) func init() { - if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil { + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" + if os.Getenv("GORM_DSN") != "" { + dsn = os.Getenv("GORM_DSN") + } + + if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil { panic(fmt.Sprintf("failed to initialize database, got error %v", err)) } } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index bb9726a8..3ee4ba9f 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -3,10 +3,12 @@ package postgres import ( "database/sql" "fmt" + "regexp" "strconv" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" _ "github.com/lib/pq" @@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'"', '"'} // "name" } +var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/dialects/postgres/postgres_test.go b/dialects/postgres/postgres_test.go index 84c0fe53..a1252d92 100644 --- a/dialects/postgres/postgres_test.go +++ b/dialects/postgres/postgres_test.go @@ -2,6 +2,7 @@ package postgres_test import ( "fmt" + "os" "testing" "github.com/jinzhu/gorm" @@ -15,7 +16,12 @@ var ( ) func init() { - if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil { + dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + if os.Getenv("GORM_DSN") != "" { + dsn = os.Getenv("GORM_DSN") + } + + if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil { panic(fmt.Sprintf("failed to initialize database, got error %v", err)) } } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 54fa7de0..a6aba066 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -5,6 +5,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" _ "github.com/mattn/go-sqlite3" @@ -41,6 +42,10 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/interfaces.go b/interfaces.go index 8f0f3085..bf1aab46 100644 --- a/interfaces.go +++ b/interfaces.go @@ -14,6 +14,7 @@ type Dialector interface { DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string QuoteChars() [2]byte + Explain(sql string, vars ...interface{}) string } // CommonDB common db interface diff --git a/logger/logger.go b/logger/logger.go index cad9be16..049b724d 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -11,9 +11,9 @@ type LogLevel int var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)} const ( - Info LogLevel = iota + 1 + Error LogLevel = iota + 1 Warn - Error + Info ) // Interface logger interface @@ -22,6 +22,7 @@ type Interface interface { Info(string, ...interface{}) Warn(string, ...interface{}) Error(string, ...interface{}) + RunWith(LogLevel, func()) } // Writer log writer interface @@ -40,21 +41,27 @@ func (logger Logger) LogMode(level LogLevel) Interface { // Info print info func (logger Logger) Info(msg string, data ...interface{}) { - if logger.logLevel <= Info { + if logger.logLevel >= Info { logger.Print("[info] " + fmt.Sprintf(msg, data...)) } } // Warn print warn messages func (logger Logger) Warn(msg string, data ...interface{}) { - if logger.logLevel <= Warn { + if logger.logLevel >= Warn { logger.Print("[warn] " + fmt.Sprintf(msg, data...)) } } // Error print error messages func (logger Logger) Error(msg string, data ...interface{}) { - if logger.logLevel <= Error { + if logger.logLevel >= Error { logger.Print("[error] " + fmt.Sprintf(msg, data...)) } } + +func (logger Logger) RunWith(logLevel LogLevel, fc func()) { + if logger.logLevel >= logLevel { + fc() + } +} diff --git a/logger/sql.go b/logger/sql.go new file mode 100644 index 00000000..b0e11027 --- /dev/null +++ b/logger/sql.go @@ -0,0 +1,68 @@ +package logger + +import ( + "database/sql/driver" + "fmt" + "regexp" + "strconv" + "strings" + "time" + "unicode" +) + +func isPrintable(s []byte) bool { + for _, r := range s { + if !unicode.IsPrint(rune(r)) { + return false + } + } + return true +} + +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { + for idx, v := range vars { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + switch v := v.(type) { + case bool: + vars[idx] = fmt.Sprint(v) + case time.Time: + vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + case *time.Time: + vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + case []byte: + if isPrintable(v) { + vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = escaper + "" + escaper + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + vars[idx] = fmt.Sprintf("%d", v) + case float64, float32: + vars[idx] = fmt.Sprintf("%.6f", v) + case string: + vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + default: + if v == nil { + vars[idx] = "NULL" + } else { + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + } + } + } + + if numericPlaceholder == nil { + for _, v := range vars { + sql = strings.Replace(sql, "?", v.(string), 1) + } + } else { + sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1") + for idx, v := range vars { + sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1) + } + } + + return sql +} diff --git a/logger/sql_test.go b/logger/sql_test.go new file mode 100644 index 00000000..d98e19b3 --- /dev/null +++ b/logger/sql_test.go @@ -0,0 +1,45 @@ +package logger_test + +import ( + "regexp" + "testing" + + "github.com/jinzhu/gorm/logger" + "github.com/jinzhu/now" +) + +func TestExplainSQL(t *testing.T) { + tt := now.MustParse("2020-02-23 11:10:10") + + results := []struct { + SQL string + NumericRegexp *regexp.Regexp + Vars []interface{} + Result string + }{ + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)", + NumericRegexp: regexp.MustCompile("@p(\\d+)"), + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)", + NumericRegexp: regexp.MustCompile("\\$(\\d+)"), + Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + }, + } + + for idx, r := range results { + if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result { + t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result) + } + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 91d415f1..cd42e1e0 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "mssql" "sqlite") +dialects=("sqlite" "mysql" "postgres" "mssql") if [[ $(pwd) == *"gorm/tests"* ]]; then cd ..