forked from mirror/gorm
Add Before/After callbacks
This commit is contained in:
parent
fa22807e12
commit
e2a360b9fa
|
@ -8,8 +8,36 @@ import (
|
|||
)
|
||||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
// before save
|
||||
// before create
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
ok = true
|
||||
i.BeforeSave(db)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(gorm.BeforeCreateInterface); ok {
|
||||
ok = true
|
||||
i.BeforeCreate(db)
|
||||
}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
|
@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
// after save
|
||||
// after create
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
ok = true
|
||||
i.AfterSave(db)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(gorm.AfterCreateInterface); ok {
|
||||
ok = true
|
||||
i.AfterCreate(db)
|
||||
}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToCreateValues convert to create values
|
||||
|
|
|
@ -1,12 +1,60 @@
|
|||
package callbacks
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.BeforeDelete {
|
||||
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
|
||||
i.BeforeDelete(db)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Delete(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.AfterDelete {
|
||||
if i, ok := value.(gorm.AfterDeleteInterface); ok {
|
||||
i.AfterDelete(db)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
)
|
||||
|
@ -13,7 +15,7 @@ func Query(db *gorm.DB) {
|
|||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
_, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
db.AddError(err)
|
||||
}
|
||||
|
||||
|
@ -21,5 +23,26 @@ func Preload(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// after find
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.AfterFind {
|
||||
if i, ok := value.(gorm.AfterFindInterface); ok {
|
||||
i.AfterFind(db)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,76 @@
|
|||
package callbacks
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
ok = true
|
||||
i.BeforeSave(db)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
|
||||
ok = true
|
||||
i.BeforeUpdate(db)
|
||||
}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Update(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
ok = true
|
||||
i.AfterSave(db)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(gorm.AfterUpdateInterface); ok {
|
||||
ok = true
|
||||
i.AfterUpdate(db)
|
||||
}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
|
@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkComplexSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
|
|
|
@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string,
|
|||
var (
|
||||
buildNames []string
|
||||
buildNamesMap = map[string]bool{}
|
||||
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ func TestExpr(t *testing.T) {
|
|||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
|
||||
if stmt.SQL.String() != result.Result {
|
||||
|
|
|
@ -24,3 +24,39 @@ type CommonDB interface {
|
|||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*DB)
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*DB)
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*DB)
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*DB)
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*DB)
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*DB)
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*DB)
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*DB)
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*DB)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
package schema_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
type UserWithCallback struct {
|
||||
}
|
||||
|
||||
func (UserWithCallback) BeforeSave(*gorm.DB) {
|
||||
}
|
||||
|
||||
func (UserWithCallback) AfterCreate(*gorm.DB) {
|
||||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user with callback, got error %v", err)
|
||||
}
|
||||
|
||||
for _, str := range []string{"BeforeSave", "AfterCreate"} {
|
||||
if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
|
||||
t.Errorf("%v should be true", str)
|
||||
}
|
||||
}
|
||||
|
||||
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
|
||||
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
|
||||
t.Errorf("%v should be false", str)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -15,7 +15,7 @@ type UserCheck struct {
|
|||
}
|
||||
|
||||
func TestParseCheck(t *testing.T) {
|
||||
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user check, got error %v", err)
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
func TestFieldValuerAndSetter(t *testing.T) {
|
||||
var (
|
||||
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user = tests.User{
|
||||
Model: gorm.Model{
|
||||
ID: 10,
|
||||
|
@ -81,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
|
|||
|
||||
func TestPointerFieldValuerAndSetter(t *testing.T) {
|
||||
var (
|
||||
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
name = "pointer_field_valuer_and_setter"
|
||||
age uint = 18
|
||||
active = true
|
||||
|
@ -151,7 +151,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
|
|||
|
||||
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
|
||||
var (
|
||||
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
name = "advanced_data_type_valuer_and_setter"
|
||||
deletedAt = mytime(time.Now())
|
||||
isAdmin = mybool(false)
|
||||
|
|
|
@ -19,7 +19,7 @@ type UserIndex struct {
|
|||
}
|
||||
|
||||
func TestParseIndex(t *testing.T) {
|
||||
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
|
|
|
@ -25,6 +25,11 @@ type Schema struct {
|
|||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
|
||||
Relationships Relationships
|
||||
BeforeCreate, AfterCreate bool
|
||||
BeforeUpdate, AfterUpdate bool
|
||||
BeforeDelete, AfterDelete bool
|
||||
BeforeSave, AfterSave bool
|
||||
AfterFind bool
|
||||
err error
|
||||
namer Namer
|
||||
cacheStore *sync.Map
|
||||
|
@ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
|
|||
}
|
||||
}
|
||||
|
||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
||||
for _, name := range callbacks {
|
||||
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
|
||||
switch methodValue.Type().String() {
|
||||
case "func(*gorm.DB)": // TODO hack
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
|
||||
default:
|
||||
logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cacheStore.Store(modelType, schema)
|
||||
|
||||
// parse relations for unidentified fields
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func TestParseSchema(t *testing.T) {
|
||||
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user, got error %v", err)
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestParseSchemaWithPointerFields(t *testing.T) {
|
||||
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
|||
}
|
||||
|
||||
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue