From 078ba75b9cc749820610e11b205a2e219a5e7239 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 8 Mar 2020 23:30:16 +0800 Subject: [PATCH] Add QuoteTo method --- dialects/mssql/mssql.go | 7 +++-- dialects/mysql/mysql.go | 7 +++-- dialects/postgres/postgres.go | 7 +++-- dialects/sqlite/sqlite.go | 7 +++-- go.mod | 4 +++ gorm.go | 1 - interfaces.go | 3 +- statement.go | 55 +++++++++++++++-------------------- tests/dummy_dialecter.go | 8 +++-- 9 files changed, 56 insertions(+), 43 deletions(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index b93cc8f6..91574787 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -5,6 +5,7 @@ import ( "fmt" "regexp" "strconv" + "strings" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" @@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "@p" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'"', '"'} // `name` +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('"') + builder.WriteString(str) + builder.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("@p(\\d+)") diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index e1bf985a..9d16507e 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "math" + "strings" _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" @@ -39,8 +40,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'`', '`'} // `name` +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('`') + builder.WriteString(str) + builder.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 3ee4ba9f..0005f7ed 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -5,6 +5,7 @@ import ( "fmt" "regexp" "strconv" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "$" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'"', '"'} // "name" +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('"') + builder.WriteString(str) + builder.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 5f9d49df..91762343 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -2,6 +2,7 @@ package sqlite import ( "database/sql" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -38,8 +39,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'`', '`'} // `name` +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('`') + builder.WriteString(str) + builder.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/go.mod b/go.mod index cdb7e574..3e067d3c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,10 @@ module github.com/jinzhu/gorm go 1.13 require ( + github.com/denisenkom/go-mssqldb v0.0.0-20200206145737-bbfc9a55622e // indirect + github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 + github.com/lib/pq v1.3.0 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/gorm.go b/gorm.go index 2f10be60..eac95868 100644 --- a/gorm.go +++ b/gorm.go @@ -79,7 +79,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { if dialector != nil { err = dialector.Initialize(db) - db.quoteChars = dialector.QuoteChars() } return } diff --git a/interfaces.go b/interfaces.go index f0d14dd8..c89c3624 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "strings" "github.com/jinzhu/gorm/schema" ) @@ -13,7 +14,7 @@ type Dialector interface { Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string - QuoteChars() [2]byte + QuoteTo(*strings.Builder, string) Explain(sql string, vars ...interface{}) string } diff --git a/statement.go b/statement.go index bad83717..f04ea269 100644 --- a/statement.go +++ b/statement.go @@ -76,65 +76,58 @@ func (stmt *Statement) WriteByte(c byte) (err error) { return stmt.SQL.WriteByte(c) } -// WriteQuoted write quoted field -func (stmt *Statement) WriteQuoted(field interface{}) (err error) { - _, err = stmt.SQL.WriteString(stmt.Quote(field)) - return +// WriteQuoted write quoted value +func (stmt *Statement) WriteQuoted(value interface{}) error { + stmt.QuoteTo(&stmt.SQL, value) + return nil } -// Quote returns quoted value -func (stmt Statement) Quote(field interface{}) string { - var str strings.Builder - str.WriteByte(stmt.DB.quoteChars[0]) - +// QuoteTo write quoted value to writer +func (stmt Statement) QuoteTo(writer *strings.Builder, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { - str.WriteString(stmt.Table) + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) } else { - str.WriteString(v.Name) + stmt.DB.Dialector.QuoteTo(writer, v.Name) } if v.Alias != "" { - str.WriteByte(stmt.DB.quoteChars[1]) - str.WriteString(" AS ") - str.WriteByte(stmt.DB.quoteChars[0]) - str.WriteString(v.Alias) - str.WriteByte(stmt.DB.quoteChars[1]) + writer.WriteString(" AS ") + stmt.DB.Dialector.QuoteTo(writer, v.Alias) } case clause.Column: if v.Table != "" { if v.Table == clause.CurrentTable { - str.WriteString(stmt.Table) + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) } else { - str.WriteString(v.Table) + stmt.DB.Dialector.QuoteTo(writer, v.Table) } - str.WriteByte(stmt.DB.quoteChars[1]) - str.WriteByte('.') - str.WriteByte(stmt.DB.quoteChars[0]) + writer.WriteByte('.') } if v.Name == clause.PrimaryKey { if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { - str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName) + stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } } else { - str.WriteString(v.Name) + stmt.DB.Dialector.QuoteTo(writer, v.Name) } if v.Alias != "" { - str.WriteByte(stmt.DB.quoteChars[1]) - str.WriteString(" AS ") - str.WriteByte(stmt.DB.quoteChars[0]) - str.WriteString(v.Alias) - str.WriteByte(stmt.DB.quoteChars[1]) + writer.WriteString(" AS ") + stmt.DB.Dialector.QuoteTo(writer, v.Alias) } default: - str.WriteString(fmt.Sprint(field)) + stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) } +} - str.WriteByte(stmt.DB.quoteChars[1]) - return str.String() +// Quote returns quoted value +func (stmt Statement) Quote(field interface{}) string { + var builder strings.Builder + stmt.QuoteTo(&builder, field) + return builder.String() } // Write write string diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index 04d6248d..9e3146fe 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -1,6 +1,8 @@ package tests import ( + "strings" + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/schema" @@ -21,8 +23,10 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (DummyDialector) QuoteChars() [2]byte { - return [2]byte{'`', '`'} // `name` +func (DummyDialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('`') + builder.WriteString(str) + builder.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string {