Refactor clause Writer

This commit is contained in:
Jinzhu 2020-03-09 17:07:00 +08:00
parent 3aa1891068
commit 504f42760a
20 changed files with 117 additions and 108 deletions

View File

@ -12,13 +12,16 @@ type ClauseBuilder interface {
Build(Clause, Builder)
}
type Writer interface {
WriteByte(byte) error
WriteString(string) (int, error)
}
// Builder builder interface
type Builder interface {
WriteByte(byte) error
Write(sql ...string) error
Writer
WriteQuoted(field interface{}) error
AddVar(vars ...interface{}) string
Quote(field interface{}) string
AddVar(Writer, ...interface{})
}
// Clause

View File

@ -9,11 +9,11 @@ func (d Delete) Name() string {
}
func (d Delete) Build(builder Builder) {
builder.Write("DELETE")
builder.WriteString("DELETE")
if d.Modifier != "" {
builder.WriteByte(' ')
builder.Write(d.Modifier)
builder.WriteString(d.Modifier)
}
}

View File

@ -1,9 +1,5 @@
package clause
import (
"strings"
)
// Expression expression interface
type Expression interface {
Build(builder Builder)
@ -22,11 +18,15 @@ type Expr struct {
// Build build raw expression
func (expr Expr) Build(builder Builder) {
sql := expr.SQL
for _, v := range expr.Vars {
sql = strings.Replace(sql, "?", builder.AddVar(v), 1)
var idx int
for _, v := range []byte(expr.SQL) {
if v == '?' {
builder.AddVar(builder, expr.Vars[idx])
idx++
} else {
builder.WriteByte(v)
}
}
builder.Write(sql)
}
// IN Whether a value is within a set of values
@ -40,11 +40,14 @@ func (in IN) Build(builder Builder) {
switch len(in.Values) {
case 0:
builder.Write(" IN (NULL)")
builder.WriteString(" IN (NULL)")
case 1:
builder.Write(" = ", builder.AddVar(in.Values...))
builder.WriteString(" = ")
builder.AddVar(builder, in.Values...)
default:
builder.Write(" IN (", builder.AddVar(in.Values...), ")")
builder.WriteString(" IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
@ -52,9 +55,12 @@ func (in IN) NegationBuild(builder Builder) {
switch len(in.Values) {
case 0:
case 1:
builder.Write(" <> ", builder.AddVar(in.Values...))
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values...)
default:
builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")")
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
@ -68,9 +74,10 @@ func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
if eq.Value == nil {
builder.Write(" IS NULL")
builder.WriteString(" IS NULL")
} else {
builder.Write(" = ", builder.AddVar(eq.Value))
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
}
@ -85,9 +92,10 @@ func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
if neq.Value == nil {
builder.Write(" IS NOT NULL")
builder.WriteString(" IS NOT NULL")
} else {
builder.Write(" <> ", builder.AddVar(neq.Value))
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
}
@ -100,7 +108,8 @@ type Gt Eq
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
builder.Write(" > ", builder.AddVar(gt.Value))
builder.WriteString(" > ")
builder.AddVar(builder, gt.Value)
}
func (gt Gt) NegationBuild(builder Builder) {
@ -112,7 +121,8 @@ type Gte Eq
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
builder.Write(" >= ", builder.AddVar(gte.Value))
builder.WriteString(" >= ")
builder.AddVar(builder, gte.Value)
}
func (gte Gte) NegationBuild(builder Builder) {
@ -124,7 +134,8 @@ type Lt Eq
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
builder.Write(" < ", builder.AddVar(lt.Value))
builder.WriteString(" < ")
builder.AddVar(builder, lt.Value)
}
func (lt Lt) NegationBuild(builder Builder) {
@ -136,7 +147,8 @@ type Lte Eq
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
builder.Write(" <= ", builder.AddVar(lte.Value))
builder.WriteString(" <= ")
builder.AddVar(builder, lte.Value)
}
func (lte Lte) NegationBuild(builder Builder) {
@ -148,12 +160,14 @@ type Like Eq
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" LIKE ", builder.AddVar(like.Value))
builder.WriteString(" LIKE ")
builder.AddVar(builder, like.Value)
}
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value)
}
// Map

View File

@ -50,18 +50,18 @@ func (from From) Build(builder Builder) {
func (join Join) Build(builder Builder) {
if join.Type != "" {
builder.Write(string(join.Type))
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.Write("JOIN ")
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.Write(" ON ")
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.Write(" USING (")
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')

View File

@ -22,7 +22,7 @@ func (groupBy GroupBy) Build(builder Builder) {
}
if len(groupBy.Having) > 0 {
builder.Write(" HAVING ")
builder.WriteString(" HAVING ")
Where{Exprs: groupBy.Having}.Build(builder)
}
}

View File

@ -13,11 +13,11 @@ func (insert Insert) Name() string {
// Build build insert clause
func (insert Insert) Build(builder Builder) {
if insert.Modifier != "" {
builder.Write(insert.Modifier)
builder.WriteString(insert.Modifier)
builder.WriteByte(' ')
}
builder.Write("INTO ")
builder.WriteString("INTO ")
if insert.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {

View File

@ -16,12 +16,12 @@ func (limit Limit) Name() string {
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 {
builder.Write("LIMIT ")
builder.Write(strconv.Itoa(limit.Limit))
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit))
if limit.Offset > 0 {
builder.Write(" OFFSET ")
builder.Write(strconv.Itoa(limit.Offset))
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
}
}
}

View File

@ -22,16 +22,16 @@ func (f For) Build(builder Builder) {
builder.WriteByte(' ')
}
builder.Write("FOR ")
builder.Write(locking.Strength)
builder.WriteString("FOR ")
builder.WriteString(locking.Strength)
if locking.Table.Name != "" {
builder.Write(" OF ")
builder.WriteString(" OF ")
builder.WriteQuoted(locking.Table)
}
if locking.Options != "" {
builder.WriteByte(' ')
builder.Write(locking.Options)
builder.WriteString(locking.Options)
}
}
}

View File

@ -24,7 +24,7 @@ func (orderBy OrderBy) Build(builder Builder) {
builder.WriteQuoted(column.Column)
if column.Desc {
builder.Write(" DESC")
builder.WriteString(" DESC")
}
}
}

View File

@ -19,7 +19,7 @@ func (set Set) Build(builder Builder) {
}
builder.WriteQuoted(assignment.Column)
builder.WriteByte('=')
builder.Write(builder.AddVar(assignment.Value))
builder.AddVar(builder, assignment.Value)
}
} else {
builder.WriteQuoted(PrimaryColumn)

View File

@ -13,7 +13,7 @@ func (update Update) Name() string {
// Build build update clause
func (update Update) Build(builder Builder) {
if update.Modifier != "" {
builder.Write(update.Modifier)
builder.WriteString(update.Modifier)
builder.WriteByte(' ')
}

View File

@ -22,7 +22,7 @@ func (values Values) Build(builder Builder) {
}
builder.WriteByte(')')
builder.Write(" VALUES ")
builder.WriteString(" VALUES ")
for idx, value := range values.Values {
if idx > 0 {
@ -30,11 +30,11 @@ func (values Values) Build(builder Builder) {
}
builder.WriteByte('(')
builder.Write(builder.AddVar(value...))
builder.AddVar(builder, value...)
builder.WriteByte(')')
}
} else {
builder.Write("DEFAULT VALUES")
builder.WriteString("DEFAULT VALUES")
}
}

View File

@ -26,9 +26,9 @@ func (where Where) Build(builder Builder) {
if expr != nil {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.Write(" OR ")
builder.WriteString(" OR ")
} else {
builder.Write(" AND ")
builder.WriteString(" AND ")
}
}
@ -65,7 +65,7 @@ func (and AndConditions) Build(builder Builder) {
}
for idx, c := range and.Exprs {
if idx > 0 {
builder.Write(" AND ")
builder.WriteString(" AND ")
}
c.Build(builder)
}
@ -91,7 +91,7 @@ func (or OrConditions) Build(builder Builder) {
}
for idx, c := range or.Exprs {
if idx > 0 {
builder.Write(" OR ")
builder.WriteString(" OR ")
}
c.Build(builder)
}
@ -117,13 +117,13 @@ func (not NotConditions) Build(builder Builder) {
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.Write(" AND ")
builder.WriteString(" AND ")
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.Write(" NOT ")
builder.WriteString(" NOT ")
c.Build(builder)
}
}

View File

@ -5,11 +5,11 @@ import (
"fmt"
"regexp"
"strconv"
"strings"
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema"
@ -42,10 +42,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "@p" + strconv.Itoa(len(stmt.Vars))
}
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('"')
builder.WriteString(str)
builder.WriteByte('"')
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('"')
writer.WriteString(str)
writer.WriteByte('"')
}
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")

View File

@ -4,11 +4,11 @@ import (
"database/sql"
"fmt"
"math"
"strings"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema"
@ -40,10 +40,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
writer.WriteString(str)
writer.WriteByte('`')
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {

View File

@ -5,10 +5,10 @@ import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema"
@ -42,10 +42,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "$" + strconv.Itoa(len(stmt.Vars))
}
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('"')
builder.WriteString(str)
builder.WriteByte('"')
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('"')
writer.WriteString(str)
writer.WriteByte('"')
}
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")

View File

@ -2,10 +2,10 @@ package sqlite
import (
"database/sql"
"strings"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/migrator"
"github.com/jinzhu/gorm/schema"
@ -39,10 +39,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
writer.WriteString(str)
writer.WriteByte('`')
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {

View File

@ -3,8 +3,8 @@ package gorm
import (
"context"
"database/sql"
"strings"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
)
@ -14,7 +14,7 @@ type Dialector interface {
Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string
BindVar(stmt *Statement, v interface{}) string
QuoteTo(*strings.Builder, string)
QuoteTo(clause.Writer, string)
Explain(sql string, vars ...interface{}) string
}

View File

@ -34,7 +34,6 @@ type Statement struct {
SQL strings.Builder
Vars []interface{}
NamedVars []sql.NamedArg
placeholders strings.Builder
}
// StatementOptimizer statement optimizer interface
@ -43,15 +42,12 @@ type StatementOptimizer interface {
}
// Write write string
func (stmt *Statement) Write(sql ...string) (err error) {
for _, s := range sql {
_, err = stmt.SQL.WriteString(s)
}
return
func (stmt *Statement) WriteString(str string) (int, error) {
return stmt.SQL.WriteString(str)
}
// Write write string
func (stmt *Statement) WriteByte(c byte) (err error) {
func (stmt *Statement) WriteByte(c byte) error {
return stmt.SQL.WriteByte(c)
}
@ -62,7 +58,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error {
}
// QuoteTo write quoted value to writer
func (stmt Statement) QuoteTo(writer *strings.Builder, field interface{}) {
func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
switch v := field.(type) {
case clause.Table:
if v.Name == clause.CurrentTable {
@ -110,44 +106,41 @@ func (stmt Statement) Quote(field interface{}) string {
}
// Write write string
func (stmt *Statement) AddVar(vars ...interface{}) string {
stmt.placeholders = strings.Builder{}
stmt.placeholders.Reset()
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
for idx, v := range vars {
if idx > 0 {
stmt.placeholders.WriteByte(',')
writer.WriteByte(',')
}
switch v := v.(type) {
case sql.NamedArg:
if len(v.Name) > 0 {
stmt.NamedVars = append(stmt.NamedVars, v)
stmt.placeholders.WriteByte('@')
stmt.placeholders.WriteString(v.Name)
writer.WriteByte('@')
writer.WriteString(v.Name)
} else {
stmt.Vars = append(stmt.Vars, v.Value)
stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
}
case clause.Column, clause.Table:
stmt.placeholders.WriteString(stmt.Quote(v))
stmt.QuoteTo(writer, v)
case clause.Expr:
stmt.placeholders.WriteString(v.SQL)
writer.WriteString(v.SQL)
stmt.Vars = append(stmt.Vars, v.Vars...)
case []interface{}:
if len(v) > 0 {
stmt.placeholders.WriteByte('(')
stmt.placeholders.WriteString(stmt.AddVar(v...))
stmt.placeholders.WriteByte(')')
writer.WriteByte('(')
stmt.skipResetPlacehodler = true
stmt.AddVar(writer, v...)
writer.WriteByte(')')
} else {
stmt.placeholders.WriteString("(NULL)")
writer.WriteString("(NULL)")
}
default:
stmt.Vars = append(stmt.Vars, v)
stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
}
}
return stmt.placeholders.String()
}
// AddClause add clause

View File

@ -1,9 +1,8 @@
package tests
import (
"strings"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/schema"
)
@ -23,10 +22,10 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
func (DummyDialector) QuoteTo(builder *strings.Builder, str string) {
builder.WriteByte('`')
builder.WriteString(str)
builder.WriteByte('`')
func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
writer.WriteString(str)
writer.WriteByte('`')
}
func (DummyDialector) Explain(sql string, vars ...interface{}) string {