Add returning tests

This commit is contained in:
Jinzhu 2021-10-28 08:03:23 +08:00
parent 835d7bde59
commit e953880d19
6 changed files with 106 additions and 36 deletions

View File

@ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) {
if !db.DryRun && db.Error == nil { if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok { if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode) gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
rows.Close() rows.Close()
} }
} else { } else {
@ -152,6 +155,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 {
var primaryKeyExprs []clause.Expression var primaryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
@ -165,7 +169,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
} }
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
}
case reflect.Struct: case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {

16
scan.go
View File

@ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}: case map[string]interface{}, *map[string]interface{}:
if update && db.Statement.Schema != nil {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
fields := make([]*schema.Field, len(columns))
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
}
}
if initialized || rows.Next() {
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
}
}
}
if initialized || rows.Next() { if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes() columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns) prepareValues(values, db, columnTypes, columns)

View File

@ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
} }
stmt.AddClauseIfNotExists(clause.Update{}) stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build("UPDATE", "SET", "WHERE") stmt.Build(stmt.DB.Callback().Update().Clauses...)
} }
} }

View File

@ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
} }
} }
} }
// only sqlite, postgres support returning
func TestSoftDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}
users := []*User{
GetUser("delete-returning-1", Config{}),
GetUser("delete-returning-2", Config{}),
GetUser("delete-returning-3", Config{}),
}
DB.Create(&users)
var results []User
DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results)
if len(results) != 2 {
t.Errorf("failed to return delete data, got %v", results)
}
var count int64
DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count)
if count != 1 {
t.Errorf("failed to delete data, current count %v", count)
}
}
func TestDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}
companies := []Company{
{Name: "delete-returning-1"},
{Name: "delete-returning-2"},
{Name: "delete-returning-3"},
}
DB.Create(&companies)
var results []Company
DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results)
if len(results) != 2 {
t.Errorf("failed to return delete data, got %v", results)
}
var count int64
DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count)
if count != 1 {
t.Errorf("failed to delete data, current count %v", count)
}
}

View File

@ -7,8 +7,8 @@ require (
github.com/jinzhu/now v1.1.2 github.com/jinzhu/now v1.1.2
github.com/lib/pq v1.10.3 github.com/lib/pq v1.10.3
gorm.io/driver/mysql v1.1.2 gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.2.0 gorm.io/driver/postgres v1.2.1
gorm.io/driver/sqlite v1.2.0 gorm.io/driver/sqlite v1.2.2
gorm.io/driver/sqlserver v1.1.2 gorm.io/driver/sqlserver v1.1.2
gorm.io/gorm v1.22.0 gorm.io/gorm v1.22.0
) )

View File

@ -167,16 +167,13 @@ func TestUpdates(t *testing.T) {
} }
// update with gorm exprs // update with gorm exprs
if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
} }
var user4 User var user4 User
DB.First(&user4, user3.ID) DB.First(&user4, user3.ID)
// sqlite, postgres support returning
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
user3.Age += 100 user3.Age += 100
}
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
} }
@ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) {
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
} }
} }
// only sqlite, postgres support returning
func TestUpdateReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}
users := []*User{
GetUser("update-returning-1", Config{}),
GetUser("update-returning-2", Config{}),
GetUser("update-returning-3", Config{}),
}
DB.Create(&users)
var results []User
DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88)
if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 {
t.Errorf("failed to return updated data, got %v", results)
}
if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
if results[1].Age-results[0].Age != 100 {
t.Errorf("failed to return updated age column")
}
}