mirror of https://github.com/go-gorm/gorm.git
callbacks support sort with wildcard
This commit is contained in:
parent
f83b00d20d
commit
c11c939b95
16
callbacks.go
16
callbacks.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||||
names, sorted []string
|
names, sorted []string
|
||||||
sortCallback func(*callback) error
|
sortCallback func(*callback) error
|
||||||
)
|
)
|
||||||
|
sort.Slice(cs, func(i, j int) bool {
|
||||||
|
return cs[j].before == "*" || cs[j].after == "*"
|
||||||
|
})
|
||||||
|
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
|
@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||||
|
|
||||||
sortCallback = func(c *callback) error {
|
sortCallback = func(c *callback) error {
|
||||||
if c.before != "" { // if defined before callback
|
if c.before != "" { // if defined before callback
|
||||||
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
if c.before == "*" && len(sorted) > 0 {
|
||||||
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
|
sorted = append([]string{c.name}, sorted...)
|
||||||
|
}
|
||||||
|
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
// if before callback already sorted, append current callback just after it
|
// if before callback already sorted, append current callback just after it
|
||||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||||
|
@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.after != "" { // if defined after callback
|
if c.after != "" { // if defined after callback
|
||||||
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
if c.after == "*" && len(sorted) > 0 {
|
||||||
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
|
sorted = append(sorted, c.name)
|
||||||
|
}
|
||||||
|
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
// if after callback sorted, append current callback to last
|
// if after callback sorted, append current callback to last
|
||||||
sorted = append(sorted, c.name)
|
sorted = append(sorted, c.name)
|
||||||
|
|
2
gorm.go
2
gorm.go
|
@ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
preparedStmt := v.(*PreparedStmtDB)
|
preparedStmt := v.(*PreparedStmtDB)
|
||||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||||
ConnPool: db.Config.ConnPool,
|
ConnPool: db.Config.ConnPool,
|
||||||
mux: preparedStmt.mux,
|
Mux: preparedStmt.Mux,
|
||||||
Stmts: preparedStmt.Stmts,
|
Stmts: preparedStmt.Stmts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,12 +9,12 @@ import (
|
||||||
type PreparedStmtDB struct {
|
type PreparedStmtDB struct {
|
||||||
Stmts map[string]*sql.Stmt
|
Stmts map[string]*sql.Stmt
|
||||||
PreparedSQL []string
|
PreparedSQL []string
|
||||||
mux sync.RWMutex
|
Mux sync.RWMutex
|
||||||
ConnPool
|
ConnPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) Close() {
|
func (db *PreparedStmtDB) Close() {
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
for _, query := range db.PreparedSQL {
|
for _, query := range db.PreparedSQL {
|
||||||
if stmt, ok := db.Stmts[query]; ok {
|
if stmt, ok := db.Stmts[query]; ok {
|
||||||
delete(db.Stmts, query)
|
delete(db.Stmts, query)
|
||||||
|
@ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||||
db.mux.RLock()
|
db.Mux.RLock()
|
||||||
if stmt, ok := db.Stmts[query]; ok {
|
if stmt, ok := db.Stmts[query]; ok {
|
||||||
db.mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
return stmt, nil
|
return stmt, nil
|
||||||
}
|
}
|
||||||
db.mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
|
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
// double check
|
// double check
|
||||||
if stmt, ok := db.Stmts[query]; ok {
|
if stmt, ok := db.Stmts[query]; ok {
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
return stmt, nil
|
return stmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||||
db.Stmts[query] = stmt
|
db.Stmts[query] = stmt
|
||||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||||
}
|
}
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
|
|
||||||
return stmt, err
|
return stmt, err
|
||||||
}
|
}
|
||||||
|
@ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = stmt.ExecContext(ctx, args...)
|
result, err = stmt.ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(db.Stmts, query)
|
delete(db.Stmts, query)
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
|
@ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = stmt.QueryContext(ctx, args...)
|
rows, err = stmt.QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.mux.Lock()
|
db.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(db.Stmts, query)
|
delete(db.Stmts, query)
|
||||||
db.mux.Unlock()
|
db.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rows, err
|
return rows, err
|
||||||
|
@ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(tx.PreparedStmtDB.Stmts, query)
|
delete(tx.PreparedStmtDB.Stmts, query)
|
||||||
tx.PreparedStmtDB.mux.Unlock()
|
tx.PreparedStmtDB.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
|
@ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
delete(tx.PreparedStmtDB.Stmts, query)
|
delete(tx.PreparedStmtDB.Stmts, query)
|
||||||
tx.PreparedStmtDB.mux.Unlock()
|
tx.PreparedStmtDB.Mux.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rows, err
|
return rows, err
|
||||||
|
|
|
@ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) {
|
||||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||||
results: []string{"c1", "c4", "c3"},
|
results: []string{"c1", "c4", "c3"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}},
|
||||||
|
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
|
||||||
|
results: []string{"c5", "c1", "c2", "c4", "c3"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, data := range datas {
|
for idx, data := range datas {
|
||||||
|
|
Loading…
Reference in New Issue