mirror of https://github.com/go-gorm/gorm.git
Yay, callbacks works
This commit is contained in:
parent
b551fee276
commit
c5b0908b22
10
model.go
10
model.go
|
@ -122,12 +122,14 @@ func (m *Model) TableName() string {
|
||||||
return reg.ReplaceAllString(toSnake(t.Name()), "s")
|
return reg.ReplaceAllString(toSnake(t.Name()), "s")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (model *Model) callMethod(method string) error {
|
func (m *Model) callMethod(method string) error {
|
||||||
fm := reflect.ValueOf(model).MethodByName(method)
|
fm := reflect.ValueOf(m.Data).MethodByName(method)
|
||||||
if fm.IsValid() {
|
if fm.IsValid() {
|
||||||
v := fm.Call([]reflect.Value{})
|
v := fm.Call([]reflect.Value{})
|
||||||
if verr, ok := v[0].Interface().(error); ok {
|
if len(v) > 0 {
|
||||||
return verr
|
if verr, ok := v[0].Interface().(error); ok {
|
||||||
|
return verr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
4
orm.go
4
orm.go
|
@ -110,9 +110,9 @@ func (s *Orm) Select(value interface{}) *Orm {
|
||||||
func (s *Orm) Save(value interface{}) *Orm {
|
func (s *Orm) Save(value interface{}) *Orm {
|
||||||
s.Model(value)
|
s.Model(value)
|
||||||
if s.model.PrimaryKeyIsEmpty() {
|
if s.model.PrimaryKeyIsEmpty() {
|
||||||
s.explain(value, "Create").create(value)
|
s.create(value)
|
||||||
} else {
|
} else {
|
||||||
s.explain(value, "Update").update(value)
|
s.update(value)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
69
orm_test.go
69
orm_test.go
|
@ -1,6 +1,7 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -14,6 +15,20 @@ type User struct {
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Id int64
|
||||||
|
Code string
|
||||||
|
Price int64
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
BeforeCreateCallTimes int64
|
||||||
|
AfterCreateCallTimes int64
|
||||||
|
BeforeUpdateCallTimes int64
|
||||||
|
AfterUpdateCallTimes int64
|
||||||
|
BeforeSaveCallTimes int64
|
||||||
|
AfterSaveCallTimes int64
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
db DB
|
db DB
|
||||||
t1, t2, t3, t4, t5 time.Time
|
t1, t2, t3, t4, t5 time.Time
|
||||||
|
@ -22,11 +37,13 @@ var (
|
||||||
func init() {
|
func init() {
|
||||||
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
db.Exec("drop table users;")
|
db.Exec("drop table users;")
|
||||||
|
db.Exec("drop table products;")
|
||||||
|
|
||||||
orm := db.CreateTable(&User{})
|
orm := db.CreateTable(&User{})
|
||||||
if orm.Error != nil {
|
if orm.Error != nil {
|
||||||
panic("No error should raise when create table")
|
panic("No error should raise when create table")
|
||||||
}
|
}
|
||||||
|
db.CreateTable(&Product{})
|
||||||
|
|
||||||
var shortForm = "2006-01-02 15:04:05"
|
var shortForm = "2006-01-02 15:04:05"
|
||||||
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
||||||
|
@ -311,3 +328,55 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) {
|
||||||
t.Errorf("Updated At should be changed after update")
|
t.Errorf("Updated At should be changed after update")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Product) BeforeCreate() {
|
||||||
|
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) BeforeUpdate() {
|
||||||
|
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) BeforeSave() {
|
||||||
|
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterCreate() {
|
||||||
|
s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterUpdate() {
|
||||||
|
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterSave() {
|
||||||
|
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Product) GetCallTimes() []int64 {
|
||||||
|
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunCallbacks(t *testing.T) {
|
||||||
|
p := Product{Code: "unique_code", Price: 100}
|
||||||
|
db.Save(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) {
|
||||||
|
t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Where("Code = ?", "unique_code").First(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) {
|
||||||
|
t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Price = 200
|
||||||
|
db.Save(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) {
|
||||||
|
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Where("Code = ?", "unique_code").First(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) {
|
||||||
|
t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
9
sql.go
9
sql.go
|
@ -131,7 +131,7 @@ func (s *Orm) create(value interface{}) {
|
||||||
var id int64
|
var id int64
|
||||||
s.err(s.model.callMethod("BeforeCreate"))
|
s.err(s.model.callMethod("BeforeCreate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
|
s.explain(value, "Create")
|
||||||
if s.driver == "postgres" {
|
if s.driver == "postgres" {
|
||||||
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
||||||
} else {
|
} else {
|
||||||
|
@ -141,12 +141,11 @@ func (s *Orm) create(value interface{}) {
|
||||||
id, err = s.SqlResult.LastInsertId()
|
id, err = s.SqlResult.LastInsertId()
|
||||||
s.err(err)
|
s.err(err)
|
||||||
}
|
}
|
||||||
|
result := reflect.ValueOf(s.model.Data).Elem()
|
||||||
|
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
|
||||||
|
|
||||||
s.err(s.model.callMethod("AfterCreate"))
|
s.err(s.model.callMethod("AfterCreate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
|
|
||||||
result := reflect.ValueOf(s.model.Data).Elem()
|
|
||||||
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Orm) updateSql(value interface{}) {
|
func (s *Orm) updateSql(value interface{}) {
|
||||||
|
@ -169,7 +168,7 @@ func (s *Orm) updateSql(value interface{}) {
|
||||||
func (s *Orm) update(value interface{}) {
|
func (s *Orm) update(value interface{}) {
|
||||||
s.err(s.model.callMethod("BeforeUpdate"))
|
s.err(s.model.callMethod("BeforeUpdate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
s.Exec()
|
s.explain(value, "Update").Exec()
|
||||||
s.err(s.model.callMethod("AfterUpdate"))
|
s.err(s.model.callMethod("AfterUpdate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue