Add Before/After callbacks

This commit is contained in:
Jinzhu 2020-02-23 21:22:35 +08:00
parent fa22807e12
commit e2a360b9fa
14 changed files with 325 additions and 43 deletions

View File

@ -8,8 +8,36 @@ import (
) )
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
// before save if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
// before create callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok {
ok = true
i.BeforeCreate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
// after save if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
// after create callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok {
ok = true
i.AfterCreate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
// ConvertToCreateValues convert to create values // ConvertToCreateValues convert to create values

View File

@ -1,12 +1,60 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"reflect"
"github.com/jinzhu/gorm"
)
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.BeforeDelete {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
i.BeforeDelete(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
func Delete(db *gorm.DB) { func Delete(db *gorm.DB) {
} }
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterDelete {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
i.AfterDelete(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }

View File

@ -1,6 +1,8 @@
package callbacks package callbacks
import ( import (
"reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
) )
@ -13,7 +15,7 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
} }
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) _, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.AddError(err) db.AddError(err)
} }
@ -21,5 +23,26 @@ func Preload(db *gorm.DB) {
} }
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
// after find if db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterFind {
if i, ok := value.(gorm.AfterFindInterface); ok {
i.AfterFind(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }

View File

@ -1,12 +1,76 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"reflect"
"github.com/jinzhu/gorm"
)
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
ok = true
i.BeforeUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }
func Update(db *gorm.DB) { func Update(db *gorm.DB) {
} }
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
ok = true
i.AfterUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
} }

View File

@ -11,7 +11,7 @@ import (
) )
func BenchmarkSelect(b *testing.B) { func BenchmarkSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) {
} }
func BenchmarkComplexSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}

View File

@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string,
var ( var (
buildNames []string buildNames []string
buildNamesMap = map[string]bool{} buildNamesMap = map[string]bool{}
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
) )

View File

@ -24,7 +24,7 @@ func TestExpr(t *testing.T) {
for idx, result := range results { for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result { if stmt.SQL.String() != result.Result {

View File

@ -24,3 +24,39 @@ type CommonDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
type BeforeCreateInterface interface {
BeforeCreate(*DB)
}
type AfterCreateInterface interface {
AfterCreate(*DB)
}
type BeforeUpdateInterface interface {
BeforeUpdate(*DB)
}
type AfterUpdateInterface interface {
AfterUpdate(*DB)
}
type BeforeSaveInterface interface {
BeforeSave(*DB)
}
type AfterSaveInterface interface {
AfterSave(*DB)
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB)
}
type AfterDeleteInterface interface {
AfterDelete(*DB)
}
type AfterFindInterface interface {
AfterFind(*DB)
}

38
schema/callbacks_test.go Normal file
View File

@ -0,0 +1,38 @@
package schema_test
import (
"reflect"
"sync"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
)
type UserWithCallback struct {
}
func (UserWithCallback) BeforeSave(*gorm.DB) {
}
func (UserWithCallback) AfterCreate(*gorm.DB) {
}
func TestCallback(t *testing.T) {
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with callback, got error %v", err)
}
for _, str := range []string{"BeforeSave", "AfterCreate"} {
if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be true", str)
}
}
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be false", str)
}
}
}

View File

@ -15,7 +15,7 @@ type UserCheck struct {
} }
func TestParseCheck(t *testing.T) { func TestParseCheck(t *testing.T) {
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user check, got error %v", err) t.Fatalf("failed to parse user check, got error %v", err)
} }

View File

@ -14,8 +14,8 @@ import (
func TestFieldValuerAndSetter(t *testing.T) { func TestFieldValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user = tests.User{ user = tests.User{
Model: gorm.Model{ Model: gorm.Model{
ID: 10, ID: 10,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
func TestPointerFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter" name = "pointer_field_valuer_and_setter"
age uint = 18 age uint = 18
active = true active = true
user = User{ user = User{
Model: &gorm.Model{ Model: &gorm.Model{
ID: 10, ID: 10,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
name = "advanced_data_type_valuer_and_setter" name = "advanced_data_type_valuer_and_setter"
deletedAt = mytime(time.Now()) deletedAt = mytime(time.Now())
isAdmin = mybool(false) isAdmin = mybool(false)
user = AdvancedDataTypeUser{ user = AdvancedDataTypeUser{
ID: sql.NullInt64{Int64: 10, Valid: true}, ID: sql.NullInt64{Int64: 10, Valid: true},
Name: &sql.NullString{String: name, Valid: true}, Name: &sql.NullString{String: name, Valid: true},
Birthday: sql.NullTime{Time: time.Now(), Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true},

View File

@ -19,7 +19,7 @@ type UserIndex struct {
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user index, got error %v", err) t.Fatalf("failed to parse user index, got error %v", err)
} }

View File

@ -14,20 +14,25 @@ import (
var ErrUnsupportedDataType = errors.New("unsupported data type") var ErrUnsupportedDataType = errors.New("unsupported data type")
type Schema struct { type Schema struct {
Name string Name string
ModelType reflect.Type ModelType reflect.Type
Table string Table string
PrioritizedPrimaryField *Field PrioritizedPrimaryField *Field
DBNames []string DBNames []string
PrimaryFields []*Field PrimaryFields []*Field
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
Relationships Relationships Relationships Relationships
err error BeforeCreate, AfterCreate bool
namer Namer BeforeUpdate, AfterUpdate bool
cacheStore *sync.Map BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error
namer Namer
cacheStore *sync.Map
} }
func (schema Schema) String() string { func (schema Schema) String() string {
@ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
} }
} }
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB)": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
}
}
}
cacheStore.Store(modelType, schema) cacheStore.Store(modelType, schema)
// parse relations for unidentified fields // parse relations for unidentified fields

View File

@ -9,7 +9,7 @@ import (
) )
func TestParseSchema(t *testing.T) { func TestParseSchema(t *testing.T) {
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user, got error %v", err) t.Fatalf("failed to parse user, got error %v", err)
} }
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
} }
func TestParseSchemaWithPointerFields(t *testing.T) { func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err) t.Fatalf("failed to parse pointer user, got error %v", err)
} }
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
} }
func TestParseSchemaWithAdvancedDataType(t *testing.T) { func TestParseSchemaWithAdvancedDataType(t *testing.T) {
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err) t.Fatalf("failed to parse pointer user, got error %v", err)
} }