Add AfterFind callback

This commit is contained in:
Jinzhu 2013-12-30 12:46:37 +08:00
parent 613600411b
commit dd77ca6df7
3 changed files with 28 additions and 8 deletions

View File

@ -6,7 +6,7 @@ Yet Another ORM library for Go, aims for developer friendly
* Chainable API * Chainable API
* Relations * Relations
* Callbacks (before/after create/save/update/delete) * Callbacks (before/after create/save/update/delete/find)
* Soft Delete * Soft Delete
* Auto Migration * Auto Migration
* Transaction * Transaction
@ -637,6 +637,13 @@ BeforeDelete
AfterDelete AfterDelete
``` ```
### After Find
```go
// load record/records from database
AfterFind
```
Here is an example: Here is an example:
```go ```go

2
do.go
View File

@ -443,6 +443,8 @@ func (s *Do) query() *Do {
} }
s.err(rows.Scan(values...)) s.err(rows.Scan(values...))
m := &Model{data: dest.Addr().Interface(), do: s}
m.callMethod("AfterFind")
if is_slice { if is_slice {
dest_out.Set(reflect.Append(dest_out, dest)) dest_out.Set(reflect.Append(dest_out, dest))
} }

View File

@ -70,6 +70,7 @@ type Product struct {
Price int64 Price int64
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
AfterFindCallTimes int64
BeforeCreateCallTimes int64 BeforeCreateCallTimes int64
AfterCreateCallTimes int64 AfterCreateCallTimes int64
BeforeUpdateCallTimes int64 BeforeUpdateCallTimes int64
@ -90,7 +91,7 @@ func init() {
switch os.Getenv("GORM_DIALECT") { switch os.Getenv("GORM_DIALECT") {
case "mysql": case "mysql":
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
// CREATE DATABASE 'gorm'; // CREATE DATABASE gorm;
// GRANT ALL ON gorm.* TO 'gorm'@'localhost'; // GRANT ALL ON gorm.* TO 'gorm'@'localhost';
fmt.Println("testing mysql...") fmt.Println("testing mysql...")
db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
@ -614,6 +615,10 @@ func (s *Product) BeforeSave() (err error) {
return return
} }
func (s *Product) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
}
func (s *Product) AfterCreate(db *DB) { func (s *Product) AfterCreate(db *DB) {
db.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) db.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
} }
@ -647,35 +652,41 @@ func (s *Product) AfterDelete() (err error) {
} }
func (p *Product) GetCallTimes() []int64 { func (p *Product) GetCallTimes() []int64 {
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes} return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes, p.AfterFindCallTimes}
} }
func TestRunCallbacks(t *testing.T) { func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100} p := Product{Code: "unique_code", Price: 100}
db.Save(&p) db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
} }
db.Where("Code = ?", "unique_code").First(&p) db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
} }
p.Price = 200 p.Price = 200
db.Save(&p) db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
} }
var products []Product
db.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should works with slice")
}
db.Where("Code = ?", "unique_code").First(&p) db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
} }
db.Delete(&p) db.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
} }