From 835d7bde59a24ac769a1c5ded206b58f7cedfba3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 Oct 2021 07:24:38 +0800 Subject: [PATCH] Add returning support to delete --- callbacks/callbacks.go | 2 +- callbacks/create.go | 27 +++++++++------------------ callbacks/delete.go | 25 ++++++++++++++++++------- callbacks/helper.go | 13 +++++++++++++ callbacks/update.go | 16 +++++----------- clause/returning.go | 14 +++++++++----- scan.go | 2 +- tests/go.mod | 4 ++-- tests/update_test.go | 2 +- utils/utils.go | 9 +++++++++ 10 files changed, 68 insertions(+), 46 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index bc18d854..d681aef3 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -57,7 +57,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) - deleteCallback.Register("gorm:delete", Delete) + deleteCallback.Register("gorm:delete", Delete(config)) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Clauses = config.DeleteClauses diff --git a/callbacks/create.go b/callbacks/create.go index fe4cd797..656273fb 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func BeforeCreate(db *gorm.DB) { @@ -31,18 +32,12 @@ func BeforeCreate(db *gorm.DB) { } func Create(config *Config) func(db *gorm.DB) { - withReturning := false - for _, clause := range config.CreateClauses { - if clause == "RETURNING" { - withReturning = true - } - } + supportReturning := utils.Contains(config.CreateClauses, "RETURNING") return func(db *gorm.DB) { if db.Error != nil { return } - onReturning := false if db.Statement.Schema != nil { if !db.Statement.Unscoped { @@ -51,8 +46,7 @@ func Create(config *Config) func(db *gorm.DB) { } } - if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { - onReturning = true + if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { @@ -72,18 +66,15 @@ func Create(config *Config) func(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - if onReturning { - doNothing := false + + if ok, mode := hasReturning(db, supportReturning); ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - onConflict, _ := c.Expression.(clause.OnConflict) - doNothing = onConflict.DoNothing + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing + } } if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - if doNothing { - gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing) - } else { - gorm.Scan(rows, db, gorm.ScanUpdate) - } + gorm.Scan(rows, db, mode) rows.Close() } } else { diff --git a/callbacks/delete.go b/callbacks/delete.go index 91659c51..a1fd0a57 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func BeforeDelete(db *gorm.DB) { @@ -104,8 +105,14 @@ func DeleteBeforeAssociations(db *gorm.DB) { } } -func Delete(db *gorm.DB) { - if db.Error == nil { +func Delete(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + if db.Statement.Schema != nil && !db.Statement.Unscoped { for _, c := range db.Statement.Schema.DeleteClauses { db.Statement.AddClause(c) @@ -144,12 +151,16 @@ func Delete(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - db.RowsAffected, _ = result.RowsAffected() + if ok, mode := hasReturning(db, supportReturning); ok { + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + rows.Close() + } } else { - db.AddError(err) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } } } } diff --git a/callbacks/helper.go b/callbacks/helper.go index d83d20ce..1d96ab26 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -93,3 +93,16 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st } return } + +func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { + if supportReturning { + if c, ok := tx.Statement.Clauses["RETURNING"]; ok { + returning, _ := c.Expression.(clause.Returning) + if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { + return true, 0 + } + return true, gorm.ScanUpdate + } + } + return false, 0 +} diff --git a/callbacks/update.go b/callbacks/update.go index 90dc6a89..991581dd 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func SetupUpdateReflectValue(db *gorm.DB) { @@ -51,12 +52,7 @@ func BeforeUpdate(db *gorm.DB) { } func Update(config *Config) func(db *gorm.DB) { - withReturning := false - for _, clause := range config.UpdateClauses { - if clause == "RETURNING" { - withReturning = true - } - } + supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") return func(db *gorm.DB) { if db.Error != nil { @@ -86,18 +82,16 @@ func Update(config *Config) func(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok { + if ok, mode := hasReturning(db, supportReturning); ok { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - gorm.Scan(rows, db, gorm.ScanUpdate) + gorm.Scan(rows, db, mode) rows.Close() } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { + if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) } } } diff --git a/clause/returning.go b/clause/returning.go index 04bc96da..d94b7a4c 100644 --- a/clause/returning.go +++ b/clause/returning.go @@ -11,12 +11,16 @@ func (returning Returning) Name() string { // Build build where clause func (returning Returning) Build(builder Builder) { - for idx, column := range returning.Columns { - if idx > 0 { - builder.WriteByte(',') - } + if len(returning.Columns) > 0 { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } - builder.WriteQuoted(column) + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') } } diff --git a/scan.go b/scan.go index 37f5112d..70fcda4a 100644 --- a/scan.go +++ b/scan.go @@ -241,7 +241,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { case reflect.Slice, reflect.Array: var elem reflect.Value - if !update { + if !update && reflectValue.Len() != 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) } diff --git a/tests/go.mod b/tests/go.mod index 96db0559..6d9e68c1 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,8 +9,8 @@ require ( gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.2.0 gorm.io/driver/sqlite v1.2.0 - gorm.io/driver/sqlserver v1.1.1 - gorm.io/gorm v1.21.16 + gorm.io/driver/sqlserver v1.1.2 + gorm.io/gorm v1.22.0 ) replace gorm.io/gorm => ../ diff --git a/tests/update_test.go b/tests/update_test.go index 0dd9465a..f58656ed 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -167,7 +167,7 @@ func TestUpdates(t *testing.T) { } // update with gorm exprs - if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } var user4 User diff --git a/utils/utils.go b/utils/utils.go index 9c238ac5..f00f92ba 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -72,6 +72,15 @@ func ToStringKey(values ...interface{}) string { return strings.Join(results, "_") } +func Contains(elems []string, elem string) bool { + for _, e := range elems { + if elem == e { + return true + } + } + return false +} + func AssertEqual(src, dst interface{}) bool { if !reflect.DeepEqual(src, dst) { if valuer, ok := src.(driver.Valuer); ok {