forked from mirror/gorm
Fix sub query argument order with multiple raw SQL
This commit is contained in:
parent
df24821896
commit
84ea3ec0cc
26
statement.go
26
statement.go
|
@ -182,8 +182,32 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||||
}
|
}
|
||||||
case *DB:
|
case *DB:
|
||||||
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||||
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
|
if v.Statement.SQL.Len() > 0 {
|
||||||
|
var (
|
||||||
|
vars = subdb.Statement.Vars
|
||||||
|
sql = v.Statement.SQL.String()
|
||||||
|
)
|
||||||
|
|
||||||
|
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||||
|
for _, vv := range vars {
|
||||||
|
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||||
|
bindvar := strings.Builder{}
|
||||||
|
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||||
|
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
subdb.Statement.SQL.Reset()
|
||||||
|
subdb.Statement.Vars = stmt.Vars
|
||||||
|
if strings.Contains(sql, "@") {
|
||||||
|
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||||
|
} else {
|
||||||
|
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||||
subdb.callbacks.Query().Execute(subdb)
|
subdb.callbacks.Query().Execute(subdb)
|
||||||
|
}
|
||||||
|
|
||||||
writer.WriteString(subdb.Statement.SQL.String())
|
writer.WriteString(subdb.Statement.SQL.String())
|
||||||
stmt.Vars = subdb.Statement.Vars
|
stmt.Vars = subdb.Statement.Vars
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -991,13 +991,13 @@ func TestSubQueryWithRaw(t *testing.T) {
|
||||||
DB.Create(&users)
|
DB.Create(&users)
|
||||||
|
|
||||||
var count int64
|
var count int64
|
||||||
err := DB.Raw("select count(*) from (?) tmp", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_3"})).Scan(&count).Error
|
err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected to get no errors, but got %v", err)
|
t.Errorf("Expected to get no errors, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if count != 2 {
|
if count != 2 {
|
||||||
t.Errorf("Row count must be 1, instead got %d", count)
|
t.Errorf("Row count must be 2, instead got %d", count)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DB.Raw("select count(*) from (?) tmp",
|
err = DB.Raw("select count(*) from (?) tmp",
|
||||||
|
|
Loading…
Reference in New Issue