diff --git a/README.md b/README.md index 801f8e6d..acb0e9f3 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO -* After/Before Save/Update/Create/Delete * Soft Delete * Better First method (First(&user, primary_key, where conditions)) * Even more complex where query (with map or struct) diff --git a/orm_test.go b/orm_test.go index 984592a6..dcde82a7 100644 --- a/orm_test.go +++ b/orm_test.go @@ -1,6 +1,7 @@ package gorm import ( + "errors" "reflect" "testing" "time" @@ -27,6 +28,8 @@ type Product struct { AfterUpdateCallTimes int64 BeforeSaveCallTimes int64 AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 } var ( @@ -329,16 +332,28 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) { } } -func (s *Product) BeforeCreate() { +func (s *Product) BeforeCreate() (err error) { + if s.Code == "Invalid" { + err = errors.New("invalid product") + } s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 + return } -func (s *Product) BeforeUpdate() { +func (s *Product) BeforeUpdate() (err error) { + if s.Code == "dont_update" { + err = errors.New("Can't update") + } s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 + return } -func (s *Product) BeforeSave() { +func (s *Product) BeforeSave() (err error) { + if s.Code == "dont_save" { + err = errors.New("Can't save") + } s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 + return } func (s *Product) AfterCreate() { @@ -353,30 +368,93 @@ func (s *Product) AfterSave() { s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 } +func (s *Product) BeforeDelete() (err error) { + if s.Code == "dont_delete" { + err = errors.New("Can't delete") + } + s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 + return +} + +func (s *Product) AfterDelete() { + s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 +} func (p *Product) GetCallTimes() []int64 { - return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes} + return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes} } func TestRunCallbacks(t *testing.T) { p := Product{Code: "unique_code", Price: 100} db.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0}) { t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes()) } db.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0, 0, 0}) { t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes()) } p.Price = 200 db.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1, 0, 0}) { t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) } db.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 0, 0}) { t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) } + + db.Delete(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 1, 1}) { + t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) + } + + if db.Where("Code = ?", "unique_code").First(&p).Error == nil { + t.Errorf("Should get error when find an deleted record") + } +} + +func TestRunCallbacksAndGetErrors(t *testing.T) { + p := Product{Code: "Invalid", Price: 100} + if db.Save(&p).Error == nil { + t.Errorf("An error from create callbacks expected when create") + } + + if db.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + t.Errorf("Should not save records that have errors") + } + + if db.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + t.Errorf("An error from create callbacks expected when create") + } + + p2 := Product{Code: "update_callback", Price: 100} + db.Save(&p2) + p2.Code = "dont_update" + if db.Save(&p2).Error == nil { + t.Errorf("An error from callbacks expected when update") + } + if db.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + t.Errorf("Record Should not be updated due to errors happened in callback") + } + if db.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + t.Errorf("Record Should not be updated due to errors happened in callback") + } + + p2.Code = "dont_save" + if db.Save(&p2).Error == nil { + t.Errorf("An error from before save callbacks expected when update") + } + + p3 := Product{Code: "dont_delete", Price: 100} + db.Save(&p3) + if db.Delete(&p3).Error == nil { + t.Errorf("An error from before delete callbacks expected when delete") + } + + if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + t.Errorf("Should not delete record due to errors happened in callback") + } } diff --git a/sql.go b/sql.go index eeab9365..a8f24997 100644 --- a/sql.go +++ b/sql.go @@ -132,20 +132,23 @@ func (s *Orm) create(value interface{}) { s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeSave")) s.explain(value, "Create") - if s.driver == "postgres" { - s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) - } else { - var err error - s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) - s.err(err) - id, err = s.SqlResult.LastInsertId() - s.err(err) - } - result := reflect.ValueOf(s.model.Data).Elem() - result.FieldByName(s.model.PrimaryKey()).SetInt(id) - s.err(s.model.callMethod("AfterCreate")) - s.err(s.model.callMethod("AfterSave")) + if len(s.Errors) == 0 { + if s.driver == "postgres" { + s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) + } else { + var err error + s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) + s.err(err) + id, err = s.SqlResult.LastInsertId() + s.err(err) + } + result := reflect.ValueOf(s.model.Data).Elem() + result.FieldByName(s.model.PrimaryKey()).SetInt(id) + + s.err(s.model.callMethod("AfterCreate")) + s.err(s.model.callMethod("AfterSave")) + } } func (s *Orm) updateSql(value interface{}) { @@ -168,7 +171,9 @@ func (s *Orm) updateSql(value interface{}) { func (s *Orm) update(value interface{}) { s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeSave")) - s.explain(value, "Update").Exec() + if len(s.Errors) == 0 { + s.explain(value, "Update").Exec() + } s.err(s.model.callMethod("AfterUpdate")) s.err(s.model.callMethod("AfterSave")) return @@ -181,7 +186,9 @@ func (s *Orm) deleteSql(value interface{}) { func (s *Orm) delete(value interface{}) { s.err(s.model.callMethod("BeforeDelete")) - s.Exec() + if len(s.Errors) == 0 { + s.Exec() + } s.err(s.model.callMethod("AfterDelete")) }