Add Before/After callbacks

This commit is contained in:
Jinzhu 2020-02-23 21:22:35 +08:00
parent fa22807e12
commit e2a360b9fa
14 changed files with 325 additions and 43 deletions

View File

@ -8,8 +8,36 @@ import (
) )
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
// before save if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
// before create callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok {
ok = true
i.BeforeCreate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
// after save if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
// after create callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok {
ok = true
i.AfterCreate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
// ConvertToCreateValues convert to create values // ConvertToCreateValues convert to create values

View File

@ -1,12 +1,60 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"reflect"
"github.com/jinzhu/gorm"
)
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.BeforeDelete {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
i.BeforeDelete(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
func Delete(db *gorm.DB) { func Delete(db *gorm.DB) {
} }
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterDelete {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
i.AfterDelete(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }

View File

@ -1,6 +1,8 @@
package callbacks package callbacks
import ( import (
"reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
) )
@ -13,7 +15,7 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
} }
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) _, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.AddError(err) db.AddError(err)
} }
@ -21,5 +23,26 @@ func Preload(db *gorm.DB) {
} }
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
// after find if db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterFind {
if i, ok := value.(gorm.AfterFindInterface); ok {
i.AfterFind(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }

View File

@ -1,12 +1,76 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"reflect"
"github.com/jinzhu/gorm"
)
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
ok = true
i.BeforeUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
func Update(db *gorm.DB) { func Update(db *gorm.DB) {
} }
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
ok = true
i.AfterUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }

View File

@ -11,7 +11,7 @@ import (
) )
func BenchmarkSelect(b *testing.B) { func BenchmarkSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) {
} }
func BenchmarkComplexSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}

View File

@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string,
var ( var (
buildNames []string buildNames []string
buildNamesMap = map[string]bool{} buildNamesMap = map[string]bool{}
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
) )

View File

@ -24,7 +24,7 @@ func TestExpr(t *testing.T) {
for idx, result := range results { for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result { if stmt.SQL.String() != result.Result {

View File

@ -24,3 +24,39 @@ type CommonDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
type BeforeCreateInterface interface {
BeforeCreate(*DB)
}
type AfterCreateInterface interface {
AfterCreate(*DB)
}
type BeforeUpdateInterface interface {
BeforeUpdate(*DB)
}
type AfterUpdateInterface interface {
AfterUpdate(*DB)
}
type BeforeSaveInterface interface {
BeforeSave(*DB)
}
type AfterSaveInterface interface {
AfterSave(*DB)
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB)
}
type AfterDeleteInterface interface {
AfterDelete(*DB)
}
type AfterFindInterface interface {
AfterFind(*DB)
}

38
schema/callbacks_test.go Normal file
View File

@ -0,0 +1,38 @@
package schema_test
import (
"reflect"
"sync"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
)
type UserWithCallback struct {
}
func (UserWithCallback) BeforeSave(*gorm.DB) {
}
func (UserWithCallback) AfterCreate(*gorm.DB) {
}
func TestCallback(t *testing.T) {
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with callback, got error %v", err)
}
for _, str := range []string{"BeforeSave", "AfterCreate"} {
if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be true", str)
}
}
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be false", str)
}
}
}

View File

@ -15,7 +15,7 @@ type UserCheck struct {
} }
func TestParseCheck(t *testing.T) { func TestParseCheck(t *testing.T) {
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user check, got error %v", err) t.Fatalf("failed to parse user check, got error %v", err)
} }

View File

@ -14,7 +14,7 @@ import (
func TestFieldValuerAndSetter(t *testing.T) { func TestFieldValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user = tests.User{ user = tests.User{
Model: gorm.Model{ Model: gorm.Model{
ID: 10, ID: 10,
@ -81,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
func TestPointerFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter" name = "pointer_field_valuer_and_setter"
age uint = 18 age uint = 18
active = true active = true
@ -151,7 +151,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
name = "advanced_data_type_valuer_and_setter" name = "advanced_data_type_valuer_and_setter"
deletedAt = mytime(time.Now()) deletedAt = mytime(time.Now())
isAdmin = mybool(false) isAdmin = mybool(false)

View File

@ -19,7 +19,7 @@ type UserIndex struct {
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user index, got error %v", err) t.Fatalf("failed to parse user index, got error %v", err)
} }

View File

@ -25,6 +25,11 @@ type Schema struct {
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
Relationships Relationships Relationships Relationships
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error err error
namer Namer namer Namer
cacheStore *sync.Map cacheStore *sync.Map
@ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
} }
} }
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB)": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
}
}
}
cacheStore.Store(modelType, schema) cacheStore.Store(modelType, schema)
// parse relations for unidentified fields // parse relations for unidentified fields

View File

@ -9,7 +9,7 @@ import (
) )
func TestParseSchema(t *testing.T) { func TestParseSchema(t *testing.T) {
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user, got error %v", err) t.Fatalf("failed to parse user, got error %v", err)
} }
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
} }
func TestParseSchemaWithPointerFields(t *testing.T) { func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err) t.Fatalf("failed to parse pointer user, got error %v", err)
} }
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
} }
func TestParseSchemaWithAdvancedDataType(t *testing.T) { func TestParseSchemaWithAdvancedDataType(t *testing.T) {
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err) t.Fatalf("failed to parse pointer user, got error %v", err)
} }