diff --git a/callbacks.go b/callbacks.go index c917a678..baeb6c09 100644 --- a/callbacks.go +++ b/callbacks.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "sort" "time" "gorm.io/gorm/logger" @@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { names, sorted []string sortCallback func(*callback) error ) + sort.Slice(cs, func(i, j int) bool { + return cs[j].before == "*" || cs[j].after == "*" + }) for _, c := range cs { // show warning message the callback name already exists @@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { sortCallback = func(c *callback) error { if c.before != "" { // if defined before callback - if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if c.before == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append([]string{c.name}, sorted...) + } + } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) @@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { } if c.after != "" { // if defined after callback - if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if c.after == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append(sorted, c.name) + } + } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) diff --git a/gorm.go b/gorm.go index c786b5a5..1ace0099 100644 --- a/gorm.go +++ b/gorm.go @@ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB { preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, - mux: preparedStmt.mux, + Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } } diff --git a/prepare_stmt.go b/prepare_stmt.go index 2f4e1d57..7e87558d 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,12 +9,12 @@ import ( type PreparedStmtDB struct { Stmts map[string]*sql.Stmt PreparedSQL []string - mux sync.RWMutex + Mux sync.RWMutex ConnPool } func (db *PreparedStmtDB) Close() { - db.mux.Lock() + db.Mux.Lock() for _, query := range db.PreparedSQL { if stmt, ok := db.Stmts[query]; ok { delete(db.Stmts, query) @@ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() { } } - db.mux.Unlock() + db.Mux.Unlock() } func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { - db.mux.RLock() + db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok { - db.mux.RUnlock() + db.Mux.RUnlock() return stmt, nil } - db.mux.RUnlock() + db.Mux.RUnlock() - db.mux.Lock() + db.Mux.Lock() // double check if stmt, ok := db.Stmts[query]; ok { - db.mux.Unlock() + db.Mux.Unlock() return stmt, nil } @@ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { db.Stmts[query] = stmt db.PreparedSQL = append(db.PreparedSQL, query) } - db.mux.Unlock() + db.Mux.Unlock() return stmt, err } @@ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { - db.mux.Lock() + db.Mux.Lock() stmt.Close() delete(db.Stmts, query) - db.mux.Unlock() + db.Mux.Unlock() } } return result, err @@ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { - db.mux.Lock() + db.Mux.Lock() stmt.Close() delete(db.Stmts, query) - db.mux.Unlock() + db.Mux.Unlock() } } return rows, err @@ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) if err != nil { - tx.PreparedStmtDB.mux.Lock() + tx.PreparedStmtDB.Mux.Lock() stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + tx.PreparedStmtDB.Mux.Unlock() } } return result, err @@ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) if err != nil { - tx.PreparedStmtDB.mux.Lock() + tx.PreparedStmtDB.Mux.Lock() stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + tx.PreparedStmtDB.Mux.Unlock() } } return rows, err diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 1dbae441..84f56165 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, results: []string{"c1", "c4", "c3"}, }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c4", "c3"}, + }, } for idx, data := range datas {