forked from mirror/gorm
Add NamedArg support
This commit is contained in:
parent
bc3728a18f
commit
bba569af2b
|
@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||||
* Context, Prepared Statment Mode, DryRun Mode
|
* Context, Prepared Statment Mode, DryRun Mode
|
||||||
* Batch Insert, FindInBatches, Find To Map
|
* Batch Insert, FindInBatches, Find To Map
|
||||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints
|
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
|
||||||
* Composite Primary Key
|
* Composite Primary Key
|
||||||
* Auto Migrations
|
* Auto Migrations
|
||||||
* Logger
|
* Logger
|
||||||
|
|
|
@ -107,7 +107,6 @@ func (p *processor) Execute(db *DB) {
|
||||||
if !stmt.DB.DryRun {
|
if !stmt.DB.DryRun {
|
||||||
stmt.SQL.Reset()
|
stmt.SQL.Reset()
|
||||||
stmt.Vars = nil
|
stmt.Vars = nil
|
||||||
stmt.NamedVars = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -265,6 +265,11 @@ func (db *DB) Unscoped() (tx *DB) {
|
||||||
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.SQL = strings.Builder{}
|
tx.Statement.SQL = strings.Builder{}
|
||||||
|
|
||||||
|
if strings.Contains(sql, "@") {
|
||||||
|
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||||
|
} else {
|
||||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package clause
|
package clause
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
@ -62,6 +63,64 @@ func (expr Expr) Build(builder Builder) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NamedExpr raw expression for named expr
|
||||||
|
type NamedExpr struct {
|
||||||
|
SQL string
|
||||||
|
Vars []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build build raw expression
|
||||||
|
func (expr NamedExpr) Build(builder Builder) {
|
||||||
|
var (
|
||||||
|
idx int
|
||||||
|
inName bool
|
||||||
|
namedMap = make(map[string]interface{}, len(expr.Vars))
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, v := range expr.Vars {
|
||||||
|
switch value := v.(type) {
|
||||||
|
case sql.NamedArg:
|
||||||
|
namedMap[value.Name] = value.Value
|
||||||
|
case map[string]interface{}:
|
||||||
|
for k, v := range value {
|
||||||
|
namedMap[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
name := make([]byte, 0, 10)
|
||||||
|
|
||||||
|
for _, v := range []byte(expr.SQL) {
|
||||||
|
if v == '@' && !inName {
|
||||||
|
inName = true
|
||||||
|
name = []byte{}
|
||||||
|
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' {
|
||||||
|
if inName {
|
||||||
|
if nv, ok := namedMap[string(name)]; ok {
|
||||||
|
builder.AddVar(builder, nv)
|
||||||
|
} else {
|
||||||
|
builder.WriteByte('@')
|
||||||
|
builder.WriteString(string(name))
|
||||||
|
}
|
||||||
|
inName = false
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteByte(v)
|
||||||
|
} else if v == '?' {
|
||||||
|
builder.AddVar(builder, expr.Vars[idx])
|
||||||
|
idx++
|
||||||
|
} else if inName {
|
||||||
|
name = append(name, v)
|
||||||
|
} else {
|
||||||
|
builder.WriteByte(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if inName {
|
||||||
|
builder.AddVar(builder, namedMap[string(name)])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// IN Whether a value is within a set of values
|
// IN Whether a value is within a set of values
|
||||||
type IN struct {
|
type IN struct {
|
||||||
Column interface{}
|
Column interface{}
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package clause_test
|
package clause_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -33,3 +35,51 @@ func TestExpr(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNamedExpr(t *testing.T) {
|
||||||
|
results := []struct {
|
||||||
|
SQL string
|
||||||
|
Result string
|
||||||
|
Vars []interface{}
|
||||||
|
ExpectedVars []interface{}
|
||||||
|
}{{
|
||||||
|
SQL: "create table ? (? ?, ? ?)",
|
||||||
|
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
|
||||||
|
Result: "create table `users` (`id` int, `name` text)",
|
||||||
|
}, {
|
||||||
|
SQL: "name1 = @name AND name2 = @name",
|
||||||
|
Vars: []interface{}{sql.Named("name", "jinzhu")},
|
||||||
|
Result: "name1 = ? AND name2 = ?",
|
||||||
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||||
|
}, {
|
||||||
|
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
|
||||||
|
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||||
|
Result: "name1 = ? AND name2 = ? AND name3 = ?",
|
||||||
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
|
||||||
|
}, {
|
||||||
|
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
|
||||||
|
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}},
|
||||||
|
Result: "name1 = ? AND name2 = ? AND name3 = ?",
|
||||||
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
|
||||||
|
}, {
|
||||||
|
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
|
||||||
|
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||||
|
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||||
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for idx, result := range results {
|
||||||
|
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||||
|
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||||
|
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||||
|
clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
|
||||||
|
if stmt.SQL.String() != result.Result {
|
||||||
|
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -453,7 +453,13 @@ func (db *DB) RollbackTo(name string) *DB {
|
||||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.SQL = strings.Builder{}
|
tx.Statement.SQL = strings.Builder{}
|
||||||
|
|
||||||
|
if strings.Contains(sql, "@") {
|
||||||
|
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||||
|
} else {
|
||||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||||
|
}
|
||||||
|
|
||||||
tx.callbacks.Raw().Execute(tx)
|
tx.callbacks.Raw().Execute(tx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
23
statement.go
23
statement.go
|
@ -38,7 +38,6 @@ type Statement struct {
|
||||||
UpdatingColumn bool
|
UpdatingColumn bool
|
||||||
SQL strings.Builder
|
SQL strings.Builder
|
||||||
Vars []interface{}
|
Vars []interface{}
|
||||||
NamedVars []sql.NamedArg
|
|
||||||
CurDestIndex int
|
CurDestIndex int
|
||||||
attrs []interface{}
|
attrs []interface{}
|
||||||
assigns []interface{}
|
assigns []interface{}
|
||||||
|
@ -148,14 +147,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||||
|
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
case sql.NamedArg:
|
case sql.NamedArg:
|
||||||
if len(v.Name) > 0 {
|
|
||||||
stmt.NamedVars = append(stmt.NamedVars, v)
|
|
||||||
writer.WriteByte('@')
|
|
||||||
writer.WriteString(v.Name)
|
|
||||||
} else {
|
|
||||||
stmt.Vars = append(stmt.Vars, v.Value)
|
stmt.Vars = append(stmt.Vars, v.Value)
|
||||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value)
|
|
||||||
}
|
|
||||||
case clause.Column, clause.Table:
|
case clause.Column, clause.Table:
|
||||||
stmt.QuoteTo(writer, v)
|
stmt.QuoteTo(writer, v)
|
||||||
case clause.Expr:
|
case clause.Expr:
|
||||||
|
@ -234,16 +226,19 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
||||||
|
|
||||||
// BuildCondition build condition
|
// BuildCondition build condition
|
||||||
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
||||||
if sql, ok := query.(string); ok {
|
if s, ok := query.(string); ok {
|
||||||
// if it is a number, then treats it as primary key
|
// if it is a number, then treats it as primary key
|
||||||
if _, err := strconv.Atoi(sql); err != nil {
|
if _, err := strconv.Atoi(s); err != nil {
|
||||||
if sql == "" && len(args) == 0 {
|
if s == "" && len(args) == 0 {
|
||||||
return
|
return
|
||||||
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
|
||||||
// looks like a where condition
|
// looks like a where condition
|
||||||
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
|
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
||||||
|
} else if len(args) > 0 && strings.Contains(s, "@") {
|
||||||
|
// looks like a named query
|
||||||
|
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
|
||||||
} else if len(args) == 1 {
|
} else if len(args) == 1 {
|
||||||
return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}}
|
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNamedArg(t *testing.T) {
|
||||||
|
type NamedUser struct {
|
||||||
|
gorm.Model
|
||||||
|
Name1 string
|
||||||
|
Name2 string
|
||||||
|
Name3 string
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&NamedUser{})
|
||||||
|
DB.AutoMigrate(&NamedUser{})
|
||||||
|
|
||||||
|
namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"}
|
||||||
|
DB.Create(&namedUser)
|
||||||
|
|
||||||
|
var result NamedUser
|
||||||
|
DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2"))
|
||||||
|
|
||||||
|
AssertEqual(t, result, namedUser)
|
||||||
|
|
||||||
|
var result2 NamedUser
|
||||||
|
DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2)
|
||||||
|
|
||||||
|
AssertEqual(t, result2, namedUser)
|
||||||
|
|
||||||
|
var result3 NamedUser
|
||||||
|
DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3)
|
||||||
|
|
||||||
|
AssertEqual(t, result3, namedUser)
|
||||||
|
|
||||||
|
var result4 NamedUser
|
||||||
|
if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil {
|
||||||
|
t.Errorf("failed to update with named arg")
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result4, namedUser)
|
||||||
|
|
||||||
|
if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil {
|
||||||
|
t.Errorf("failed to update with named arg")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result5 NamedUser
|
||||||
|
if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil {
|
||||||
|
t.Errorf("failed to update with named arg")
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result4, namedUser)
|
||||||
|
}
|
Loading…
Reference in New Issue