diff --git a/callbacks.go b/callbacks.go index d05947d9..d3cd8e62 100644 --- a/callbacks.go +++ b/callbacks.go @@ -80,7 +80,7 @@ func (p *processor) Execute(db *DB) { } if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { db.AddError(err) } } diff --git a/callbacks/create.go b/callbacks/create.go index 7a2b8bfe..01329141 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -63,36 +63,38 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { - if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ + if err == nil { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { + if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + db.AddError(err) } - } else { - db.AddError(err) } } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) } - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) } } } @@ -135,42 +137,44 @@ func CreateWithReturning(db *gorm.DB) { idx++ } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - defer rows.Close() + if err == nil { + defer rows.Close() - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for rows.Next() { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + db.RowsAffected++ + } + case reflect.Struct: for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } - db.RowsAffected++ - } - case reflect.Struct: - for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - if rows.Next() { - db.RowsAffected++ - err = rows.Scan(values...) + if rows.Next() { + db.RowsAffected++ + err = rows.Scan(values...) + } } } - } - if err != nil { - db.AddError(err) - } - } else { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() + if err != nil { + db.AddError(err) + } } else { - db.AddError(err) + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 582a76f4..451569cf 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -72,12 +72,14 @@ func Delete(db *gorm.DB) { db.Statement.Build("DELETE", "FROM", "WHERE") } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } } diff --git a/callbacks/query.go b/callbacks/query.go index e4e76665..f7c3271f 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -23,14 +23,16 @@ func Query(db *gorm.DB) { BuildQuerySQL(db) } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err != nil { - db.AddError(err) - return - } - defer rows.Close() + if !db.DryRun { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer rows.Close() - gorm.Scan(rows, db, false) + gorm.Scan(rows, db, false) + } } } diff --git a/callbacks/update.go b/callbacks/update.go index fda07676..a52bd310 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -85,12 +85,14 @@ func Update(db *gorm.DB) { return } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } } diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index ebdeeab0..6820bb7b 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -61,41 +61,43 @@ func Create(db *gorm.DB) { } } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - defer rows.Close() + if err == nil { + defer rows.Close() - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) + + for rows.Next() { + for idx, field := range db.Statement.Schema.PrimaryFields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } + } + case reflect.Struct: + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - for rows.Next() { for idx, field := range db.Statement.Schema.PrimaryFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - } - } - case reflect.Struct: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - - for idx, field := range db.Statement.Schema.PrimaryFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values...)) + if rows.Next() { + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } } } + } else { + db.AddError(err) } - } else { - db.AddError(err) } } diff --git a/gorm.go b/gorm.go index c1d6f8da..7d6bd2ed 100644 --- a/gorm.go +++ b/gorm.go @@ -22,6 +22,8 @@ type Config struct { Logger logger.Interface // NowFunc the function to be used when creating a new timestamp NowFunc func() time.Time + // DryRun generate sql without execute + DryRun bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -45,6 +47,7 @@ type DB struct { // Session session config when create session with Session() method type Session struct { + DryRun bool WithConditions bool Context context.Context Logger logger.Interface @@ -120,6 +123,10 @@ func (db *DB) Session(config *Session) *DB { tx.clone = 3 } + if config.DryRun { + tx.Config.DryRun = true + } + if config.Logger != nil { tx.Config.Logger = config.Logger } diff --git a/logger/sql.go b/logger/sql.go index dd502324..d3c0bf10 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -22,8 +22,10 @@ func isPrintable(s []byte) bool { var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} -func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var convertParams func(interface{}, int) + var vars = make([]interface{}, len(avars)) + copy(vars, avars) convertParams = func(v interface{}, idx int) { switch v := v.(type) { diff --git a/statement.go b/statement.go index aa7d193c..03d1b8a8 100644 --- a/statement.go +++ b/statement.go @@ -157,6 +157,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { writer.WriteString("(NULL)") } + case *DB: + result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement + writer.WriteString(result.SQL.String()) + stmt.Vars = append(stmt.Vars, result.Vars...) default: switch rv := reflect.ValueOf(v); rv.Kind() { case reflect.Slice, reflect.Array: @@ -226,7 +230,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con case clause.Expression: conds = append(conds, v) case *DB: - if v.Statement == nil { + if v.Statement != nil { if cs, ok := v.Statement.Clauses["WHERE"]; ok { conds = append(conds, cs.Expression) } @@ -367,7 +371,9 @@ func (stmt *Statement) reinit() { // }) // stmt.Schema = nil - stmt.SQL.Reset() - stmt.Vars = nil - stmt.NamedVars = nil + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + stmt.NamedVars = nil + } } diff --git a/tests/query_test.go b/tests/query_test.go index 73b6dca3..12f29ace 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -349,3 +349,89 @@ func TestSearchWithMap(t *testing.T) { t.Errorf("Search all records with inline multiple value map") } } + +func TestSubQuery(t *testing.T) { + users := []User{ + {Name: "subquery_1", Age: 10}, + {Name: "subquery_2", Age: 20}, + {Name: "subquery_3", Age: 30}, + {Name: "subquery_4", Age: 40}, + } + + DB.Create(&users) + + if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil { + t.Fatalf("got error: %v", err) + } + + if len(users) != 4 { + t.Errorf("Four users should be found, instead found %d", len(users)) + } + + DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB. + Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } +} + +func TestSubQueryWithRaw(t *testing.T) { + users := []User{ + {Name: "subquery_raw_1", Age: 10}, + {Name: "subquery_raw_2", Age: 20}, + {Name: "subquery_raw_3", Age: 30}, + {Name: "subquery_raw_4", Age: 40}, + } + DB.Create(&users) + + var count int64 + err := DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_raw%"). + Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } +} + +func TestSubQueryWithHaving(t *testing.T) { + users := []User{ + {Name: "subquery_having_1", Age: 10}, + {Name: "subquery_having_2", Age: 20}, + {Name: "subquery_having_3", Age: 30}, + {Name: "subquery_having_4", Age: 40}, + } + DB.Create(&users) + + var results []User + DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB. + Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results) + + if len(results) != 2 { + t.Errorf("Two user group should be found, instead found %d", len(results)) + } +}