Fix sub query argument order with multiple raw SQL

This commit is contained in:
Jinzhu 2021-02-09 18:56:13 +08:00
parent df24821896
commit 84ea3ec0cc
2 changed files with 28 additions and 4 deletions

View File

@ -182,8 +182,32 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
if v.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = v.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
writer.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
default:

View File

@ -991,13 +991,13 @@ func TestSubQueryWithRaw(t *testing.T) {
DB.Create(&users)
var count int64
err := DB.Raw("select count(*) from (?) tmp", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_3"})).Scan(&count).Error
err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error
if err != nil {
t.Errorf("Expected to get no errors, but got %v", err)
}
if count != 2 {
t.Errorf("Row count must be 1, instead got %d", count)
t.Errorf("Row count must be 2, instead got %d", count)
}
err = DB.Raw("select count(*) from (?) tmp",