Yay, callbacks works

This commit is contained in:
Jinzhu 2013-10-27 15:24:01 +08:00
parent b551fee276
commit c5b0908b22
4 changed files with 81 additions and 11 deletions

View File

@ -122,14 +122,16 @@ func (m *Model) TableName() string {
return reg.ReplaceAllString(toSnake(t.Name()), "s")
}
func (model *Model) callMethod(method string) error {
fm := reflect.ValueOf(model).MethodByName(method)
func (m *Model) callMethod(method string) error {
fm := reflect.ValueOf(m.Data).MethodByName(method)
if fm.IsValid() {
v := fm.Call([]reflect.Value{})
if len(v) > 0 {
if verr, ok := v[0].Interface().(error); ok {
return verr
}
}
}
return nil
}

4
orm.go
View File

@ -110,9 +110,9 @@ func (s *Orm) Select(value interface{}) *Orm {
func (s *Orm) Save(value interface{}) *Orm {
s.Model(value)
if s.model.PrimaryKeyIsEmpty() {
s.explain(value, "Create").create(value)
s.create(value)
} else {
s.explain(value, "Update").update(value)
s.update(value)
}
return s
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"reflect"
"testing"
"time"
)
@ -14,6 +15,20 @@ type User struct {
UpdatedAt time.Time
}
type Product struct {
Id int64
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
}
var (
db DB
t1, t2, t3, t4, t5 time.Time
@ -22,11 +37,13 @@ var (
func init() {
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
db.Exec("drop table users;")
db.Exec("drop table products;")
orm := db.CreateTable(&User{})
if orm.Error != nil {
panic("No error should raise when create table")
}
db.CreateTable(&Product{})
var shortForm = "2006-01-02 15:04:05"
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
@ -311,3 +328,55 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) {
t.Errorf("Updated At should be changed after update")
}
}
func (s *Product) BeforeCreate() {
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
}
func (s *Product) BeforeUpdate() {
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
}
func (s *Product) BeforeSave() {
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
}
func (s *Product) AfterCreate() {
s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1
}
func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Product) AfterSave() {
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
}
func (p *Product) GetCallTimes() []int64 {
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes}
}
func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) {
t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes())
}
db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) {
t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes())
}
p.Price = 200
db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) {
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
}
db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) {
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
}
}

9
sql.go
View File

@ -131,7 +131,7 @@ func (s *Orm) create(value interface{}) {
var id int64
s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave"))
s.explain(value, "Create")
if s.driver == "postgres" {
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
} else {
@ -141,12 +141,11 @@ func (s *Orm) create(value interface{}) {
id, err = s.SqlResult.LastInsertId()
s.err(err)
}
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
}
func (s *Orm) updateSql(value interface{}) {
@ -169,7 +168,7 @@ func (s *Orm) updateSql(value interface{}) {
func (s *Orm) update(value interface{}) {
s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave"))
s.Exec()
s.explain(value, "Update").Exec()
s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave"))
return