diff --git a/callbacks/delete.go b/callbacks/delete.go index d79f88fc..05d00d0a 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -4,6 +4,7 @@ import ( "reflect" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) func BeforeDelete(db *gorm.DB) { @@ -32,6 +33,37 @@ func BeforeDelete(db *gorm.DB) { } func Delete(db *gorm.DB) { + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Delete{}) + + values := []reflect.Value{db.Statement.ReflectValue} + if db.Statement.Dest != db.Statement.Model { + values = append(values, reflect.ValueOf(db.Statement.Model)) + } + for _, field := range db.Statement.Schema.PrimaryFields { + for _, value := range values { + if value, isZero := field.ValueOf(value); !isZero { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.Build("DELETE", "FROM", "WHERE") + } + + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } func AfterDelete(db *gorm.DB) { diff --git a/finisher_api.go b/finisher_api.go index 0b729cc9..806c6723 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -135,8 +135,13 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } + tx.Statement.Dest = value + tx.callbacks.Delete().Execute(tx) return } diff --git a/helpers.go b/helpers.go index d7177ba7..241d3fbd 100644 --- a/helpers.go +++ b/helpers.go @@ -17,6 +17,8 @@ var ( ErrUnaddressable = errors.New("using unaddressable value") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") + // ErrMissingWhereClause missing where clause + ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") ) // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt diff --git a/logger/logger.go b/logger/logger.go index 2a765628..80ae31b1 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -121,7 +121,7 @@ func (l logger) Error(msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { - if elapsed := time.Now().Sub(begin); err != nil || (elapsed > l.SlowThreshold && l.SlowThreshold != 0) { + if elapsed := time.Now().Sub(begin); elapsed > l.SlowThreshold && l.SlowThreshold != 0 { sql, rows := fc() fileline := utils.FileWithLineNum() if err != nil { diff --git a/tests/tests.go b/tests/tests.go index 4181ad46..a15a9d0d 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,6 +1,7 @@ package tests import ( + "errors" "reflect" "strconv" "testing" @@ -18,6 +19,7 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { TestCreate(t, db) TestFind(t, db) TestUpdate(t, db) + TestDelete(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { @@ -266,3 +268,59 @@ func TestUpdate(t *testing.T, db *gorm.DB) { } }) } + +func TestDelete(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Delete", func(t *testing.T) { + var users = []User{{ + Name: "find", + Age: 1, + Birthday: Now(), + }, { + Name: "find", + Age: 2, + Birthday: Now(), + }, { + Name: "find", + Age: 3, + Birthday: Now(), + }} + + if err := db.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + for _, user := range users { + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + } + + if err := db.Delete(&users[1]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + var result User + if err := db.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } + + for _, user := range []User{users[0], users[2]} { + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + if err := db.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } + + for _, user := range []User{users[0], users[2]} { + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + }) +}