diff --git a/callbacks/update.go b/callbacks/update.go index 991581dd..1603a517 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) { if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + dest := db.Statement.Dest + db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() gorm.Scan(rows, db, mode) + db.Statement.Dest = dest rows.Close() } } else { @@ -152,20 +155,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - var primaryKeyExprs []clause.Expression - for i := 0; i < stmt.ReflectValue.Len(); i++ { - var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) - var notZero bool - for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) - exprs[idx] = clause.Eq{Column: field.DBName, Value: value} - notZero = notZero || !isZero - } - if notZero { - primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + if size := stmt.ReflectValue.Len(); size > 0 { + var primaryKeyExprs []clause.Expression + for i := 0; i < stmt.ReflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + } } + + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { diff --git a/scan.go b/scan.go index 70fcda4a..360ed8b9 100644 --- a/scan.go +++ b/scan.go @@ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: - if update && db.Statement.Schema != nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - fields := make([]*schema.Field, len(columns)) - for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } - } - - if initialized || rows.Next() { - db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) - } - } - } - if initialized || rows.Next() { columnTypes, _ := rows.ColumnTypes() prepareValues(values, db, columnTypes, columns) diff --git a/soft_delete.go b/soft_delete.go index af02f8fd..11c4fafc 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } stmt.AddClauseIfNotExists(clause.Update{}) - stmt.Build("UPDATE", "SET", "WHERE") + stmt.Build(stmt.DB.Callback().Update().Clauses...) } } diff --git a/tests/delete_test.go b/tests/delete_test.go index f62cc606..049b2ac4 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) { } } } + +// only sqlite, postgres support returning +func TestSoftDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("delete-returning-1", Config{}), + GetUser("delete-returning-2", Config{}), + GetUser("delete-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} + +func TestDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + companies := []Company{ + {Name: "delete-returning-1"}, + {Name: "delete-returning-2"}, + {Name: "delete-returning-3"}, + } + DB.Create(&companies) + + var results []Company + DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} diff --git a/tests/go.mod b/tests/go.mod index 6d9e68c1..ab3ef898 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,8 +7,8 @@ require ( github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 - gorm.io/driver/postgres v1.2.0 - gorm.io/driver/sqlite v1.2.0 + gorm.io/driver/postgres v1.2.1 + gorm.io/driver/sqlite v1.2.2 gorm.io/driver/sqlserver v1.1.2 gorm.io/gorm v1.22.0 ) diff --git a/tests/update_test.go b/tests/update_test.go index f58656ed..14ed9820 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -167,16 +167,13 @@ func TestUpdates(t *testing.T) { } // update with gorm exprs - if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } var user4 User DB.First(&user4, user3.ID) - // sqlite, postgres support returning - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { - user3.Age += 100 - } + user3.Age += 100 AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") } @@ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) { t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) } } + +// only sqlite, postgres support returning +func TestUpdateReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("update-returning-1", Config{}), + GetUser("update-returning-2", Config{}), + GetUser("update-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88) + if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 { + t.Errorf("failed to return updated data, got %v", results) + } + + if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if results[1].Age-results[0].Age != 100 { + t.Errorf("failed to return updated age column") + } +}