forked from mirror/gorm
Explain SQL for dialects
This commit is contained in:
parent
fe24c3f105
commit
bc5ceff82f
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
|
@ -69,6 +70,7 @@ func (cs *callbacks) Raw() *processor {
|
|||
}
|
||||
|
||||
func (p *processor) Execute(db *DB) {
|
||||
curTime := time.Now()
|
||||
if stmt := db.Statement; stmt != nil {
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
|
@ -86,6 +88,12 @@ func (p *processor) Execute(db *DB) {
|
|||
for _, f := range p.fns {
|
||||
f(db)
|
||||
}
|
||||
|
||||
if stmt := db.Statement; stmt != nil {
|
||||
db.Logger.RunWith(logger.Info, func() {
|
||||
db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processor) Get(name string) func(*DB) {
|
||||
|
|
|
@ -3,11 +3,13 @@ package mssql
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/gorm/migrator"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
|||
return [2]byte{'"', '"'} // `name`
|
||||
}
|
||||
|
||||
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
|
|
|
@ -2,6 +2,7 @@ package mssql_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
@ -15,7 +16,12 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil {
|
||||
dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||
if os.Getenv("GORM_DSN") != "" {
|
||||
dsn = os.Getenv("GORM_DSN")
|
||||
}
|
||||
|
||||
if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil {
|
||||
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/gorm/migrator"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
@ -42,6 +43,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
|||
return [2]byte{'`', '`'} // `name`
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
|
|
|
@ -2,6 +2,7 @@ package mysql_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
@ -15,7 +16,12 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil {
|
||||
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
|
||||
if os.Getenv("GORM_DSN") != "" {
|
||||
dsn = os.Getenv("GORM_DSN")
|
||||
}
|
||||
|
||||
if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil {
|
||||
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,10 +3,12 @@ package postgres
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/gorm/migrator"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
_ "github.com/lib/pq"
|
||||
|
@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
|||
return [2]byte{'"', '"'} // "name"
|
||||
}
|
||||
|
||||
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
|
|
|
@ -2,6 +2,7 @@ package postgres_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
@ -15,7 +16,12 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil {
|
||||
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
|
||||
if os.Getenv("GORM_DSN") != "" {
|
||||
dsn = os.Getenv("GORM_DSN")
|
||||
}
|
||||
|
||||
if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil {
|
||||
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/gorm/migrator"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
@ -41,6 +42,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
|||
return [2]byte{'`', '`'} // `name`
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
|
|
|
@ -14,6 +14,7 @@ type Dialector interface {
|
|||
DataTypeOf(*schema.Field) string
|
||||
BindVar(stmt *Statement, v interface{}) string
|
||||
QuoteChars() [2]byte
|
||||
Explain(sql string, vars ...interface{}) string
|
||||
}
|
||||
|
||||
// CommonDB common db interface
|
||||
|
|
|
@ -11,9 +11,9 @@ type LogLevel int
|
|||
var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)}
|
||||
|
||||
const (
|
||||
Info LogLevel = iota + 1
|
||||
Error LogLevel = iota + 1
|
||||
Warn
|
||||
Error
|
||||
Info
|
||||
)
|
||||
|
||||
// Interface logger interface
|
||||
|
@ -22,6 +22,7 @@ type Interface interface {
|
|||
Info(string, ...interface{})
|
||||
Warn(string, ...interface{})
|
||||
Error(string, ...interface{})
|
||||
RunWith(LogLevel, func())
|
||||
}
|
||||
|
||||
// Writer log writer interface
|
||||
|
@ -40,21 +41,27 @@ func (logger Logger) LogMode(level LogLevel) Interface {
|
|||
|
||||
// Info print info
|
||||
func (logger Logger) Info(msg string, data ...interface{}) {
|
||||
if logger.logLevel <= Info {
|
||||
if logger.logLevel >= Info {
|
||||
logger.Print("[info] " + fmt.Sprintf(msg, data...))
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (logger Logger) Warn(msg string, data ...interface{}) {
|
||||
if logger.logLevel <= Warn {
|
||||
if logger.logLevel >= Warn {
|
||||
logger.Print("[warn] " + fmt.Sprintf(msg, data...))
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (logger Logger) Error(msg string, data ...interface{}) {
|
||||
if logger.logLevel <= Error {
|
||||
if logger.logLevel >= Error {
|
||||
logger.Print("[error] " + fmt.Sprintf(msg, data...))
|
||||
}
|
||||
}
|
||||
|
||||
func (logger Logger) RunWith(logLevel LogLevel, fc func()) {
|
||||
if logger.logLevel >= logLevel {
|
||||
fc()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func isPrintable(s []byte) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsPrint(rune(r)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string {
|
||||
for idx, v := range vars {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
vars[idx] = fmt.Sprint(v)
|
||||
case time.Time:
|
||||
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
|
||||
case *time.Time:
|
||||
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
|
||||
case []byte:
|
||||
if isPrintable(v) {
|
||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
vars[idx] = fmt.Sprintf("%d", v)
|
||||
case float64, float32:
|
||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||
case string:
|
||||
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
|
||||
default:
|
||||
if v == nil {
|
||||
vars[idx] = "NULL"
|
||||
} else {
|
||||
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if numericPlaceholder == nil {
|
||||
for _, v := range vars {
|
||||
sql = strings.Replace(sql, "?", v.(string), 1)
|
||||
}
|
||||
} else {
|
||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1")
|
||||
for idx, v := range vars {
|
||||
sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1)
|
||||
}
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package logger_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/now"
|
||||
)
|
||||
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
tt := now.MustParse("2020-02-23 11:10:10")
|
||||
|
||||
results := []struct {
|
||||
SQL string
|
||||
NumericRegexp *regexp.Regexp
|
||||
Vars []interface{}
|
||||
Result string
|
||||
}{
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)",
|
||||
NumericRegexp: regexp.MustCompile("@p(\\d+)"),
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)",
|
||||
NumericRegexp: regexp.MustCompile("\\$(\\d+)"),
|
||||
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, r := range results {
|
||||
if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result {
|
||||
t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
dialects=("postgres" "mysql" "mssql" "sqlite")
|
||||
dialects=("sqlite" "mysql" "postgres" "mssql")
|
||||
|
||||
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
||||
cd ..
|
||||
|
|
Loading…
Reference in New Issue