Add Before/After callbacks

This commit is contained in:
Jinzhu 2020-02-23 21:22:35 +08:00
parent fa22807e12
commit e2a360b9fa
14 changed files with 325 additions and 43 deletions

View File

@ -8,8 +8,36 @@ import (
)
func BeforeCreate(db *gorm.DB) {
// before save
// before create
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok {
ok = true
i.BeforeCreate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
func SaveBeforeAssociations(db *gorm.DB) {
@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) {
}
func AfterCreate(db *gorm.DB) {
// after save
// after create
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok {
ok = true
i.AfterCreate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
// ConvertToCreateValues convert to create values

View File

@ -1,12 +1,60 @@
package callbacks
import "github.com/jinzhu/gorm"
import (
"reflect"
"github.com/jinzhu/gorm"
)
func BeforeDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.BeforeDelete {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
i.BeforeDelete(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
func Delete(db *gorm.DB) {
}
func AfterDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterDelete {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
i.AfterDelete(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

View File

@ -1,6 +1,8 @@
package callbacks
import (
"reflect"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
)
@ -13,7 +15,7 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
_, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.AddError(err)
}
@ -21,5 +23,26 @@ func Preload(db *gorm.DB) {
}
func AfterQuery(db *gorm.DB) {
// after find
if db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterFind {
if i, ok := value.(gorm.AfterFindInterface); ok {
i.AfterFind(db)
return true
}
}
return false
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

View File

@ -1,12 +1,76 @@
package callbacks
import "github.com/jinzhu/gorm"
import (
"reflect"
"github.com/jinzhu/gorm"
)
func BeforeUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
ok = true
i.BeforeUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
func Update(db *gorm.DB) {
}
func AfterUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
ok = true
i.AfterUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

View File

@ -11,7 +11,7 @@ import (
)
func BenchmarkSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) {
}
func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}

View File

@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string,
var (
buildNames []string
buildNamesMap = map[string]bool{}
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
)

View File

@ -24,7 +24,7 @@ func TestExpr(t *testing.T) {
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {

View File

@ -24,3 +24,39 @@ type CommonDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type BeforeCreateInterface interface {
BeforeCreate(*DB)
}
type AfterCreateInterface interface {
AfterCreate(*DB)
}
type BeforeUpdateInterface interface {
BeforeUpdate(*DB)
}
type AfterUpdateInterface interface {
AfterUpdate(*DB)
}
type BeforeSaveInterface interface {
BeforeSave(*DB)
}
type AfterSaveInterface interface {
AfterSave(*DB)
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB)
}
type AfterDeleteInterface interface {
AfterDelete(*DB)
}
type AfterFindInterface interface {
AfterFind(*DB)
}

38
schema/callbacks_test.go Normal file
View File

@ -0,0 +1,38 @@
package schema_test
import (
"reflect"
"sync"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
)
type UserWithCallback struct {
}
func (UserWithCallback) BeforeSave(*gorm.DB) {
}
func (UserWithCallback) AfterCreate(*gorm.DB) {
}
func TestCallback(t *testing.T) {
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with callback, got error %v", err)
}
for _, str := range []string{"BeforeSave", "AfterCreate"} {
if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be true", str)
}
}
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be false", str)
}
}
}

View File

@ -15,7 +15,7 @@ type UserCheck struct {
}
func TestParseCheck(t *testing.T) {
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user check, got error %v", err)
}

View File

@ -14,8 +14,8 @@ import (
func TestFieldValuerAndSetter(t *testing.T) {
var (
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user = tests.User{
userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user = tests.User{
Model: gorm.Model{
ID: 10,
CreatedAt: time.Now(),
@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
func TestPointerFieldValuerAndSetter(t *testing.T) {
var (
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter"
age uint = 18
active = true
user = User{
userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter"
age uint = 18
active = true
user = User{
Model: &gorm.Model{
ID: 10,
CreatedAt: time.Now(),
@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
var (
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
name = "advanced_data_type_valuer_and_setter"
deletedAt = mytime(time.Now())
isAdmin = mybool(false)
user = AdvancedDataTypeUser{
userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
name = "advanced_data_type_valuer_and_setter"
deletedAt = mytime(time.Now())
isAdmin = mybool(false)
user = AdvancedDataTypeUser{
ID: sql.NullInt64{Int64: 10, Valid: true},
Name: &sql.NullString{String: name, Valid: true},
Birthday: sql.NullTime{Time: time.Now(), Valid: true},

View File

@ -19,7 +19,7 @@ type UserIndex struct {
}
func TestParseIndex(t *testing.T) {
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}

View File

@ -14,20 +14,25 @@ import (
var ErrUnsupportedDataType = errors.New("unsupported data type")
type Schema struct {
Name string
ModelType reflect.Type
Table string
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
Fields []*Field
FieldsByName map[string]*Field
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
Relationships Relationships
err error
namer Namer
cacheStore *sync.Map
Name string
ModelType reflect.Type
Table string
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
Fields []*Field
FieldsByName map[string]*Field
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
Relationships Relationships
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error
namer Namer
cacheStore *sync.Map
}
func (schema Schema) String() string {
@ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
}
}
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB)": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
}
}
}
cacheStore.Store(modelType, schema)
// parse relations for unidentified fields

View File

@ -9,7 +9,7 @@ import (
)
func TestParseSchema(t *testing.T) {
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user, got error %v", err)
}
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
}
func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
}
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}