Add QuoteTo method

This commit is contained in:
Jinzhu 2020-03-08 23:30:16 +08:00
parent 5fce17543a
commit 078ba75b9c
9 changed files with 56 additions and 43 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings"
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "@p" + strconv.Itoa(len(stmt.Vars)) return "@p" + strconv.Itoa(len(stmt.Vars))
} }
func (dialector Dialector) QuoteChars() [2]byte { func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
return [2]byte{'"', '"'} // `name` builder.WriteByte('"')
builder.WriteString(str)
builder.WriteByte('"')
} }
var numericPlaceholder = regexp.MustCompile("@p(\\d+)") var numericPlaceholder = regexp.MustCompile("@p(\\d+)")

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"math" "math"
"strings"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -39,8 +40,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?" return "?"
} }
func (dialector Dialector) QuoteChars() [2]byte { func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
return [2]byte{'`', '`'} // `name` builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
} }
func (dialector Dialector) Explain(sql string, vars ...interface{}) string { func (dialector Dialector) Explain(sql string, vars ...interface{}) string {

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "$" + strconv.Itoa(len(stmt.Vars)) return "$" + strconv.Itoa(len(stmt.Vars))
} }
func (dialector Dialector) QuoteChars() [2]byte { func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
return [2]byte{'"', '"'} // "name" builder.WriteByte('"')
builder.WriteString(str)
builder.WriteByte('"')
} }
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")

View File

@ -2,6 +2,7 @@ package sqlite
import ( import (
"database/sql" "database/sql"
"strings"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
@ -38,8 +39,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?" return "?"
} }
func (dialector Dialector) QuoteChars() [2]byte { func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
return [2]byte{'`', '`'} // `name` builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
} }
func (dialector Dialector) Explain(sql string, vars ...interface{}) string { func (dialector Dialector) Explain(sql string, vars ...interface{}) string {

4
go.mod
View File

@ -3,6 +3,10 @@ module github.com/jinzhu/gorm
go 1.13 go 1.13
require ( require (
github.com/denisenkom/go-mssqldb v0.0.0-20200206145737-bbfc9a55622e // indirect
github.com/go-sql-driver/mysql v1.5.0 // indirect
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.3.0 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
) )

View File

@ -79,7 +79,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
if dialector != nil { if dialector != nil {
err = dialector.Initialize(db) err = dialector.Initialize(db)
db.quoteChars = dialector.QuoteChars()
} }
return return
} }

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
) )
@ -13,7 +14,7 @@ type Dialector interface {
Migrator(db *DB) Migrator Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string DataTypeOf(*schema.Field) string
BindVar(stmt *Statement, v interface{}) string BindVar(stmt *Statement, v interface{}) string
QuoteChars() [2]byte QuoteTo(*strings.Builder, string)
Explain(sql string, vars ...interface{}) string Explain(sql string, vars ...interface{}) string
} }

View File

@ -76,65 +76,58 @@ func (stmt *Statement) WriteByte(c byte) (err error) {
return stmt.SQL.WriteByte(c) return stmt.SQL.WriteByte(c)
} }
// WriteQuoted write quoted field // WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(field interface{}) (err error) { func (stmt *Statement) WriteQuoted(value interface{}) error {
_, err = stmt.SQL.WriteString(stmt.Quote(field)) stmt.QuoteTo(&stmt.SQL, value)
return return nil
} }
// Quote returns quoted value // QuoteTo write quoted value to writer
func (stmt Statement) Quote(field interface{}) string { func (stmt Statement) QuoteTo(writer *strings.Builder, field interface{}) {
var str strings.Builder
str.WriteByte(stmt.DB.quoteChars[0])
switch v := field.(type) { switch v := field.(type) {
case clause.Table: case clause.Table:
if v.Name == clause.CurrentTable { if v.Name == clause.CurrentTable {
str.WriteString(stmt.Table) stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
} else { } else {
str.WriteString(v.Name) stmt.DB.Dialector.QuoteTo(writer, v.Name)
} }
if v.Alias != "" { if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1]) writer.WriteString(" AS ")
str.WriteString(" AS ") stmt.DB.Dialector.QuoteTo(writer, v.Alias)
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
} }
case clause.Column: case clause.Column:
if v.Table != "" { if v.Table != "" {
if v.Table == clause.CurrentTable { if v.Table == clause.CurrentTable {
str.WriteString(stmt.Table) stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
} else { } else {
str.WriteString(v.Table) stmt.DB.Dialector.QuoteTo(writer, v.Table)
} }
str.WriteByte(stmt.DB.quoteChars[1]) writer.WriteByte('.')
str.WriteByte('.')
str.WriteByte(stmt.DB.quoteChars[0])
} }
if v.Name == clause.PrimaryKey { if v.Name == clause.PrimaryKey {
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName) stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
} }
} else { } else {
str.WriteString(v.Name) stmt.DB.Dialector.QuoteTo(writer, v.Name)
} }
if v.Alias != "" { if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1]) writer.WriteString(" AS ")
str.WriteString(" AS ") stmt.DB.Dialector.QuoteTo(writer, v.Alias)
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
} }
default: default:
str.WriteString(fmt.Sprint(field)) stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
}
} }
str.WriteByte(stmt.DB.quoteChars[1]) // Quote returns quoted value
return str.String() func (stmt Statement) Quote(field interface{}) string {
var builder strings.Builder
stmt.QuoteTo(&builder, field)
return builder.String()
} }
// Write write string // Write write string

View File

@ -1,6 +1,8 @@
package tests package tests
import ( import (
"strings"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
@ -21,8 +23,10 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?" return "?"
} }
func (DummyDialector) QuoteChars() [2]byte { func (DummyDialector) QuoteTo(builder *strings.Builder, str string) {
return [2]byte{'`', '`'} // `name` builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
} }
func (DummyDialector) Explain(sql string, vars ...interface{}) string { func (DummyDialector) Explain(sql string, vars ...interface{}) string {