mirror of https://github.com/go-gorm/gorm.git
callbacks support sort with wildcard
This commit is contained in:
parent
f83b00d20d
commit
c11c939b95
16
callbacks.go
16
callbacks.go
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
|
@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||
names, sorted []string
|
||||
sortCallback func(*callback) error
|
||||
)
|
||||
sort.Slice(cs, func(i, j int) bool {
|
||||
return cs[j].before == "*" || cs[j].after == "*"
|
||||
})
|
||||
|
||||
for _, c := range cs {
|
||||
// show warning message the callback name already exists
|
||||
|
@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||
|
||||
sortCallback = func(c *callback) error {
|
||||
if c.before != "" { // if defined before callback
|
||||
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||
if c.before == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append([]string{c.name}, sorted...)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||
|
@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||
}
|
||||
|
||||
if c.after != "" { // if defined after callback
|
||||
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||
if c.after == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append(sorted, c.name)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if after callback sorted, append current callback to last
|
||||
sorted = append(sorted, c.name)
|
||||
|
|
2
gorm.go
2
gorm.go
|
@ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB {
|
|||
preparedStmt := v.(*PreparedStmtDB)
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
mux: preparedStmt.mux,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,12 +9,12 @@ import (
|
|||
type PreparedStmtDB struct {
|
||||
Stmts map[string]*sql.Stmt
|
||||
PreparedSQL []string
|
||||
mux sync.RWMutex
|
||||
Mux sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.mux.Lock()
|
||||
db.Mux.Lock()
|
||||
for _, query := range db.PreparedSQL {
|
||||
if stmt, ok := db.Stmts[query]; ok {
|
||||
delete(db.Stmts, query)
|
||||
|
@ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
db.mux.Unlock()
|
||||
db.Mux.Unlock()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||
db.mux.RLock()
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok {
|
||||
db.mux.RUnlock()
|
||||
db.Mux.RUnlock()
|
||||
return stmt, nil
|
||||
}
|
||||
db.mux.RUnlock()
|
||||
db.Mux.RUnlock()
|
||||
|
||||
db.mux.Lock()
|
||||
db.Mux.Lock()
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok {
|
||||
db.mux.Unlock()
|
||||
db.Mux.Unlock()
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
|||
db.Stmts[query] = stmt
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
}
|
||||
db.mux.Unlock()
|
||||
db.Mux.Unlock()
|
||||
|
||||
return stmt, err
|
||||
}
|
||||
|
@ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
|||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
db.mux.Lock()
|
||||
db.Mux.Lock()
|
||||
stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
db.mux.Unlock()
|
||||
db.Mux.Unlock()
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
|
@ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
|||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
db.mux.Lock()
|
||||
db.Mux.Lock()
|
||||
stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
db.mux.Unlock()
|
||||
db.Mux.Unlock()
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
|
@ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||
if err == nil {
|
||||
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.mux.Lock()
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
tx.PreparedStmtDB.mux.Unlock()
|
||||
tx.PreparedStmtDB.Mux.Unlock()
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
|
@ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
|||
if err == nil {
|
||||
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.mux.Lock()
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
tx.PreparedStmtDB.mux.Unlock()
|
||||
tx.PreparedStmtDB.Mux.Unlock()
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
|
|
|
@ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) {
|
|||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||
results: []string{"c1", "c4", "c3"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}},
|
||||
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
|
||||
results: []string{"c5", "c1", "c2", "c4", "c3"},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, data := range datas {
|
||||
|
|
Loading…
Reference in New Issue