diff --git a/statement.go b/statement.go index 6ea8c883..aac4f073 100644 --- a/statement.go +++ b/statement.go @@ -182,8 +182,32 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } case *DB: subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() - subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) - subdb.callbacks.Query().Execute(subdb) + if v.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = v.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + writer.WriteString(subdb.Statement.SQL.String()) stmt.Vars = subdb.Statement.Vars default: diff --git a/tests/query_test.go b/tests/query_test.go index 8ed02c98..be6768b1 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -991,13 +991,13 @@ func TestSubQueryWithRaw(t *testing.T) { DB.Create(&users) var count int64 - err := DB.Raw("select count(*) from (?) tmp", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_3"})).Scan(&count).Error + err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } if count != 2 { - t.Errorf("Row count must be 1, instead got %d", count) + t.Errorf("Row count must be 2, instead got %d", count) } err = DB.Raw("select count(*) from (?) tmp",