callbacks support sort with wildcard

This commit is contained in:
Jinzhu 2020-08-03 21:48:36 +08:00
parent f83b00d20d
commit c11c939b95
4 changed files with 40 additions and 20 deletions

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"sort"
"time" "time"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
names, sorted []string names, sorted []string
sortCallback func(*callback) error sortCallback func(*callback) error
) )
sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*"
})
for _, c := range cs { for _, c := range cs {
// show warning message the callback name already exists // show warning message the callback name already exists
@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
sortCallback = func(c *callback) error { sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback if c.before != "" { // if defined before callback
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it // if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
} }
if c.after != "" { // if defined after callback if c.after != "" { // if defined after callback
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last // if after callback sorted, append current callback to last
sorted = append(sorted, c.name) sorted = append(sorted, c.name)

View File

@ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB {
preparedStmt := v.(*PreparedStmtDB) preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{ tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool, ConnPool: db.Config.ConnPool,
mux: preparedStmt.mux, Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts, Stmts: preparedStmt.Stmts,
} }
} }

View File

@ -9,12 +9,12 @@ import (
type PreparedStmtDB struct { type PreparedStmtDB struct {
Stmts map[string]*sql.Stmt Stmts map[string]*sql.Stmt
PreparedSQL []string PreparedSQL []string
mux sync.RWMutex Mux sync.RWMutex
ConnPool ConnPool
} }
func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Close() {
db.mux.Lock() db.Mux.Lock()
for _, query := range db.PreparedSQL { for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok { if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query) delete(db.Stmts, query)
@ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() {
} }
} }
db.mux.Unlock() db.Mux.Unlock()
} }
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
db.mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok { if stmt, ok := db.Stmts[query]; ok {
db.mux.RUnlock() db.Mux.RUnlock()
return stmt, nil return stmt, nil
} }
db.mux.RUnlock() db.Mux.RUnlock()
db.mux.Lock() db.Mux.Lock()
// double check // double check
if stmt, ok := db.Stmts[query]; ok { if stmt, ok := db.Stmts[query]; ok {
db.mux.Unlock() db.Mux.Unlock()
return stmt, nil return stmt, nil
} }
@ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
db.Stmts[query] = stmt db.Stmts[query] = stmt
db.PreparedSQL = append(db.PreparedSQL, query) db.PreparedSQL = append(db.PreparedSQL, query)
} }
db.mux.Unlock() db.Mux.Unlock()
return stmt, err return stmt, err
} }
@ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
if err == nil { if err == nil {
result, err = stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
if err != nil { if err != nil {
db.mux.Lock() db.Mux.Lock()
stmt.Close() stmt.Close()
delete(db.Stmts, query) delete(db.Stmts, query)
db.mux.Unlock() db.Mux.Unlock()
} }
} }
return result, err return result, err
@ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
if err == nil { if err == nil {
rows, err = stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
if err != nil { if err != nil {
db.mux.Lock() db.Mux.Lock()
stmt.Close() stmt.Close()
delete(db.Stmts, query) delete(db.Stmts, query)
db.mux.Unlock() db.Mux.Unlock()
} }
} }
return rows, err return rows, err
@ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
if err == nil { if err == nil {
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
if err != nil { if err != nil {
tx.PreparedStmtDB.mux.Lock() tx.PreparedStmtDB.Mux.Lock()
stmt.Close() stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query) delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock() tx.PreparedStmtDB.Mux.Unlock()
} }
} }
return result, err return result, err
@ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
if err == nil { if err == nil {
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
if err != nil { if err != nil {
tx.PreparedStmtDB.mux.Lock() tx.PreparedStmtDB.Mux.Lock()
stmt.Close() stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query) delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock() tx.PreparedStmtDB.Mux.Unlock()
} }
} }
return rows, err return rows, err

View File

@ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) {
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
results: []string{"c1", "c4", "c3"}, results: []string{"c1", "c4", "c3"},
}, },
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c4", "c3"},
},
} }
for idx, data := range datas { for idx, data := range datas {