forked from mirror/gorm
Add returning support to delete
This commit is contained in:
parent
af3fbdc2fc
commit
835d7bde59
|
@ -57,7 +57,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||||
deleteCallback.Register("gorm:delete", Delete)
|
deleteCallback.Register("gorm:delete", Delete(config))
|
||||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
deleteCallback.Clauses = config.DeleteClauses
|
deleteCallback.Clauses = config.DeleteClauses
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BeforeCreate(db *gorm.DB) {
|
func BeforeCreate(db *gorm.DB) {
|
||||||
|
@ -31,18 +32,12 @@ func BeforeCreate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(config *Config) func(db *gorm.DB) {
|
func Create(config *Config) func(db *gorm.DB) {
|
||||||
withReturning := false
|
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
|
||||||
for _, clause := range config.CreateClauses {
|
|
||||||
if clause == "RETURNING" {
|
|
||||||
withReturning = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(db *gorm.DB) {
|
return func(db *gorm.DB) {
|
||||||
if db.Error != nil {
|
if db.Error != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
onReturning := false
|
|
||||||
|
|
||||||
if db.Statement.Schema != nil {
|
if db.Statement.Schema != nil {
|
||||||
if !db.Statement.Unscoped {
|
if !db.Statement.Unscoped {
|
||||||
|
@ -51,8 +46,7 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||||
onReturning = true
|
|
||||||
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||||
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||||
|
@ -72,18 +66,15 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
if onReturning {
|
|
||||||
doNothing := false
|
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||||
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
||||||
onConflict, _ := c.Expression.(clause.OnConflict)
|
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
|
||||||
doNothing = onConflict.DoNothing
|
mode |= gorm.ScanOnConflictDoNothing
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
if doNothing {
|
gorm.Scan(rows, db, mode)
|
||||||
gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing)
|
|
||||||
} else {
|
|
||||||
gorm.Scan(rows, db, gorm.ScanUpdate)
|
|
||||||
}
|
|
||||||
rows.Close()
|
rows.Close()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BeforeDelete(db *gorm.DB) {
|
func BeforeDelete(db *gorm.DB) {
|
||||||
|
@ -104,8 +105,14 @@ func DeleteBeforeAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Delete(db *gorm.DB) {
|
func Delete(config *Config) func(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
|
||||||
|
|
||||||
|
return func(db *gorm.DB) {
|
||||||
|
if db.Error != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||||
db.Statement.AddClause(c)
|
db.Statement.AddClause(c)
|
||||||
|
@ -144,12 +151,16 @@ func Delete(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||||
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
if err == nil {
|
gorm.Scan(rows, db, mode)
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
rows.Close()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
db.AddError(err)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
if db.AddError(err) == nil {
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,3 +93,16 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
|
||||||
|
if supportReturning {
|
||||||
|
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
|
||||||
|
returning, _ := c.Expression.(clause.Returning)
|
||||||
|
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
return true, gorm.ScanUpdate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupUpdateReflectValue(db *gorm.DB) {
|
func SetupUpdateReflectValue(db *gorm.DB) {
|
||||||
|
@ -51,12 +52,7 @@ func BeforeUpdate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Update(config *Config) func(db *gorm.DB) {
|
func Update(config *Config) func(db *gorm.DB) {
|
||||||
withReturning := false
|
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
|
||||||
for _, clause := range config.UpdateClauses {
|
|
||||||
if clause == "RETURNING" {
|
|
||||||
withReturning = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(db *gorm.DB) {
|
return func(db *gorm.DB) {
|
||||||
if db.Error != nil {
|
if db.Error != nil {
|
||||||
|
@ -86,18 +82,16 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok {
|
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
gorm.Scan(rows, db, gorm.ScanUpdate)
|
gorm.Scan(rows, db, mode)
|
||||||
rows.Close()
|
rows.Close()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if db.AddError(err) == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
} else {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,16 @@ func (returning Returning) Name() string {
|
||||||
|
|
||||||
// Build build where clause
|
// Build build where clause
|
||||||
func (returning Returning) Build(builder Builder) {
|
func (returning Returning) Build(builder Builder) {
|
||||||
for idx, column := range returning.Columns {
|
if len(returning.Columns) > 0 {
|
||||||
if idx > 0 {
|
for idx, column := range returning.Columns {
|
||||||
builder.WriteByte(',')
|
if idx > 0 {
|
||||||
}
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
|
||||||
builder.WriteQuoted(column)
|
builder.WriteQuoted(column)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
builder.WriteByte('*')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
scan.go
2
scan.go
|
@ -241,7 +241,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var elem reflect.Value
|
var elem reflect.Value
|
||||||
|
|
||||||
if !update {
|
if !update && reflectValue.Len() != 0 {
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,8 @@ require (
|
||||||
gorm.io/driver/mysql v1.1.2
|
gorm.io/driver/mysql v1.1.2
|
||||||
gorm.io/driver/postgres v1.2.0
|
gorm.io/driver/postgres v1.2.0
|
||||||
gorm.io/driver/sqlite v1.2.0
|
gorm.io/driver/sqlite v1.2.0
|
||||||
gorm.io/driver/sqlserver v1.1.1
|
gorm.io/driver/sqlserver v1.1.2
|
||||||
gorm.io/gorm v1.21.16
|
gorm.io/gorm v1.22.0
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
|
@ -167,7 +167,7 @@ func TestUpdates(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// update with gorm exprs
|
// update with gorm exprs
|
||||||
if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||||
}
|
}
|
||||||
var user4 User
|
var user4 User
|
||||||
|
|
|
@ -72,6 +72,15 @@ func ToStringKey(values ...interface{}) string {
|
||||||
return strings.Join(results, "_")
|
return strings.Join(results, "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Contains(elems []string, elem string) bool {
|
||||||
|
for _, e := range elems {
|
||||||
|
if elem == e {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func AssertEqual(src, dst interface{}) bool {
|
func AssertEqual(src, dst interface{}) bool {
|
||||||
if !reflect.DeepEqual(src, dst) {
|
if !reflect.DeepEqual(src, dst) {
|
||||||
if valuer, ok := src.(driver.Valuer); ok {
|
if valuer, ok := src.(driver.Valuer); ok {
|
||||||
|
|
Loading…
Reference in New Issue