Yay, callbacks works

This commit is contained in:
Jinzhu 2013-10-27 15:24:01 +08:00
parent b551fee276
commit c5b0908b22
4 changed files with 81 additions and 11 deletions

View File

@ -122,14 +122,16 @@ func (m *Model) TableName() string {
return reg.ReplaceAllString(toSnake(t.Name()), "s") return reg.ReplaceAllString(toSnake(t.Name()), "s")
} }
func (model *Model) callMethod(method string) error { func (m *Model) callMethod(method string) error {
fm := reflect.ValueOf(model).MethodByName(method) fm := reflect.ValueOf(m.Data).MethodByName(method)
if fm.IsValid() { if fm.IsValid() {
v := fm.Call([]reflect.Value{}) v := fm.Call([]reflect.Value{})
if len(v) > 0 {
if verr, ok := v[0].Interface().(error); ok { if verr, ok := v[0].Interface().(error); ok {
return verr return verr
} }
} }
}
return nil return nil
} }

4
orm.go
View File

@ -110,9 +110,9 @@ func (s *Orm) Select(value interface{}) *Orm {
func (s *Orm) Save(value interface{}) *Orm { func (s *Orm) Save(value interface{}) *Orm {
s.Model(value) s.Model(value)
if s.model.PrimaryKeyIsEmpty() { if s.model.PrimaryKeyIsEmpty() {
s.explain(value, "Create").create(value) s.create(value)
} else { } else {
s.explain(value, "Update").update(value) s.update(value)
} }
return s return s
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"reflect"
"testing" "testing"
"time" "time"
) )
@ -14,6 +15,20 @@ type User struct {
UpdatedAt time.Time UpdatedAt time.Time
} }
type Product struct {
Id int64
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
}
var ( var (
db DB db DB
t1, t2, t3, t4, t5 time.Time t1, t2, t3, t4, t5 time.Time
@ -22,11 +37,13 @@ var (
func init() { func init() {
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable") db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
db.Exec("drop table users;") db.Exec("drop table users;")
db.Exec("drop table products;")
orm := db.CreateTable(&User{}) orm := db.CreateTable(&User{})
if orm.Error != nil { if orm.Error != nil {
panic("No error should raise when create table") panic("No error should raise when create table")
} }
db.CreateTable(&Product{})
var shortForm = "2006-01-02 15:04:05" var shortForm = "2006-01-02 15:04:05"
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
@ -311,3 +328,55 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) {
t.Errorf("Updated At should be changed after update") t.Errorf("Updated At should be changed after update")
} }
} }
func (s *Product) BeforeCreate() {
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
}
func (s *Product) BeforeUpdate() {
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
}
func (s *Product) BeforeSave() {
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
}
func (s *Product) AfterCreate() {
s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1
}
func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Product) AfterSave() {
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
}
func (p *Product) GetCallTimes() []int64 {
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes}
}
func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) {
t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes())
}
db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) {
t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes())
}
p.Price = 200
db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) {
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
}
db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) {
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
}
}

9
sql.go
View File

@ -131,7 +131,7 @@ func (s *Orm) create(value interface{}) {
var id int64 var id int64
s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
s.explain(value, "Create")
if s.driver == "postgres" { if s.driver == "postgres" {
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
} else { } else {
@ -141,12 +141,11 @@ func (s *Orm) create(value interface{}) {
id, err = s.SqlResult.LastInsertId() id, err = s.SqlResult.LastInsertId()
s.err(err) s.err(err)
} }
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave")) s.err(s.model.callMethod("AfterSave"))
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
} }
func (s *Orm) updateSql(value interface{}) { func (s *Orm) updateSql(value interface{}) {
@ -169,7 +168,7 @@ func (s *Orm) updateSql(value interface{}) {
func (s *Orm) update(value interface{}) { func (s *Orm) update(value interface{}) {
s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
s.Exec() s.explain(value, "Update").Exec()
s.err(s.model.callMethod("AfterUpdate")) s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave")) s.err(s.model.callMethod("AfterSave"))
return return