Add NamedArg support

This commit is contained in:
Jinzhu 2020-07-10 12:28:24 +08:00
parent bc3728a18f
commit bba569af2b
8 changed files with 190 additions and 19 deletions

View File

@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statment Mode, DryRun Mode * Context, Prepared Statment Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map * Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
* Composite Primary Key * Composite Primary Key
* Auto Migrations * Auto Migrations
* Logger * Logger

View File

@ -107,7 +107,6 @@ func (p *processor) Execute(db *DB) {
if !stmt.DB.DryRun { if !stmt.DB.DryRun {
stmt.SQL.Reset() stmt.SQL.Reset()
stmt.Vars = nil stmt.Vars = nil
stmt.NamedVars = nil
} }
} }

View File

@ -265,6 +265,11 @@ func (db *DB) Unscoped() (tx *DB) {
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return return
} }

View File

@ -1,6 +1,7 @@
package clause package clause
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"reflect" "reflect"
) )
@ -62,6 +63,64 @@ func (expr Expr) Build(builder Builder) {
} }
} }
// NamedExpr raw expression for named expr
type NamedExpr struct {
SQL string
Vars []interface{}
}
// Build build raw expression
func (expr NamedExpr) Build(builder Builder) {
var (
idx int
inName bool
namedMap = make(map[string]interface{}, len(expr.Vars))
)
for _, v := range expr.Vars {
switch value := v.(type) {
case sql.NamedArg:
namedMap[value.Name] = value.Value
case map[string]interface{}:
for k, v := range value {
namedMap[k] = v
}
}
}
name := make([]byte, 0, 10)
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
inName = false
}
builder.WriteByte(v)
} else if v == '?' {
builder.AddVar(builder, expr.Vars[idx])
idx++
} else if inName {
name = append(name, v)
} else {
builder.WriteByte(v)
}
}
if inName {
builder.AddVar(builder, namedMap[string(name)])
}
}
// IN Whether a value is within a set of values // IN Whether a value is within a set of values
type IN struct { type IN struct {
Column interface{} Column interface{}

View File

@ -1,7 +1,9 @@
package clause_test package clause_test
import ( import (
"database/sql"
"fmt" "fmt"
"reflect"
"sync" "sync"
"testing" "testing"
@ -33,3 +35,51 @@ func TestExpr(t *testing.T) {
}) })
} }
} }
func TestNamedExpr(t *testing.T) {
results := []struct {
SQL string
Result string
Vars []interface{}
ExpectedVars []interface{}
}{{
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
Result: "create table `users` (`id` int, `name` text)",
}, {
SQL: "name1 = @name AND name2 = @name",
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
}}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
}
})
}
}

View File

@ -453,7 +453,13 @@ func (db *DB) RollbackTo(name string) *DB {
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
tx.callbacks.Raw().Execute(tx) tx.callbacks.Raw().Execute(tx)
return return
} }

View File

@ -38,7 +38,6 @@ type Statement struct {
UpdatingColumn bool UpdatingColumn bool
SQL strings.Builder SQL strings.Builder
Vars []interface{} Vars []interface{}
NamedVars []sql.NamedArg
CurDestIndex int CurDestIndex int
attrs []interface{} attrs []interface{}
assigns []interface{} assigns []interface{}
@ -148,14 +147,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
switch v := v.(type) { switch v := v.(type) {
case sql.NamedArg: case sql.NamedArg:
if len(v.Name) > 0 {
stmt.NamedVars = append(stmt.NamedVars, v)
writer.WriteByte('@')
writer.WriteString(v.Name)
} else {
stmt.Vars = append(stmt.Vars, v.Value) stmt.Vars = append(stmt.Vars, v.Value)
stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value)
}
case clause.Column, clause.Table: case clause.Column, clause.Table:
stmt.QuoteTo(writer, v) stmt.QuoteTo(writer, v)
case clause.Expr: case clause.Expr:
@ -234,16 +226,19 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
// BuildCondition build condition // BuildCondition build condition
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
if sql, ok := query.(string); ok { if s, ok := query.(string); ok {
// if it is a number, then treats it as primary key // if it is a number, then treats it as primary key
if _, err := strconv.Atoi(sql); err != nil { if _, err := strconv.Atoi(s); err != nil {
if sql == "" && len(args) == 0 { if s == "" && len(args) == 0 {
return return
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition // looks like a where condition
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
} else if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
} else if len(args) == 1 { } else if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
} }
} }
} }

View File

@ -0,0 +1,57 @@
package tests_test
import (
"database/sql"
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
func TestNamedArg(t *testing.T) {
type NamedUser struct {
gorm.Model
Name1 string
Name2 string
Name3 string
}
DB.Migrator().DropTable(&NamedUser{})
DB.AutoMigrate(&NamedUser{})
namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"}
DB.Create(&namedUser)
var result NamedUser
DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2"))
AssertEqual(t, result, namedUser)
var result2 NamedUser
DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2)
AssertEqual(t, result2, namedUser)
var result3 NamedUser
DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3)
AssertEqual(t, result3, namedUser)
var result4 NamedUser
if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil {
t.Errorf("failed to update with named arg")
}
AssertEqual(t, result4, namedUser)
if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil {
t.Errorf("failed to update with named arg")
}
var result5 NamedUser
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil {
t.Errorf("failed to update with named arg")
}
AssertEqual(t, result4, namedUser)
}