forked from mirror/gorm
346 lines
8.3 KiB
Go
346 lines
8.3 KiB
Go
package gorm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"sort"
|
|
"time"
|
|
|
|
"gorm.io/gorm/schema"
|
|
"gorm.io/gorm/utils"
|
|
)
|
|
|
|
func initializeCallbacks(db *DB) *callbacks {
|
|
return &callbacks{
|
|
processors: map[string]*processor{
|
|
"create": {db: db},
|
|
"query": {db: db},
|
|
"update": {db: db},
|
|
"delete": {db: db},
|
|
"row": {db: db},
|
|
"raw": {db: db},
|
|
},
|
|
}
|
|
}
|
|
|
|
// callbacks gorm callbacks manager
|
|
type callbacks struct {
|
|
processors map[string]*processor
|
|
}
|
|
|
|
type processor struct {
|
|
db *DB
|
|
Clauses []string
|
|
fns []func(*DB)
|
|
callbacks []*callback
|
|
}
|
|
|
|
type callback struct {
|
|
name string
|
|
before string
|
|
after string
|
|
remove bool
|
|
replace bool
|
|
match func(*DB) bool
|
|
handler func(*DB)
|
|
processor *processor
|
|
}
|
|
|
|
func (cs *callbacks) Create() *processor {
|
|
return cs.processors["create"]
|
|
}
|
|
|
|
func (cs *callbacks) Query() *processor {
|
|
return cs.processors["query"]
|
|
}
|
|
|
|
func (cs *callbacks) Update() *processor {
|
|
return cs.processors["update"]
|
|
}
|
|
|
|
func (cs *callbacks) Delete() *processor {
|
|
return cs.processors["delete"]
|
|
}
|
|
|
|
func (cs *callbacks) Row() *processor {
|
|
return cs.processors["row"]
|
|
}
|
|
|
|
func (cs *callbacks) Raw() *processor {
|
|
return cs.processors["raw"]
|
|
}
|
|
|
|
func (p *processor) Execute(db *DB) *DB {
|
|
// call scopes
|
|
for len(db.Statement.scopes) > 0 {
|
|
scopes := db.Statement.scopes
|
|
db.Statement.scopes = nil
|
|
for _, scope := range scopes {
|
|
db = scope(db)
|
|
}
|
|
}
|
|
|
|
var (
|
|
curTime = time.Now()
|
|
stmt = db.Statement
|
|
resetBuildClauses bool
|
|
)
|
|
|
|
if len(stmt.BuildClauses) == 0 {
|
|
stmt.BuildClauses = p.Clauses
|
|
resetBuildClauses = true
|
|
}
|
|
|
|
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
|
|
optimizer.ModifyStatement(stmt)
|
|
}
|
|
|
|
// assign model values
|
|
if stmt.Model == nil {
|
|
stmt.Model = stmt.Dest
|
|
} else if stmt.Dest == nil {
|
|
stmt.Dest = stmt.Model
|
|
}
|
|
|
|
// parse model values
|
|
if stmt.Model != nil {
|
|
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
|
|
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
|
|
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
|
} else {
|
|
db.AddError(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// assign stmt.ReflectValue
|
|
if stmt.Dest != nil {
|
|
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
|
for stmt.ReflectValue.Kind() == reflect.Ptr {
|
|
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
|
|
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
|
|
}
|
|
|
|
stmt.ReflectValue = stmt.ReflectValue.Elem()
|
|
}
|
|
if !stmt.ReflectValue.IsValid() {
|
|
db.AddError(ErrInvalidValue)
|
|
}
|
|
}
|
|
|
|
for _, f := range p.fns {
|
|
f(db)
|
|
}
|
|
|
|
if stmt.SQL.Len() > 0 {
|
|
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
|
sql, vars := stmt.SQL.String(), stmt.Vars
|
|
if filter, ok := db.Logger.(ParamsFilter); ok {
|
|
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
|
}
|
|
return db.Dialector.Explain(sql, vars...), db.RowsAffected
|
|
}, db.Error)
|
|
}
|
|
|
|
if !stmt.DB.DryRun {
|
|
stmt.SQL.Reset()
|
|
stmt.Vars = nil
|
|
}
|
|
|
|
if resetBuildClauses {
|
|
stmt.BuildClauses = nil
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func (p *processor) Get(name string) func(*DB) {
|
|
for i := len(p.callbacks) - 1; i >= 0; i-- {
|
|
if v := p.callbacks[i]; v.name == name && !v.remove {
|
|
return v.handler
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *processor) Before(name string) *callback {
|
|
return &callback{before: name, processor: p}
|
|
}
|
|
|
|
func (p *processor) After(name string) *callback {
|
|
return &callback{after: name, processor: p}
|
|
}
|
|
|
|
func (p *processor) Match(fc func(*DB) bool) *callback {
|
|
return &callback{match: fc, processor: p}
|
|
}
|
|
|
|
func (p *processor) Register(name string, fn func(*DB)) error {
|
|
return (&callback{processor: p}).Register(name, fn)
|
|
}
|
|
|
|
func (p *processor) Remove(name string) error {
|
|
return (&callback{processor: p}).Remove(name)
|
|
}
|
|
|
|
func (p *processor) Replace(name string, fn func(*DB)) error {
|
|
return (&callback{processor: p}).Replace(name, fn)
|
|
}
|
|
|
|
func (p *processor) compile() (err error) {
|
|
var callbacks []*callback
|
|
for _, callback := range p.callbacks {
|
|
if callback.match == nil || callback.match(p.db) {
|
|
callbacks = append(callbacks, callback)
|
|
}
|
|
}
|
|
p.callbacks = callbacks
|
|
|
|
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
|
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *callback) Before(name string) *callback {
|
|
c.before = name
|
|
return c
|
|
}
|
|
|
|
func (c *callback) After(name string) *callback {
|
|
c.after = name
|
|
return c
|
|
}
|
|
|
|
func (c *callback) Register(name string, fn func(*DB)) error {
|
|
c.name = name
|
|
c.handler = fn
|
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
|
return c.processor.compile()
|
|
}
|
|
|
|
func (c *callback) Remove(name string) error {
|
|
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
|
c.name = name
|
|
c.remove = true
|
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
|
return c.processor.compile()
|
|
}
|
|
|
|
func (c *callback) Replace(name string, fn func(*DB)) error {
|
|
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
|
c.name = name
|
|
c.handler = fn
|
|
c.replace = true
|
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
|
return c.processor.compile()
|
|
}
|
|
|
|
// getRIndex get right index from string slice
|
|
func getRIndex(strs []string, str string) int {
|
|
for i := len(strs) - 1; i >= 0; i-- {
|
|
if strs[i] == str {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|
var (
|
|
names, sorted []string
|
|
sortCallback func(*callback) error
|
|
)
|
|
sort.Slice(cs, func(i, j int) bool {
|
|
if cs[j].before == "*" && cs[i].before != "*" {
|
|
return true
|
|
}
|
|
if cs[j].after == "*" && cs[i].after != "*" {
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
|
|
for _, c := range cs {
|
|
// show warning message the callback name already exists
|
|
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
|
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
|
|
}
|
|
names = append(names, c.name)
|
|
}
|
|
|
|
sortCallback = func(c *callback) error {
|
|
if c.before != "" { // if defined before callback
|
|
if c.before == "*" && len(sorted) > 0 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
sorted = append([]string{c.name}, sorted...)
|
|
}
|
|
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
// if before callback already sorted, append current callback just after it
|
|
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
|
} else if curIdx > sortedIdx {
|
|
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
|
|
}
|
|
} else if idx := getRIndex(names, c.before); idx != -1 {
|
|
// if before callback exists
|
|
cs[idx].after = c.name
|
|
}
|
|
}
|
|
|
|
if c.after != "" { // if defined after callback
|
|
if c.after == "*" && len(sorted) > 0 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
sorted = append(sorted, c.name)
|
|
}
|
|
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
// if after callback sorted, append current callback to last
|
|
sorted = append(sorted, c.name)
|
|
} else if curIdx < sortedIdx {
|
|
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
|
|
}
|
|
} else if idx := getRIndex(names, c.after); idx != -1 {
|
|
// if after callback exists but haven't sorted
|
|
// set after callback's before callback to current callback
|
|
after := cs[idx]
|
|
|
|
if after.before == "" {
|
|
after.before = c.name
|
|
}
|
|
|
|
if err := sortCallback(after); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := sortCallback(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// if current callback haven't been sorted, append it to last
|
|
if getRIndex(sorted, c.name) == -1 {
|
|
sorted = append(sorted, c.name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
for _, c := range cs {
|
|
if err = sortCallback(c); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
for _, name := range sorted {
|
|
if idx := getRIndex(names, name); !cs[idx].remove {
|
|
fns = append(fns, cs[idx].handler)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|