mirror of https://github.com/go-gorm/gorm.git
Yay, callbacks works
This commit is contained in:
parent
b551fee276
commit
c5b0908b22
6
model.go
6
model.go
|
@ -122,14 +122,16 @@ func (m *Model) TableName() string {
|
|||
return reg.ReplaceAllString(toSnake(t.Name()), "s")
|
||||
}
|
||||
|
||||
func (model *Model) callMethod(method string) error {
|
||||
fm := reflect.ValueOf(model).MethodByName(method)
|
||||
func (m *Model) callMethod(method string) error {
|
||||
fm := reflect.ValueOf(m.Data).MethodByName(method)
|
||||
if fm.IsValid() {
|
||||
v := fm.Call([]reflect.Value{})
|
||||
if len(v) > 0 {
|
||||
if verr, ok := v[0].Interface().(error); ok {
|
||||
return verr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
4
orm.go
4
orm.go
|
@ -110,9 +110,9 @@ func (s *Orm) Select(value interface{}) *Orm {
|
|||
func (s *Orm) Save(value interface{}) *Orm {
|
||||
s.Model(value)
|
||||
if s.model.PrimaryKeyIsEmpty() {
|
||||
s.explain(value, "Create").create(value)
|
||||
s.create(value)
|
||||
} else {
|
||||
s.explain(value, "Update").update(value)
|
||||
s.update(value)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
|
69
orm_test.go
69
orm_test.go
|
@ -1,6 +1,7 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -14,6 +15,20 @@ type User struct {
|
|||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Product struct {
|
||||
Id int64
|
||||
Code string
|
||||
Price int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
BeforeCreateCallTimes int64
|
||||
AfterCreateCallTimes int64
|
||||
BeforeUpdateCallTimes int64
|
||||
AfterUpdateCallTimes int64
|
||||
BeforeSaveCallTimes int64
|
||||
AfterSaveCallTimes int64
|
||||
}
|
||||
|
||||
var (
|
||||
db DB
|
||||
t1, t2, t3, t4, t5 time.Time
|
||||
|
@ -22,11 +37,13 @@ var (
|
|||
func init() {
|
||||
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
db.Exec("drop table users;")
|
||||
db.Exec("drop table products;")
|
||||
|
||||
orm := db.CreateTable(&User{})
|
||||
if orm.Error != nil {
|
||||
panic("No error should raise when create table")
|
||||
}
|
||||
db.CreateTable(&Product{})
|
||||
|
||||
var shortForm = "2006-01-02 15:04:05"
|
||||
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
||||
|
@ -311,3 +328,55 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) {
|
|||
t.Errorf("Updated At should be changed after update")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Product) BeforeCreate() {
|
||||
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) BeforeUpdate() {
|
||||
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) BeforeSave() {
|
||||
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) AfterCreate() {
|
||||
s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) AfterUpdate() {
|
||||
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) AfterSave() {
|
||||
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||
}
|
||||
|
||||
func (p *Product) GetCallTimes() []int64 {
|
||||
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes}
|
||||
}
|
||||
|
||||
func TestRunCallbacks(t *testing.T) {
|
||||
p := Product{Code: "unique_code", Price: 100}
|
||||
db.Save(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) {
|
||||
t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
db.Where("Code = ?", "unique_code").First(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) {
|
||||
t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
p.Price = 200
|
||||
db.Save(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) {
|
||||
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
db.Where("Code = ?", "unique_code").First(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) {
|
||||
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
||||
}
|
||||
}
|
||||
|
|
9
sql.go
9
sql.go
|
@ -131,7 +131,7 @@ func (s *Orm) create(value interface{}) {
|
|||
var id int64
|
||||
s.err(s.model.callMethod("BeforeCreate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
|
||||
s.explain(value, "Create")
|
||||
if s.driver == "postgres" {
|
||||
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
||||
} else {
|
||||
|
@ -141,12 +141,11 @@ func (s *Orm) create(value interface{}) {
|
|||
id, err = s.SqlResult.LastInsertId()
|
||||
s.err(err)
|
||||
}
|
||||
result := reflect.ValueOf(s.model.Data).Elem()
|
||||
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
|
||||
|
||||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
|
||||
result := reflect.ValueOf(s.model.Data).Elem()
|
||||
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
|
||||
}
|
||||
|
||||
func (s *Orm) updateSql(value interface{}) {
|
||||
|
@ -169,7 +168,7 @@ func (s *Orm) updateSql(value interface{}) {
|
|||
func (s *Orm) update(value interface{}) {
|
||||
s.err(s.model.callMethod("BeforeUpdate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
s.Exec()
|
||||
s.explain(value, "Update").Exec()
|
||||
s.err(s.model.callMethod("AfterUpdate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue