Explain SQL for dialects

This commit is contained in:
Jinzhu 2020-02-23 12:39:26 +08:00
parent fe24c3f105
commit bc5ceff82f
13 changed files with 182 additions and 9 deletions

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
@ -69,6 +70,7 @@ func (cs *callbacks) Raw() *processor {
} }
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) {
curTime := time.Now()
if stmt := db.Statement; stmt != nil { if stmt := db.Statement; stmt != nil {
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
@ -86,6 +88,12 @@ func (p *processor) Execute(db *DB) {
for _, f := range p.fns { for _, f := range p.fns {
f(db) f(db)
} }
if stmt := db.Statement; stmt != nil {
db.Logger.RunWith(logger.Info, func() {
db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars))
})
}
} }
func (p *processor) Get(name string) func(*DB) { func (p *processor) Get(name string) func(*DB) {

View File

@ -3,11 +3,13 @@ package mssql
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"regexp"
"strconv" "strconv"
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
) )
@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'"', '"'} // `name` return [2]byte{'"', '"'} // `name`
} }
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string { func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType { switch field.DataType {
case schema.Bool: case schema.Bool:

View File

@ -2,6 +2,7 @@ package mssql_test
import ( import (
"fmt" "fmt"
"os"
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -15,7 +16,12 @@ var (
) )
func init() { func init() {
if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil { dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
if os.Getenv("GORM_DSN") != "" {
dsn = os.Getenv("GORM_DSN")
}
if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil {
panic(fmt.Sprintf("failed to initialize database, got error %v", err)) panic(fmt.Sprintf("failed to initialize database, got error %v", err))
} }
} }

View File

@ -8,6 +8,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
) )
@ -42,6 +43,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name` return [2]byte{'`', '`'} // `name`
} }
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string { func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType { switch field.DataType {
case schema.Bool: case schema.Bool:

View File

@ -2,6 +2,7 @@ package mysql_test
import ( import (
"fmt" "fmt"
"os"
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -15,7 +16,12 @@ var (
) )
func init() { func init() {
if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil { dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
if os.Getenv("GORM_DSN") != "" {
dsn = os.Getenv("GORM_DSN")
}
if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil {
panic(fmt.Sprintf("failed to initialize database, got error %v", err)) panic(fmt.Sprintf("failed to initialize database, got error %v", err))
} }
} }

View File

@ -3,10 +3,12 @@ package postgres
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"regexp"
"strconv" "strconv"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'"', '"'} // "name" return [2]byte{'"', '"'} // "name"
} }
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string { func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType { switch field.DataType {
case schema.Bool: case schema.Bool:

View File

@ -2,6 +2,7 @@ package postgres_test
import ( import (
"fmt" "fmt"
"os"
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -15,7 +16,12 @@ var (
) )
func init() { func init() {
if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil { dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
if os.Getenv("GORM_DSN") != "" {
dsn = os.Getenv("GORM_DSN")
}
if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil {
panic(fmt.Sprintf("failed to initialize database, got error %v", err)) panic(fmt.Sprintf("failed to initialize database, got error %v", err))
} }
} }

View File

@ -5,6 +5,7 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -41,6 +42,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name` return [2]byte{'`', '`'} // `name`
} }
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string { func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType { switch field.DataType {
case schema.Bool: case schema.Bool:

View File

@ -14,6 +14,7 @@ type Dialector interface {
DataTypeOf(*schema.Field) string DataTypeOf(*schema.Field) string
BindVar(stmt *Statement, v interface{}) string BindVar(stmt *Statement, v interface{}) string
QuoteChars() [2]byte QuoteChars() [2]byte
Explain(sql string, vars ...interface{}) string
} }
// CommonDB common db interface // CommonDB common db interface

View File

@ -11,9 +11,9 @@ type LogLevel int
var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)} var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)}
const ( const (
Info LogLevel = iota + 1 Error LogLevel = iota + 1
Warn Warn
Error Info
) )
// Interface logger interface // Interface logger interface
@ -22,6 +22,7 @@ type Interface interface {
Info(string, ...interface{}) Info(string, ...interface{})
Warn(string, ...interface{}) Warn(string, ...interface{})
Error(string, ...interface{}) Error(string, ...interface{})
RunWith(LogLevel, func())
} }
// Writer log writer interface // Writer log writer interface
@ -40,21 +41,27 @@ func (logger Logger) LogMode(level LogLevel) Interface {
// Info print info // Info print info
func (logger Logger) Info(msg string, data ...interface{}) { func (logger Logger) Info(msg string, data ...interface{}) {
if logger.logLevel <= Info { if logger.logLevel >= Info {
logger.Print("[info] " + fmt.Sprintf(msg, data...)) logger.Print("[info] " + fmt.Sprintf(msg, data...))
} }
} }
// Warn print warn messages // Warn print warn messages
func (logger Logger) Warn(msg string, data ...interface{}) { func (logger Logger) Warn(msg string, data ...interface{}) {
if logger.logLevel <= Warn { if logger.logLevel >= Warn {
logger.Print("[warn] " + fmt.Sprintf(msg, data...)) logger.Print("[warn] " + fmt.Sprintf(msg, data...))
} }
} }
// Error print error messages // Error print error messages
func (logger Logger) Error(msg string, data ...interface{}) { func (logger Logger) Error(msg string, data ...interface{}) {
if logger.logLevel <= Error { if logger.logLevel >= Error {
logger.Print("[error] " + fmt.Sprintf(msg, data...)) logger.Print("[error] " + fmt.Sprintf(msg, data...))
} }
} }
func (logger Logger) RunWith(logLevel LogLevel, fc func()) {
if logger.logLevel >= logLevel {
fc()
}
}

68
logger/sql.go Normal file
View File

@ -0,0 +1,68 @@
package logger
import (
"database/sql/driver"
"fmt"
"regexp"
"strconv"
"strings"
"time"
"unicode"
)
func isPrintable(s []byte) bool {
for _, r := range s {
if !unicode.IsPrint(rune(r)) {
return false
}
}
return true
}
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string {
for idx, v := range vars {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
switch v := v.(type) {
case bool:
vars[idx] = fmt.Sprint(v)
case time.Time:
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
case *time.Time:
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
case []byte:
if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
} else {
vars[idx] = escaper + "<binary>" + escaper
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = fmt.Sprintf("%d", v)
case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v)
case string:
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
default:
if v == nil {
vars[idx] = "NULL"
} else {
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
}
}
}
if numericPlaceholder == nil {
for _, v := range vars {
sql = strings.Replace(sql, "?", v.(string), 1)
}
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1")
for idx, v := range vars {
sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1)
}
}
return sql
}

45
logger/sql_test.go Normal file
View File

@ -0,0 +1,45 @@
package logger_test
import (
"regexp"
"testing"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/now"
)
func TestExplainSQL(t *testing.T) {
tt := now.MustParse("2020-02-23 11:10:10")
results := []struct {
SQL string
NumericRegexp *regexp.Regexp
Vars []interface{}
Result string
}{
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)",
NumericRegexp: regexp.MustCompile("@p(\\d+)"),
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)",
NumericRegexp: regexp.MustCompile("\\$(\\d+)"),
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
},
}
for idx, r := range results {
if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result {
t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result)
}
}
}

View File

@ -1,4 +1,4 @@
dialects=("postgres" "mysql" "mssql" "sqlite") dialects=("sqlite" "mysql" "postgres" "mssql")
if [[ $(pwd) == *"gorm/tests"* ]]; then if [[ $(pwd) == *"gorm/tests"* ]]; then
cd .. cd ..