forked from mirror/gorm
Add NamedArg support
This commit is contained in:
parent
bc3728a18f
commit
bba569af2b
|
@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||
* Context, Prepared Statment Mode, DryRun Mode
|
||||
* Batch Insert, FindInBatches, Find To Map
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
|
||||
* Composite Primary Key
|
||||
* Auto Migrations
|
||||
* Logger
|
||||
|
|
|
@ -107,7 +107,6 @@ func (p *processor) Execute(db *DB) {
|
|||
if !stmt.DB.DryRun {
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
stmt.NamedVars = nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -265,6 +265,11 @@ func (db *DB) Unscoped() (tx *DB) {
|
|||
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package clause
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
)
|
||||
|
@ -62,6 +63,64 @@ func (expr Expr) Build(builder Builder) {
|
|||
}
|
||||
}
|
||||
|
||||
// NamedExpr raw expression for named expr
|
||||
type NamedExpr struct {
|
||||
SQL string
|
||||
Vars []interface{}
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr NamedExpr) Build(builder Builder) {
|
||||
var (
|
||||
idx int
|
||||
inName bool
|
||||
namedMap = make(map[string]interface{}, len(expr.Vars))
|
||||
)
|
||||
|
||||
for _, v := range expr.Vars {
|
||||
switch value := v.(type) {
|
||||
case sql.NamedArg:
|
||||
namedMap[value.Name] = value.Value
|
||||
case map[string]interface{}:
|
||||
for k, v := range value {
|
||||
namedMap[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
name := make([]byte, 0, 10)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '@' && !inName {
|
||||
inName = true
|
||||
name = []byte{}
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' {
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
} else {
|
||||
builder.WriteByte('@')
|
||||
builder.WriteString(string(name))
|
||||
}
|
||||
inName = false
|
||||
}
|
||||
|
||||
builder.WriteByte(v)
|
||||
} else if v == '?' {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
idx++
|
||||
} else if inName {
|
||||
name = append(name, v)
|
||||
} else {
|
||||
builder.WriteByte(v)
|
||||
}
|
||||
}
|
||||
|
||||
if inName {
|
||||
builder.AddVar(builder, namedMap[string(name)])
|
||||
}
|
||||
}
|
||||
|
||||
// IN Whether a value is within a set of values
|
||||
type IN struct {
|
||||
Column interface{}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package clause_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
|
@ -33,3 +35,51 @@ func TestExpr(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamedExpr(t *testing.T) {
|
||||
results := []struct {
|
||||
SQL string
|
||||
Result string
|
||||
Vars []interface{}
|
||||
ExpectedVars []interface{}
|
||||
}{{
|
||||
SQL: "create table ? (? ?, ? ?)",
|
||||
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
|
||||
Result: "create table `users` (`id` int, `name` text)",
|
||||
}, {
|
||||
SQL: "name1 = @name AND name2 = @name",
|
||||
Vars: []interface{}{sql.Named("name", "jinzhu")},
|
||||
Result: "name1 = ? AND name2 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||
}, {
|
||||
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
|
||||
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||
Result: "name1 = ? AND name2 = ? AND name3 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
|
||||
}, {
|
||||
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
|
||||
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}},
|
||||
Result: "name1 = ? AND name2 = ? AND name3 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
|
||||
}, {
|
||||
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
|
||||
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||
}}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
|
||||
if stmt.SQL.String() != result.Result {
|
||||
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
|
||||
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -453,7 +453,13 @@ func (db *DB) RollbackTo(name string) *DB {
|
|||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
|
||||
tx.callbacks.Raw().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
|
23
statement.go
23
statement.go
|
@ -38,7 +38,6 @@ type Statement struct {
|
|||
UpdatingColumn bool
|
||||
SQL strings.Builder
|
||||
Vars []interface{}
|
||||
NamedVars []sql.NamedArg
|
||||
CurDestIndex int
|
||||
attrs []interface{}
|
||||
assigns []interface{}
|
||||
|
@ -148,14 +147,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||
|
||||
switch v := v.(type) {
|
||||
case sql.NamedArg:
|
||||
if len(v.Name) > 0 {
|
||||
stmt.NamedVars = append(stmt.NamedVars, v)
|
||||
writer.WriteByte('@')
|
||||
writer.WriteString(v.Name)
|
||||
} else {
|
||||
stmt.Vars = append(stmt.Vars, v.Value)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value)
|
||||
}
|
||||
case clause.Column, clause.Table:
|
||||
stmt.QuoteTo(writer, v)
|
||||
case clause.Expr:
|
||||
|
@ -234,16 +226,19 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
|||
|
||||
// BuildCondition build condition
|
||||
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
||||
if sql, ok := query.(string); ok {
|
||||
if s, ok := query.(string); ok {
|
||||
// if it is a number, then treats it as primary key
|
||||
if _, err := strconv.Atoi(sql); err != nil {
|
||||
if sql == "" && len(args) == 0 {
|
||||
if _, err := strconv.Atoi(s); err != nil {
|
||||
if s == "" && len(args) == 0 {
|
||||
return
|
||||
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
||||
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
|
||||
// looks like a where condition
|
||||
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
|
||||
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
||||
} else if len(args) > 0 && strings.Contains(s, "@") {
|
||||
// looks like a named query
|
||||
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
|
||||
} else if len(args) == 1 {
|
||||
return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}}
|
||||
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestNamedArg(t *testing.T) {
|
||||
type NamedUser struct {
|
||||
gorm.Model
|
||||
Name1 string
|
||||
Name2 string
|
||||
Name3 string
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&NamedUser{})
|
||||
DB.AutoMigrate(&NamedUser{})
|
||||
|
||||
namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"}
|
||||
DB.Create(&namedUser)
|
||||
|
||||
var result NamedUser
|
||||
DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2"))
|
||||
|
||||
AssertEqual(t, result, namedUser)
|
||||
|
||||
var result2 NamedUser
|
||||
DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2)
|
||||
|
||||
AssertEqual(t, result2, namedUser)
|
||||
|
||||
var result3 NamedUser
|
||||
DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3)
|
||||
|
||||
AssertEqual(t, result3, namedUser)
|
||||
|
||||
var result4 NamedUser
|
||||
if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil {
|
||||
t.Errorf("failed to update with named arg")
|
||||
}
|
||||
|
||||
AssertEqual(t, result4, namedUser)
|
||||
|
||||
if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil {
|
||||
t.Errorf("failed to update with named arg")
|
||||
}
|
||||
|
||||
var result5 NamedUser
|
||||
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil {
|
||||
t.Errorf("failed to update with named arg")
|
||||
}
|
||||
|
||||
AssertEqual(t, result4, namedUser)
|
||||
}
|
Loading…
Reference in New Issue