Improve Hooks

This commit is contained in:
Jinzhu 2020-05-31 23:55:56 +08:00
parent 95a6539331
commit befef0c9a9
16 changed files with 610 additions and 320 deletions

View File

@ -10,7 +10,7 @@ import (
) )
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
if db.Statement.Schema != nil { if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
// Save Belongs To associations // Save Belongs To associations
@ -83,7 +83,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
} }
func SaveAfterAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) {
if db.Statement.Schema != nil { if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
// Save Has One associations // Save Has One associations

View File

@ -9,20 +9,21 @@ import (
) )
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool { callMethod := func(value interface{}) bool {
var ok bool var ok bool
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true ok = true
i.BeforeSave(db) db.AddError(i.BeforeSave(tx))
} }
} }
if db.Statement.Schema.BeforeCreate { if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok { if i, ok := value.(gorm.BeforeCreateInterface); ok {
ok = true ok = true
i.BeforeCreate(db) db.AddError(i.BeforeCreate(tx))
} }
} }
return ok return ok
@ -31,7 +32,7 @@ func BeforeCreate(db *gorm.DB) {
if ok := callMethod(db.Statement.Dest); !ok { if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface()) callMethod(db.Statement.ReflectValue.Index(i).Interface())
} }
case reflect.Struct: case reflect.Struct:
@ -46,6 +47,7 @@ func Create(config *Config) func(db *gorm.DB) {
return CreateWithReturning return CreateWithReturning
} else { } else {
return func(db *gorm.DB) { return func(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses { for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
@ -95,8 +97,10 @@ func Create(config *Config) func(db *gorm.DB) {
} }
} }
} }
}
func CreateWithReturning(db *gorm.DB) { func CreateWithReturning(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses { for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
@ -170,22 +174,24 @@ func CreateWithReturning(db *gorm.DB) {
} }
} }
} }
}
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool { callMethod := func(value interface{}) bool {
var ok bool var ok bool
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok { if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true ok = true
i.AfterSave(db) db.AddError(i.AfterSave(tx))
} }
} }
if db.Statement.Schema.AfterCreate { if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok { if i, ok := value.(gorm.AfterCreateInterface); ok {
ok = true ok = true
i.AfterCreate(db) db.AddError(i.AfterCreate(tx))
} }
} }
return ok return ok
@ -194,7 +200,7 @@ func AfterCreate(db *gorm.DB) {
if ok := callMethod(db.Statement.Dest); !ok { if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface()) callMethod(db.Statement.ReflectValue.Index(i).Interface())
} }
case reflect.Struct: case reflect.Struct:

View File

@ -9,11 +9,12 @@ import (
) )
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool { callMethod := func(value interface{}) bool {
if db.Statement.Schema.BeforeDelete { if db.Statement.Schema.BeforeDelete {
if i, ok := value.(gorm.BeforeDeleteInterface); ok { if i, ok := value.(gorm.BeforeDeleteInterface); ok {
i.BeforeDelete(db) db.AddError(i.BeforeDelete(tx))
return true return true
} }
} }
@ -23,7 +24,7 @@ func BeforeDelete(db *gorm.DB) {
if ok := callMethod(db.Statement.Dest); !ok { if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface()) callMethod(db.Statement.ReflectValue.Index(i).Interface())
} }
case reflect.Struct: case reflect.Struct:
@ -34,6 +35,7 @@ func BeforeDelete(db *gorm.DB) {
} }
func Delete(db *gorm.DB) { func Delete(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.DeleteClauses { for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
@ -78,13 +80,15 @@ func Delete(db *gorm.DB) {
db.AddError(err) db.AddError(err)
} }
} }
}
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool { callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterDelete { if db.Statement.Schema.AfterDelete {
if i, ok := value.(gorm.AfterDeleteInterface); ok { if i, ok := value.(gorm.AfterDeleteInterface); ok {
i.AfterDelete(db) db.AddError(i.AfterDelete(tx))
return true return true
} }
} }
@ -94,7 +98,7 @@ func AfterDelete(db *gorm.DB) {
if ok := callMethod(db.Statement.Dest); !ok { if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface()) callMethod(db.Statement.ReflectValue.Index(i).Interface())
} }
case reflect.Struct: case reflect.Struct:

View File

@ -12,6 +12,7 @@ import (
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.QueryClauses { for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
@ -31,6 +32,7 @@ func Query(db *gorm.DB) {
gorm.Scan(rows, db, false) gorm.Scan(rows, db, false)
} }
}
func BuildQuerySQL(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) {
clauseSelect := clause.Select{} clauseSelect := clause.Select{}
@ -129,6 +131,7 @@ func BuildQuerySQL(db *gorm.DB) {
} }
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {
if db.Error == nil {
if len(db.Statement.Preloads) > 0 { if len(db.Statement.Preloads) > 0 {
preloadMap := map[string][]string{} preloadMap := map[string][]string{}
for name := range db.Statement.Preloads { for name := range db.Statement.Preloads {
@ -166,13 +169,15 @@ func Preload(db *gorm.DB) {
} }
} }
} }
}
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.AfterFind { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool { callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterFind { if db.Statement.Schema.AfterFind {
if i, ok := value.(gorm.AfterFindInterface); ok { if i, ok := value.(gorm.AfterFindInterface); ok {
i.AfterFind(db) db.AddError(i.AfterFind(tx))
return true return true
} }
} }
@ -182,7 +187,7 @@ func AfterQuery(db *gorm.DB) {
if ok := callMethod(db.Statement.Dest); !ok { if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface()) callMethod(db.Statement.ReflectValue.Index(i).Interface())
} }
case reflect.Struct: case reflect.Struct:

View File

@ -5,6 +5,7 @@ import (
) )
func RawExec(db *gorm.DB) { func RawExec(db *gorm.DB) {
if db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
@ -12,3 +13,4 @@ func RawExec(db *gorm.DB) {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} }
} }
}

View File

@ -5,6 +5,7 @@ import (
) )
func RowQuery(db *gorm.DB) { func RowQuery(db *gorm.DB) {
if db.Error == nil {
if db.Statement.SQL.String() == "" { if db.Statement.SQL.String() == "" {
BuildQuerySQL(db) BuildQuerySQL(db)
} }
@ -15,3 +16,4 @@ func RowQuery(db *gorm.DB) {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} }
} }
}

View File

@ -1,9 +1,25 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"github.com/jinzhu/gorm"
)
func BeginTransaction(db *gorm.DB) { func BeginTransaction(db *gorm.DB) {
if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool
tx.InstanceSet("gorm:started_transaction", true)
} else {
tx.Error = nil
}
} }
func CommitOrRollbackTransaction(db *gorm.DB) { func CommitOrRollbackTransaction(db *gorm.DB) {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error == nil {
db.Commit()
} else {
db.Rollback()
}
db.Statement.ConnPool = db.ConnPool
}
} }

View File

@ -10,20 +10,21 @@ import (
) )
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool { callMethod := func(value interface{}) bool {
var ok bool var ok bool
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true ok = true
i.BeforeSave(db) db.AddError(i.BeforeSave(tx))
} }
} }
if db.Statement.Schema.BeforeUpdate { if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok { if i, ok := value.(gorm.BeforeUpdateInterface); ok {
ok = true ok = true
i.BeforeUpdate(db) db.AddError(i.BeforeUpdate(tx))
} }
} }
return ok return ok
@ -32,7 +33,7 @@ func BeforeUpdate(db *gorm.DB) {
if ok := callMethod(db.Statement.Dest); !ok { if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface()) callMethod(db.Statement.ReflectValue.Index(i).Interface())
} }
case reflect.Struct: case reflect.Struct:
@ -43,6 +44,7 @@ func BeforeUpdate(db *gorm.DB) {
} }
func Update(db *gorm.DB) { func Update(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.UpdateClauses { for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
@ -72,22 +74,24 @@ func Update(db *gorm.DB) {
db.AddError(err) db.AddError(err)
} }
} }
}
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool { callMethod := func(value interface{}) bool {
var ok bool var ok bool
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok { if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true ok = true
i.AfterSave(db) db.AddError(i.AfterSave(tx))
} }
} }
if db.Statement.Schema.AfterUpdate { if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok { if i, ok := value.(gorm.AfterUpdateInterface); ok {
ok = true ok = true
i.AfterUpdate(db) db.AddError(i.AfterUpdate(tx))
} }
} }
return ok return ok
@ -96,7 +100,7 @@ func AfterUpdate(db *gorm.DB) {
if ok := callMethod(db.Statement.Dest); !ok { if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface()) callMethod(db.Statement.ReflectValue.Index(i).Interface())
} }
case reflect.Struct: case reflect.Struct:

View File

@ -16,7 +16,7 @@ var (
// ErrNotImplemented not implemented // ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented") ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause // ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") ErrMissingWhereClause = errors.New("WHERE conditions required")
// ErrUnsupportedRelation unsupported relations // ErrUnsupportedRelation unsupported relations
ErrUnsupportedRelation = errors.New("unsupported relations") ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPtrStructSupported only ptr of struct supported // ErrPtrStructSupported only ptr of struct supported

100
gorm.go
View File

@ -40,11 +40,12 @@ type DB struct {
Error error Error error
RowsAffected int64 RowsAffected int64
Statement *Statement Statement *Statement
clone bool clone int
} }
// Session session config when create session with Session() method // Session session config when create session with Session() method
type Session struct { type Session struct {
WithConditions bool
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
NowFunc func() time.Time NowFunc func() time.Time
@ -76,10 +77,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
config.cacheStore = &sync.Map{} config.cacheStore = &sync.Map{}
} }
db = &DB{ db = &DB{Config: config, clone: 1}
Config: config,
clone: true,
}
db.callbacks = initializeCallbacks(db) db.callbacks = initializeCallbacks(db)
@ -96,38 +94,54 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
// Session create new db session // Session create new db session
func (db *DB) Session(config *Session) *DB { func (db *DB) Session(config *Session) *DB {
var ( var (
tx = db.getInstance() txConfig = *db.Config
stmt = tx.Statement.clone() tx = &DB{
txConfig = *tx.Config Config: &txConfig,
Statement: db.Statement,
clone: 1,
}
) )
if config.Context != nil { if config.Context != nil {
stmt.Context = config.Context if tx.Statement != nil {
tx.Statement = tx.Statement.clone()
} else {
tx.Statement = &Statement{
DB: tx,
Clauses: map[string]clause.Clause{},
ConnPool: tx.ConnPool,
}
}
tx.Statement.Context = config.Context
}
if config.WithConditions {
tx.clone = 3
} }
if config.Logger != nil { if config.Logger != nil {
txConfig.Logger = config.Logger tx.Config.Logger = config.Logger
} }
if config.NowFunc != nil { if config.NowFunc != nil {
txConfig.NowFunc = config.NowFunc tx.Config.NowFunc = config.NowFunc
} }
return &DB{ return tx
Config: &txConfig,
Statement: stmt,
clone: true,
}
} }
// WithContext change current instance db's context to ctx // WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB { func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx}) return db.Session(&Session{WithConditions: true, Context: ctx})
} }
// Debug start debug mode // Debug start debug mode
func (db *DB) Debug() (tx *DB) { func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) return db.Session(&Session{
WithConditions: true,
Logger: db.Logger.LogMode(logger.Info),
})
} }
// Set store value with key into current db instance's context // Set store value with key into current db instance's context
@ -145,6 +159,21 @@ func (db *DB) Get(key string) (interface{}, bool) {
return nil, false return nil, false
} }
// InstanceSet store value with key into current db instance's context
func (db *DB) InstanceSet(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
return tx
}
// InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) {
if db.Statement != nil {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
return nil, false
}
// Callback returns callback manager // Callback returns callback manager
func (db *DB) Callback() *callbacks { func (db *DB) Callback() *callbacks {
return db.callbacks return db.callbacks
@ -166,18 +195,37 @@ func (db *DB) AddError(err error) error {
} }
func (db *DB) getInstance() *DB { func (db *DB) getInstance() *DB {
if db.clone { if db.clone > 0 {
stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}} tx := &DB{Config: db.Config}
switch db.clone {
case 1: // clone with new statement
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
db.clone = 1
tx.Statement = db.Statement
case 3: // with clone statement
if db.Statement != nil { if db.Statement != nil {
stmt.Context = db.Statement.Context tx.Statement = db.Statement.clone()
stmt.ConnPool = db.Statement.ConnPool tx.Statement.DB = tx
} else { }
stmt.Context = context.Background()
stmt.ConnPool = db.ConnPool
} }
return &DB{Config: db.Config, Statement: stmt} if tx.Statement == nil {
tx.Statement = &Statement{
DB: tx,
Clauses: map[string]clause.Clause{},
}
}
if db.Statement != nil {
tx.Statement.Context = db.Statement.Context
tx.Statement.ConnPool = db.Statement.ConnPool
} else {
tx.Statement.Context = context.Background()
tx.Statement.ConnPool = db.ConnPool
}
return tx
} }
return db return db

View File

@ -36,37 +36,37 @@ type TxCommiter interface {
} }
type BeforeCreateInterface interface { type BeforeCreateInterface interface {
BeforeCreate(*DB) BeforeCreate(*DB) error
} }
type AfterCreateInterface interface { type AfterCreateInterface interface {
AfterCreate(*DB) AfterCreate(*DB) error
} }
type BeforeUpdateInterface interface { type BeforeUpdateInterface interface {
BeforeUpdate(*DB) BeforeUpdate(*DB) error
} }
type AfterUpdateInterface interface { type AfterUpdateInterface interface {
AfterUpdate(*DB) AfterUpdate(*DB) error
} }
type BeforeSaveInterface interface { type BeforeSaveInterface interface {
BeforeSave(*DB) BeforeSave(*DB) error
} }
type AfterSaveInterface interface { type AfterSaveInterface interface {
AfterSave(*DB) AfterSave(*DB) error
} }
type BeforeDeleteInterface interface { type BeforeDeleteInterface interface {
BeforeDelete(*DB) BeforeDelete(*DB) error
} }
type AfterDeleteInterface interface { type AfterDeleteInterface interface {
AfterDelete(*DB) AfterDelete(*DB) error
} }
type AfterFindInterface interface { type AfterFindInterface interface {
AfterFind(*DB) AfterFind(*DB) error
} }

View File

@ -12,10 +12,12 @@ import (
type UserWithCallback struct { type UserWithCallback struct {
} }
func (UserWithCallback) BeforeSave(*gorm.DB) { func (UserWithCallback) BeforeSave(*gorm.DB) error {
return nil
} }
func (UserWithCallback) AfterCreate(*gorm.DB) { func (UserWithCallback) AfterCreate(*gorm.DB) error {
return nil
} }
func TestCallback(t *testing.T) { func TestCallback(t *testing.T) {

View File

@ -200,12 +200,12 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
} }
} }
reflectValue := reflect.Indirect(reflect.New(modelType)) reflectValue := reflect.New(modelType)
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks { for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
switch methodValue.Type().String() { switch methodValue.Type().String() {
case "func(*gorm.DB)": // TODO hack case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default: default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)

201
tests/hooks_test.go Normal file
View File

@ -0,0 +1,201 @@
package tests_test
import (
"errors"
"reflect"
"testing"
"github.com/jinzhu/gorm"
. "github.com/jinzhu/gorm/tests"
)
type Product struct {
gorm.Model
Name string
Code string
Price float64
AfterFindCallTimes int64
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
}
func (s *Product) BeforeCreate(tx *gorm.DB) (err error) {
if s.Code == "Invalid" {
err = errors.New("invalid product")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
}
func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) {
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
}
func (s *Product) BeforeSave(tx *gorm.DB) (err error) {
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
}
func (s *Product) AfterFind(tx *gorm.DB) (err error) {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
return
}
func (s *Product) AfterCreate(tx *gorm.DB) (err error) {
return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error
}
func (s *Product) AfterUpdate(tx *gorm.DB) (err error) {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
return
}
func (s *Product) AfterSave(tx *gorm.DB) (err error) {
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}
func (s *Product) BeforeDelete(tx *gorm.DB) (err error) {
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
}
func (s *Product) AfterDelete(tx *gorm.DB) (err error) {
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}
func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
}
func TestRunCallbacks(t *testing.T) {
DB.Migrator().DropTable(&Product{})
DB.AutoMigrate(&Product{})
p := Product{Code: "unique_code", Price: 100}
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes())
}
p.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}
var products []Product
DB.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 1 {
t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes)
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes())
}
DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Fatalf("Can't find a deleted record")
}
}
func TestCallbacksWithErrors(t *testing.T) {
DB.Migrator().DropTable(&Product{})
DB.AutoMigrate(&Product{})
p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil {
t.Fatalf("An error from before create callbacks happened when create with invalid value")
}
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Fatalf("Should not save record that have errors")
}
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Fatalf("An error from after create callbacks happened when create with invalid value")
}
p2 := Product{Code: "update_callback", Price: 100}
DB.Save(&p2)
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
t.Fatalf("An error from before update callbacks happened when update with invalid value")
}
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Fatalf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Fatalf("Record Should not be updated due to errors happened in before update callback")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
t.Fatalf("An error from before save callbacks happened when update with invalid value")
}
p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
t.Fatalf("An error from before delete callbacks happened when delete")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Fatalf("An error from before delete callbacks happened")
}
p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Fatalf("Record should be reverted if get an error in after save callback")
}
p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Fatalf("Record should be found")
}
DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}

View File

@ -59,9 +59,9 @@ func OpenTestConnection() (db *gorm.DB, err error) {
} }
if debug := os.Getenv("DEBUG"); debug == "true" { if debug := os.Getenv("DEBUG"); debug == "true" {
db.Logger.LogMode(logger.Info) db.Logger = db.Logger.LogMode(logger.Info)
} else if debug == "false" { } else if debug == "false" {
db.Logger.LogMode(logger.Silent) db.Logger = db.Logger.LogMode(logger.Silent)
} }
return return

View File

@ -14,37 +14,37 @@ func TestTransaction(t *testing.T) {
user := *GetUser("transcation", Config{}) user := *GetUser("transcation", Config{})
if err := tx.Save(&user).Error; err != nil { if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise, but got %v", err) t.Fatalf("No error should raise, but got %v", err)
} }
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record, but got %v", err) t.Fatalf("Should find saved record, but got %v", err)
} }
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
t.Errorf("Should return the underlying sql.Tx") t.Fatalf("Should return the underlying sql.Tx")
} }
tx.Rollback() tx.Rollback()
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback, but got %v", err) t.Fatalf("Should not find record after rollback, but got %v", err)
} }
tx2 := DB.Begin() tx2 := DB.Begin()
user2 := *GetUser("transcation-2", Config{}) user2 := *GetUser("transcation-2", Config{})
if err := tx2.Save(&user2).Error; err != nil { if err := tx2.Save(&user2).Error; err != nil {
t.Errorf("No error should raise, but got %v", err) t.Fatalf("No error should raise, but got %v", err)
} }
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record, but got %v", err) t.Fatalf("Should find saved record, but got %v", err)
} }
tx2.Commit() tx2.Commit()
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record, but got %v", err) t.Fatalf("Should be able to find committed record, but got %v", err)
} }
} }
@ -52,7 +52,7 @@ func TestTransactionWithBlock(t *testing.T) {
assertPanic := func(f func()) { assertPanic := func(f func()) {
defer func() { defer func() {
if r := recover(); r == nil { if r := recover(); r == nil {
t.Errorf("The code did not panic") t.Fatalf("The code did not panic")
} }
}() }()
f() f()
@ -62,39 +62,39 @@ func TestTransactionWithBlock(t *testing.T) {
err := DB.Transaction(func(tx *gorm.DB) error { err := DB.Transaction(func(tx *gorm.DB) error {
user := *GetUser("transcation-block", Config{}) user := *GetUser("transcation-block", Config{})
if err := tx.Save(&user).Error; err != nil { if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise") t.Fatalf("No error should raise")
} }
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should find saved record") t.Fatalf("Should find saved record")
} }
return errors.New("the error message") return errors.New("the error message")
}) })
if err.Error() != "the error message" { if err.Error() != "the error message" {
t.Errorf("Transaction return error will equal the block returns error") t.Fatalf("Transaction return error will equal the block returns error")
} }
if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil {
t.Errorf("Should not find record after rollback") t.Fatalf("Should not find record after rollback")
} }
// commit // commit
DB.Transaction(func(tx *gorm.DB) error { DB.Transaction(func(tx *gorm.DB) error {
user := *GetUser("transcation-block-2", Config{}) user := *GetUser("transcation-block-2", Config{})
if err := tx.Save(&user).Error; err != nil { if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise") t.Fatalf("No error should raise")
} }
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should find saved record") t.Fatalf("Should find saved record")
} }
return nil return nil
}) })
if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil {
t.Errorf("Should be able to find committed record") t.Fatalf("Should be able to find committed record")
} }
// panic will rollback // panic will rollback
@ -102,11 +102,11 @@ func TestTransactionWithBlock(t *testing.T) {
DB.Transaction(func(tx *gorm.DB) error { DB.Transaction(func(tx *gorm.DB) error {
user := *GetUser("transcation-block-3", Config{}) user := *GetUser("transcation-block-3", Config{})
if err := tx.Save(&user).Error; err != nil { if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise") t.Fatalf("No error should raise")
} }
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should find saved record") t.Fatalf("Should find saved record")
} }
panic("force panic") panic("force panic")
@ -114,7 +114,7 @@ func TestTransactionWithBlock(t *testing.T) {
}) })
if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil {
t.Errorf("Should not find record after panic rollback") t.Fatalf("Should not find record after panic rollback")
} }
} }
@ -122,14 +122,14 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) {
tx := DB.Begin() tx := DB.Begin()
user := User{Name: "transcation"} user := User{Name: "transcation"}
if err := tx.Save(&user).Error; err != nil { if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise") t.Fatalf("No error should raise")
} }
if err := tx.Commit().Error; err != nil { if err := tx.Commit().Error; err != nil {
t.Errorf("Commit should not raise error") t.Fatalf("Commit should not raise error")
} }
if err := tx.Rollback().Error; err == nil { if err := tx.Rollback().Error; err == nil {
t.Errorf("Rollback after commit should raise error") t.Fatalf("Rollback after commit should raise error")
} }
} }