diff --git a/README.md b/README.md index 140c0d28..b51297c4 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statment Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map -* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg * Composite Primary Key * Auto Migrations * Logger diff --git a/callbacks.go b/callbacks.go index 5e7933af..c917a678 100644 --- a/callbacks.go +++ b/callbacks.go @@ -107,7 +107,6 @@ func (p *processor) Execute(db *DB) { if !stmt.DB.DryRun { stmt.SQL.Reset() stmt.Vars = nil - stmt.NamedVars = nil } } diff --git a/chainable_api.go b/chainable_api.go index acceb58f..3e509f12 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -265,6 +265,11 @@ func (db *DB) Unscoped() (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} - clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } return } diff --git a/clause/expression.go b/clause/expression.go index ecf8ba85..4d5e328b 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,6 +1,7 @@ package clause import ( + "database/sql" "database/sql/driver" "reflect" ) @@ -62,6 +63,64 @@ func (expr Expr) Build(builder Builder) { } } +// NamedExpr raw expression for named expr +type NamedExpr struct { + SQL string + Vars []interface{} +} + +// Build build raw expression +func (expr NamedExpr) Build(builder Builder) { + var ( + idx int + inName bool + namedMap = make(map[string]interface{}, len(expr.Vars)) + ) + + for _, v := range expr.Vars { + switch value := v.(type) { + case sql.NamedArg: + namedMap[value.Name] = value.Value + case map[string]interface{}: + for k, v := range value { + namedMap[k] = v + } + } + } + + name := make([]byte, 0, 10) + + for _, v := range []byte(expr.SQL) { + if v == '@' && !inName { + inName = true + name = []byte{} + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + inName = false + } + + builder.WriteByte(v) + } else if v == '?' { + builder.AddVar(builder, expr.Vars[idx]) + idx++ + } else if inName { + name = append(name, v) + } else { + builder.WriteByte(v) + } + } + + if inName { + builder.AddVar(builder, namedMap[string(name)]) + } +} + // IN Whether a value is within a set of values type IN struct { Column interface{} diff --git a/clause/expression_test.go b/clause/expression_test.go index 3059aea6..17af737d 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -1,7 +1,9 @@ package clause_test import ( + "database/sql" "fmt" + "reflect" "sync" "testing" @@ -33,3 +35,51 @@ func TestExpr(t *testing.T) { }) } } + +func TestNamedExpr(t *testing.T) { + results := []struct { + SQL string + Result string + Vars []interface{} + ExpectedVars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }, { + SQL: "name1 = @name AND name2 = @name", + Vars: []interface{}{sql.Named("name", "jinzhu")}, + Result: "name1 = ? AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { + t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) + } + }) + } +} diff --git a/finisher_api.go b/finisher_api.go index 25c56e49..d70b3cd0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -453,7 +453,13 @@ func (db *DB) RollbackTo(name string) *DB { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} - clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + tx.callbacks.Raw().Execute(tx) return } diff --git a/statement.go b/statement.go index 036b8297..00feeac5 100644 --- a/statement.go +++ b/statement.go @@ -38,7 +38,6 @@ type Statement struct { UpdatingColumn bool SQL strings.Builder Vars []interface{} - NamedVars []sql.NamedArg CurDestIndex int attrs []interface{} assigns []interface{} @@ -148,14 +147,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { switch v := v.(type) { case sql.NamedArg: - if len(v.Name) > 0 { - stmt.NamedVars = append(stmt.NamedVars, v) - writer.WriteByte('@') - writer.WriteString(v.Name) - } else { - stmt.Vars = append(stmt.Vars, v.Value) - stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value) - } + stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) case clause.Expr: @@ -234,16 +226,19 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { // BuildCondition build condition func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { - if sql, ok := query.(string); ok { + if s, ok := query.(string); ok { // if it is a number, then treats it as primary key - if _, err := strconv.Atoi(sql); err != nil { - if sql == "" && len(args) == 0 { + if _, err := strconv.Atoi(s); err != nil { + if s == "" && len(args) == 0 { return - } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition - return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } else if len(args) > 0 && strings.Contains(s, "@") { + // looks like a named query + return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} } else if len(args) == 1 { - return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} + return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} } } } diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go new file mode 100644 index 00000000..60f5a535 --- /dev/null +++ b/tests/named_argument_test.go @@ -0,0 +1,57 @@ +package tests_test + +import ( + "database/sql" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestNamedArg(t *testing.T) { + type NamedUser struct { + gorm.Model + Name1 string + Name2 string + Name3 string + } + + DB.Migrator().DropTable(&NamedUser{}) + DB.AutoMigrate(&NamedUser{}) + + namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} + DB.Create(&namedUser) + + var result NamedUser + DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) + + AssertEqual(t, result, namedUser) + + var result2 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) + + AssertEqual(t, result2, namedUser) + + var result3 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) + + AssertEqual(t, result3, namedUser) + + var result4 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result4, namedUser) + + if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + var result5 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result4, namedUser) +}