forked from mirror/gorm
Explain SQL for dialects
This commit is contained in:
parent
fe24c3f105
commit
bc5ceff82f
|
@ -3,6 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/logger"
|
"github.com/jinzhu/gorm/logger"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
|
@ -69,6 +70,7 @@ func (cs *callbacks) Raw() *processor {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Execute(db *DB) {
|
func (p *processor) Execute(db *DB) {
|
||||||
|
curTime := time.Now()
|
||||||
if stmt := db.Statement; stmt != nil {
|
if stmt := db.Statement; stmt != nil {
|
||||||
if stmt.Model == nil {
|
if stmt.Model == nil {
|
||||||
stmt.Model = stmt.Dest
|
stmt.Model = stmt.Dest
|
||||||
|
@ -86,6 +88,12 @@ func (p *processor) Execute(db *DB) {
|
||||||
for _, f := range p.fns {
|
for _, f := range p.fns {
|
||||||
f(db)
|
f(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stmt := db.Statement; stmt != nil {
|
||||||
|
db.Logger.RunWith(logger.Info, func() {
|
||||||
|
db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Get(name string) func(*DB) {
|
func (p *processor) Get(name string) func(*DB) {
|
||||||
|
|
|
@ -3,11 +3,13 @@ package mssql
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
_ "github.com/denisenkom/go-mssqldb"
|
_ "github.com/denisenkom/go-mssqldb"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/callbacks"
|
"github.com/jinzhu/gorm/callbacks"
|
||||||
|
"github.com/jinzhu/gorm/logger"
|
||||||
"github.com/jinzhu/gorm/migrator"
|
"github.com/jinzhu/gorm/migrator"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
||||||
return [2]byte{'"', '"'} // `name`
|
return [2]byte{'"', '"'} // `name`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
|
||||||
|
|
||||||
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||||
|
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||||
|
}
|
||||||
|
|
||||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||||
switch field.DataType {
|
switch field.DataType {
|
||||||
case schema.Bool:
|
case schema.Bool:
|
||||||
|
|
|
@ -2,6 +2,7 @@ package mssql_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
@ -15,7 +16,12 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil {
|
dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||||
|
if os.Getenv("GORM_DSN") != "" {
|
||||||
|
dsn = os.Getenv("GORM_DSN")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil {
|
||||||
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/callbacks"
|
"github.com/jinzhu/gorm/callbacks"
|
||||||
|
"github.com/jinzhu/gorm/logger"
|
||||||
"github.com/jinzhu/gorm/migrator"
|
"github.com/jinzhu/gorm/migrator"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
@ -42,6 +43,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
||||||
return [2]byte{'`', '`'} // `name`
|
return [2]byte{'`', '`'} // `name`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||||
|
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||||
|
}
|
||||||
|
|
||||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||||
switch field.DataType {
|
switch field.DataType {
|
||||||
case schema.Bool:
|
case schema.Bool:
|
||||||
|
|
|
@ -2,6 +2,7 @@ package mysql_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
@ -15,7 +16,12 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil {
|
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
|
||||||
|
if os.Getenv("GORM_DSN") != "" {
|
||||||
|
dsn = os.Getenv("GORM_DSN")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil {
|
||||||
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,12 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/callbacks"
|
"github.com/jinzhu/gorm/callbacks"
|
||||||
|
"github.com/jinzhu/gorm/logger"
|
||||||
"github.com/jinzhu/gorm/migrator"
|
"github.com/jinzhu/gorm/migrator"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
||||||
return [2]byte{'"', '"'} // "name"
|
return [2]byte{'"', '"'} // "name"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
|
||||||
|
|
||||||
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||||
|
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||||
|
}
|
||||||
|
|
||||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||||
switch field.DataType {
|
switch field.DataType {
|
||||||
case schema.Bool:
|
case schema.Bool:
|
||||||
|
|
|
@ -2,6 +2,7 @@ package postgres_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
@ -15,7 +16,12 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil {
|
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
|
||||||
|
if os.Getenv("GORM_DSN") != "" {
|
||||||
|
dsn = os.Getenv("GORM_DSN")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil {
|
||||||
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/callbacks"
|
"github.com/jinzhu/gorm/callbacks"
|
||||||
|
"github.com/jinzhu/gorm/logger"
|
||||||
"github.com/jinzhu/gorm/migrator"
|
"github.com/jinzhu/gorm/migrator"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
@ -41,6 +42,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
||||||
return [2]byte{'`', '`'} // `name`
|
return [2]byte{'`', '`'} // `name`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||||
|
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||||
|
}
|
||||||
|
|
||||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||||
switch field.DataType {
|
switch field.DataType {
|
||||||
case schema.Bool:
|
case schema.Bool:
|
||||||
|
|
|
@ -14,6 +14,7 @@ type Dialector interface {
|
||||||
DataTypeOf(*schema.Field) string
|
DataTypeOf(*schema.Field) string
|
||||||
BindVar(stmt *Statement, v interface{}) string
|
BindVar(stmt *Statement, v interface{}) string
|
||||||
QuoteChars() [2]byte
|
QuoteChars() [2]byte
|
||||||
|
Explain(sql string, vars ...interface{}) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommonDB common db interface
|
// CommonDB common db interface
|
||||||
|
|
|
@ -11,9 +11,9 @@ type LogLevel int
|
||||||
var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)}
|
var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Info LogLevel = iota + 1
|
Error LogLevel = iota + 1
|
||||||
Warn
|
Warn
|
||||||
Error
|
Info
|
||||||
)
|
)
|
||||||
|
|
||||||
// Interface logger interface
|
// Interface logger interface
|
||||||
|
@ -22,6 +22,7 @@ type Interface interface {
|
||||||
Info(string, ...interface{})
|
Info(string, ...interface{})
|
||||||
Warn(string, ...interface{})
|
Warn(string, ...interface{})
|
||||||
Error(string, ...interface{})
|
Error(string, ...interface{})
|
||||||
|
RunWith(LogLevel, func())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Writer log writer interface
|
// Writer log writer interface
|
||||||
|
@ -40,21 +41,27 @@ func (logger Logger) LogMode(level LogLevel) Interface {
|
||||||
|
|
||||||
// Info print info
|
// Info print info
|
||||||
func (logger Logger) Info(msg string, data ...interface{}) {
|
func (logger Logger) Info(msg string, data ...interface{}) {
|
||||||
if logger.logLevel <= Info {
|
if logger.logLevel >= Info {
|
||||||
logger.Print("[info] " + fmt.Sprintf(msg, data...))
|
logger.Print("[info] " + fmt.Sprintf(msg, data...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn print warn messages
|
// Warn print warn messages
|
||||||
func (logger Logger) Warn(msg string, data ...interface{}) {
|
func (logger Logger) Warn(msg string, data ...interface{}) {
|
||||||
if logger.logLevel <= Warn {
|
if logger.logLevel >= Warn {
|
||||||
logger.Print("[warn] " + fmt.Sprintf(msg, data...))
|
logger.Print("[warn] " + fmt.Sprintf(msg, data...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error print error messages
|
// Error print error messages
|
||||||
func (logger Logger) Error(msg string, data ...interface{}) {
|
func (logger Logger) Error(msg string, data ...interface{}) {
|
||||||
if logger.logLevel <= Error {
|
if logger.logLevel >= Error {
|
||||||
logger.Print("[error] " + fmt.Sprintf(msg, data...))
|
logger.Print("[error] " + fmt.Sprintf(msg, data...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (logger Logger) RunWith(logLevel LogLevel, fc func()) {
|
||||||
|
if logger.logLevel >= logLevel {
|
||||||
|
fc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isPrintable(s []byte) bool {
|
||||||
|
for _, r := range s {
|
||||||
|
if !unicode.IsPrint(rune(r)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string {
|
||||||
|
for idx, v := range vars {
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
v, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := v.(type) {
|
||||||
|
case bool:
|
||||||
|
vars[idx] = fmt.Sprint(v)
|
||||||
|
case time.Time:
|
||||||
|
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
|
||||||
|
case *time.Time:
|
||||||
|
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
|
||||||
|
case []byte:
|
||||||
|
if isPrintable(v) {
|
||||||
|
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||||
|
} else {
|
||||||
|
vars[idx] = escaper + "<binary>" + escaper
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
|
vars[idx] = fmt.Sprintf("%d", v)
|
||||||
|
case float64, float32:
|
||||||
|
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||||
|
case string:
|
||||||
|
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
|
||||||
|
default:
|
||||||
|
if v == nil {
|
||||||
|
vars[idx] = "NULL"
|
||||||
|
} else {
|
||||||
|
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if numericPlaceholder == nil {
|
||||||
|
for _, v := range vars {
|
||||||
|
sql = strings.Replace(sql, "?", v.(string), 1)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1")
|
||||||
|
for idx, v := range vars {
|
||||||
|
sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sql
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package logger_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm/logger"
|
||||||
|
"github.com/jinzhu/now"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExplainSQL(t *testing.T) {
|
||||||
|
tt := now.MustParse("2020-02-23 11:10:10")
|
||||||
|
|
||||||
|
results := []struct {
|
||||||
|
SQL string
|
||||||
|
NumericRegexp *regexp.Regexp
|
||||||
|
Vars []interface{}
|
||||||
|
Result string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
|
||||||
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)",
|
||||||
|
NumericRegexp: regexp.MustCompile("@p(\\d+)"),
|
||||||
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
|
||||||
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)",
|
||||||
|
NumericRegexp: regexp.MustCompile("\\$(\\d+)"),
|
||||||
|
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt},
|
||||||
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, r := range results {
|
||||||
|
if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result {
|
||||||
|
t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
dialects=("postgres" "mysql" "mssql" "sqlite")
|
dialects=("sqlite" "mysql" "postgres" "mssql")
|
||||||
|
|
||||||
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
||||||
cd ..
|
cd ..
|
||||||
|
|
Loading…
Reference in New Issue