Add Delete test

This commit is contained in:
Jinzhu 2020-03-08 14:51:52 +08:00
parent 2da0ad5bed
commit ce0e6f9f33
5 changed files with 99 additions and 2 deletions

View File

@ -4,6 +4,7 @@ import (
"reflect" "reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
) )
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
@ -32,6 +33,37 @@ func BeforeDelete(db *gorm.DB) {
} }
func Delete(db *gorm.DB) { func Delete(db *gorm.DB) {
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Delete{})
values := []reflect.Value{db.Statement.ReflectValue}
if db.Statement.Dest != db.Statement.Model {
values = append(values, reflect.ValueOf(db.Statement.Model))
}
for _, field := range db.Statement.Schema.PrimaryFields {
for _, value := range values {
if value, isZero := field.ValueOf(value); !isZero {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("DELETE", "FROM", "WHERE")
}
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
} }
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {

View File

@ -135,8 +135,13 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
}
tx.Statement.Dest = value
tx.callbacks.Delete().Execute(tx)
return return
} }

View File

@ -17,6 +17,8 @@ var (
ErrUnaddressable = errors.New("using unaddressable value") ErrUnaddressable = errors.New("using unaddressable value")
// ErrNotImplemented not implemented // ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented") ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
) )
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt

View File

@ -121,7 +121,7 @@ func (l logger) Error(msg string, data ...interface{}) {
// Trace print sql message // Trace print sql message
func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) {
if elapsed := time.Now().Sub(begin); err != nil || (elapsed > l.SlowThreshold && l.SlowThreshold != 0) { if elapsed := time.Now().Sub(begin); elapsed > l.SlowThreshold && l.SlowThreshold != 0 {
sql, rows := fc() sql, rows := fc()
fileline := utils.FileWithLineNum() fileline := utils.FileWithLineNum()
if err != nil { if err != nil {

View File

@ -1,6 +1,7 @@
package tests package tests
import ( import (
"errors"
"reflect" "reflect"
"strconv" "strconv"
"testing" "testing"
@ -18,6 +19,7 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) {
TestCreate(t, db) TestCreate(t, db)
TestFind(t, db) TestFind(t, db)
TestUpdate(t, db) TestUpdate(t, db)
TestDelete(t, db)
} }
func TestCreate(t *testing.T, db *gorm.DB) { func TestCreate(t *testing.T, db *gorm.DB) {
@ -266,3 +268,59 @@ func TestUpdate(t *testing.T, db *gorm.DB) {
} }
}) })
} }
func TestDelete(t *testing.T, db *gorm.DB) {
db.Migrator().DropTable(&User{})
db.AutoMigrate(&User{})
t.Run("Delete", func(t *testing.T) {
var users = []User{{
Name: "find",
Age: 1,
Birthday: Now(),
}, {
Name: "find",
Age: 2,
Birthday: Now(),
}, {
Name: "find",
Age: 3,
Birthday: Now(),
}}
if err := db.Create(&users).Error; err != nil {
t.Errorf("errors happened when create: %v", err)
}
for _, user := range users {
if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}
}
if err := db.Delete(&users[1]).Error; err != nil {
t.Errorf("errors happened when delete: %v", err)
}
var result User
if err := db.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
t.Errorf("should returns record not found error, but got %v", err)
}
for _, user := range []User{users[0], users[2]} {
if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil {
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
}
}
if err := db.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) {
t.Errorf("should returns missing WHERE clause while deleting error")
}
for _, user := range []User{users[0], users[2]} {
if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil {
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
}
}
})
}