mirror of https://github.com/go-gorm/gorm.git
Add tests for callbacks
This commit is contained in:
parent
c5b0908b22
commit
2600e1099e
|
@ -3,7 +3,6 @@
|
||||||
Yet Another ORM library for Go, aims for developer friendly
|
Yet Another ORM library for Go, aims for developer friendly
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
* After/Before Save/Update/Create/Delete
|
|
||||||
* Soft Delete
|
* Soft Delete
|
||||||
* Better First method (First(&user, primary_key, where conditions))
|
* Better First method (First(&user, primary_key, where conditions))
|
||||||
* Even more complex where query (with map or struct)
|
* Even more complex where query (with map or struct)
|
||||||
|
|
94
orm_test.go
94
orm_test.go
|
@ -1,6 +1,7 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -27,6 +28,8 @@ type Product struct {
|
||||||
AfterUpdateCallTimes int64
|
AfterUpdateCallTimes int64
|
||||||
BeforeSaveCallTimes int64
|
BeforeSaveCallTimes int64
|
||||||
AfterSaveCallTimes int64
|
AfterSaveCallTimes int64
|
||||||
|
BeforeDeleteCallTimes int64
|
||||||
|
AfterDeleteCallTimes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -329,16 +332,28 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) BeforeCreate() {
|
func (s *Product) BeforeCreate() (err error) {
|
||||||
|
if s.Code == "Invalid" {
|
||||||
|
err = errors.New("invalid product")
|
||||||
|
}
|
||||||
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) BeforeUpdate() {
|
func (s *Product) BeforeUpdate() (err error) {
|
||||||
|
if s.Code == "dont_update" {
|
||||||
|
err = errors.New("Can't update")
|
||||||
|
}
|
||||||
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) BeforeSave() {
|
func (s *Product) BeforeSave() (err error) {
|
||||||
|
if s.Code == "dont_save" {
|
||||||
|
err = errors.New("Can't save")
|
||||||
|
}
|
||||||
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) AfterCreate() {
|
func (s *Product) AfterCreate() {
|
||||||
|
@ -353,30 +368,93 @@ func (s *Product) AfterSave() {
|
||||||
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Product) BeforeDelete() (err error) {
|
||||||
|
if s.Code == "dont_delete" {
|
||||||
|
err = errors.New("Can't delete")
|
||||||
|
}
|
||||||
|
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterDelete() {
|
||||||
|
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
|
||||||
|
}
|
||||||
func (p *Product) GetCallTimes() []int64 {
|
func (p *Product) GetCallTimes() []int64 {
|
||||||
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes}
|
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunCallbacks(t *testing.T) {
|
func TestRunCallbacks(t *testing.T) {
|
||||||
p := Product{Code: "unique_code", Price: 100}
|
p := Product{Code: "unique_code", Price: 100}
|
||||||
db.Save(&p)
|
db.Save(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0}) {
|
||||||
t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes())
|
t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Where("Code = ?", "unique_code").First(&p)
|
db.Where("Code = ?", "unique_code").First(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0, 0, 0}) {
|
||||||
t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes())
|
t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Price = 200
|
p.Price = 200
|
||||||
db.Save(&p)
|
db.Save(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1, 0, 0}) {
|
||||||
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Where("Code = ?", "unique_code").First(&p)
|
db.Where("Code = ?", "unique_code").First(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 0, 0}) {
|
||||||
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.Delete(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 1, 1}) {
|
||||||
|
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Where("Code = ?", "unique_code").First(&p).Error == nil {
|
||||||
|
t.Errorf("Should get error when find an deleted record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunCallbacksAndGetErrors(t *testing.T) {
|
||||||
|
p := Product{Code: "Invalid", Price: 100}
|
||||||
|
if db.Save(&p).Error == nil {
|
||||||
|
t.Errorf("An error from create callbacks expected when create")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
|
||||||
|
t.Errorf("Should not save records that have errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
|
||||||
|
t.Errorf("An error from create callbacks expected when create")
|
||||||
|
}
|
||||||
|
|
||||||
|
p2 := Product{Code: "update_callback", Price: 100}
|
||||||
|
db.Save(&p2)
|
||||||
|
p2.Code = "dont_update"
|
||||||
|
if db.Save(&p2).Error == nil {
|
||||||
|
t.Errorf("An error from callbacks expected when update")
|
||||||
|
}
|
||||||
|
if db.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
|
||||||
|
t.Errorf("Record Should not be updated due to errors happened in callback")
|
||||||
|
}
|
||||||
|
if db.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
|
||||||
|
t.Errorf("Record Should not be updated due to errors happened in callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
p2.Code = "dont_save"
|
||||||
|
if db.Save(&p2).Error == nil {
|
||||||
|
t.Errorf("An error from before save callbacks expected when update")
|
||||||
|
}
|
||||||
|
|
||||||
|
p3 := Product{Code: "dont_delete", Price: 100}
|
||||||
|
db.Save(&p3)
|
||||||
|
if db.Delete(&p3).Error == nil {
|
||||||
|
t.Errorf("An error from before delete callbacks expected when delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
|
||||||
|
t.Errorf("Should not delete record due to errors happened in callback")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
7
sql.go
7
sql.go
|
@ -132,6 +132,8 @@ func (s *Orm) create(value interface{}) {
|
||||||
s.err(s.model.callMethod("BeforeCreate"))
|
s.err(s.model.callMethod("BeforeCreate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
s.explain(value, "Create")
|
s.explain(value, "Create")
|
||||||
|
|
||||||
|
if len(s.Errors) == 0 {
|
||||||
if s.driver == "postgres" {
|
if s.driver == "postgres" {
|
||||||
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
||||||
} else {
|
} else {
|
||||||
|
@ -146,6 +148,7 @@ func (s *Orm) create(value interface{}) {
|
||||||
|
|
||||||
s.err(s.model.callMethod("AfterCreate"))
|
s.err(s.model.callMethod("AfterCreate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Orm) updateSql(value interface{}) {
|
func (s *Orm) updateSql(value interface{}) {
|
||||||
|
@ -168,7 +171,9 @@ func (s *Orm) updateSql(value interface{}) {
|
||||||
func (s *Orm) update(value interface{}) {
|
func (s *Orm) update(value interface{}) {
|
||||||
s.err(s.model.callMethod("BeforeUpdate"))
|
s.err(s.model.callMethod("BeforeUpdate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
|
if len(s.Errors) == 0 {
|
||||||
s.explain(value, "Update").Exec()
|
s.explain(value, "Update").Exec()
|
||||||
|
}
|
||||||
s.err(s.model.callMethod("AfterUpdate"))
|
s.err(s.model.callMethod("AfterUpdate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
return
|
return
|
||||||
|
@ -181,7 +186,9 @@ func (s *Orm) deleteSql(value interface{}) {
|
||||||
|
|
||||||
func (s *Orm) delete(value interface{}) {
|
func (s *Orm) delete(value interface{}) {
|
||||||
s.err(s.model.callMethod("BeforeDelete"))
|
s.err(s.model.callMethod("BeforeDelete"))
|
||||||
|
if len(s.Errors) == 0 {
|
||||||
s.Exec()
|
s.Exec()
|
||||||
|
}
|
||||||
s.err(s.model.callMethod("AfterDelete"))
|
s.err(s.model.callMethod("AfterDelete"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue