Add returning tests

This commit is contained in:
Jinzhu 2021-10-28 08:03:23 +08:00
parent 835d7bde59
commit e953880d19
6 changed files with 106 additions and 36 deletions

View File

@ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) {
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
rows.Close()
}
} else {
@ -152,20 +155,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var primaryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
if size := stmt.ReflectValue.Len(); size > 0 {
var primaryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {

16
scan.go
View File

@ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if update && db.Statement.Schema != nil {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
fields := make([]*schema.Field, len(columns))
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
}
}
if initialized || rows.Next() {
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
}
}
}
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)

View File

@ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
}
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build("UPDATE", "SET", "WHERE")
stmt.Build(stmt.DB.Callback().Update().Clauses...)
}
}

View File

@ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
}
}
}
// only sqlite, postgres support returning
func TestSoftDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}
users := []*User{
GetUser("delete-returning-1", Config{}),
GetUser("delete-returning-2", Config{}),
GetUser("delete-returning-3", Config{}),
}
DB.Create(&users)
var results []User
DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results)
if len(results) != 2 {
t.Errorf("failed to return delete data, got %v", results)
}
var count int64
DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count)
if count != 1 {
t.Errorf("failed to delete data, current count %v", count)
}
}
func TestDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}
companies := []Company{
{Name: "delete-returning-1"},
{Name: "delete-returning-2"},
{Name: "delete-returning-3"},
}
DB.Create(&companies)
var results []Company
DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results)
if len(results) != 2 {
t.Errorf("failed to return delete data, got %v", results)
}
var count int64
DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count)
if count != 1 {
t.Errorf("failed to delete data, current count %v", count)
}
}

View File

@ -7,8 +7,8 @@ require (
github.com/jinzhu/now v1.1.2
github.com/lib/pq v1.10.3
gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.2.0
gorm.io/driver/sqlite v1.2.0
gorm.io/driver/postgres v1.2.1
gorm.io/driver/sqlite v1.2.2
gorm.io/driver/sqlserver v1.1.2
gorm.io/gorm v1.22.0
)

View File

@ -167,16 +167,13 @@ func TestUpdates(t *testing.T) {
}
// update with gorm exprs
if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
var user4 User
DB.First(&user4, user3.ID)
// sqlite, postgres support returning
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
user3.Age += 100
}
user3.Age += 100
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
}
@ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) {
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
}
}
// only sqlite, postgres support returning
func TestUpdateReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}
users := []*User{
GetUser("update-returning-1", Config{}),
GetUser("update-returning-2", Config{}),
GetUser("update-returning-3", Config{}),
}
DB.Create(&users)
var results []User
DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88)
if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 {
t.Errorf("failed to return updated data, got %v", results)
}
if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
if results[1].Age-results[0].Age != 100 {
t.Errorf("failed to return updated age column")
}
}