callbacks support sort with wildcard

This commit is contained in:
Jinzhu 2020-08-03 21:48:36 +08:00
parent f83b00d20d
commit c11c939b95
4 changed files with 40 additions and 20 deletions

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"sort"
"time"
"gorm.io/gorm/logger"
@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
names, sorted []string
sortCallback func(*callback) error
)
sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*"
})
for _, c := range cs {
// show warning message the callback name already exists
@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
}
if c.after != "" { // if defined after callback
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last
sorted = append(sorted, c.name)

View File

@ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB {
preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
mux: preparedStmt.mux,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}

View File

@ -9,12 +9,12 @@ import (
type PreparedStmtDB struct {
Stmts map[string]*sql.Stmt
PreparedSQL []string
mux sync.RWMutex
Mux sync.RWMutex
ConnPool
}
func (db *PreparedStmtDB) Close() {
db.mux.Lock()
db.Mux.Lock()
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
@ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() {
}
}
db.mux.Unlock()
db.Mux.Unlock()
}
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
db.mux.RLock()
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok {
db.mux.RUnlock()
db.Mux.RUnlock()
return stmt, nil
}
db.mux.RUnlock()
db.Mux.RUnlock()
db.mux.Lock()
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok {
db.mux.Unlock()
db.Mux.Unlock()
return stmt, nil
}
@ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
db.Stmts[query] = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
}
db.mux.Unlock()
db.Mux.Unlock()
return stmt, err
}
@ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
db.mux.Lock()
db.Mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
db.Mux.Unlock()
}
}
return result, err
@ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
db.mux.Lock()
db.Mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
db.Mux.Unlock()
}
}
return rows, err
@ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
if err == nil {
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.mux.Lock()
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
tx.PreparedStmtDB.Mux.Unlock()
}
}
return result, err
@ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
if err == nil {
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.mux.Lock()
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
tx.PreparedStmtDB.Mux.Unlock()
}
}
return rows, err

View File

@ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) {
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
results: []string{"c1", "c4", "c3"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c4", "c3"},
},
}
for idx, data := range datas {