package gorm import ( "context" "database/sql" "errors" "fmt" "reflect" "sort" "time" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ "create": {db: db}, "query": {db: db}, "update": {db: db}, "delete": {db: db}, "row": {db: db}, "raw": {db: db}, "transaction": {db: db}, }, } } // callbacks gorm callbacks manager type callbacks struct { processors map[string]*processor } type processor struct { db *DB Clauses []string fns []func(*DB) callbacks []*callback } type callback struct { name string before string after string remove bool replace bool match func(*DB) bool handler func(*DB) processor *processor } func (cs *callbacks) Create() *processor { return cs.processors["create"] } func (cs *callbacks) Query() *processor { return cs.processors["query"] } func (cs *callbacks) Update() *processor { return cs.processors["update"] } func (cs *callbacks) Delete() *processor { return cs.processors["delete"] } func (cs *callbacks) Row() *processor { return cs.processors["row"] } func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } func (cs *callbacks) Transaction() *processor { return cs.processors["transaction"] } func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { var err error switch beginner := tx.Statement.ConnPool.(type) { case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) default: err = ErrInvalidTransaction } if err != nil { _ = tx.AddError(err) } return tx } func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { scopes := db.Statement.scopes db.Statement.scopes = nil for _, scope := range scopes { db = scope(db) } } var ( curTime = time.Now() stmt = db.Statement resetBuildClauses bool ) if len(stmt.BuildClauses) == 0 { stmt.BuildClauses = p.Clauses resetBuildClauses = true } // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model } // parse model values if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) } } } // assign stmt.ReflectValue if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) } stmt.ReflectValue = stmt.ReflectValue.Elem() } if !stmt.ReflectValue.IsValid() { db.AddError(ErrInvalidValue) } } for _, f := range p.fns { f(db) } if stmt.SQL.Len() > 0 { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) } if !stmt.DB.DryRun { stmt.SQL.Reset() stmt.Vars = nil } if resetBuildClauses { stmt.BuildClauses = nil } return db } func (p *processor) Get(name string) func(*DB) { for i := len(p.callbacks) - 1; i >= 0; i-- { if v := p.callbacks[i]; v.name == name && !v.remove { return v.handler } } return nil } func (p *processor) Before(name string) *callback { return &callback{before: name, processor: p} } func (p *processor) After(name string) *callback { return &callback{after: name, processor: p} } func (p *processor) Match(fc func(*DB) bool) *callback { return &callback{match: fc, processor: p} } func (p *processor) Register(name string, fn func(*DB)) error { return (&callback{processor: p}).Register(name, fn) } func (p *processor) Remove(name string) error { return (&callback{processor: p}).Remove(name) } func (p *processor) Replace(name string, fn func(*DB)) error { return (&callback{processor: p}).Replace(name, fn) } func (p *processor) compile() (err error) { var callbacks []*callback for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } } p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) } return } func (c *callback) Before(name string) *callback { c.before = name return c } func (c *callback) After(name string) *callback { c.after = name return c } func (c *callback) Register(name string, fn func(*DB)) error { c.name = name c.handler = fn c.processor.callbacks = append(c.processor.callbacks, c) return c.processor.compile() } func (c *callback) Remove(name string) error { c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) return c.processor.compile() } func (c *callback) Replace(name string, fn func(*DB)) error { c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true c.processor.callbacks = append(c.processor.callbacks, c) return c.processor.compile() } // getRIndex get right index from string slice func getRIndex(strs []string, str string) int { for i := len(strs) - 1; i >= 0; i-- { if strs[i] == str { return i } } return -1 } func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { var ( names, sorted []string sortCallback func(*callback) error ) sort.Slice(cs, func(i, j int) bool { if cs[j].before == "*" && cs[i].before != "*" { return true } if cs[j].after == "*" && cs[i].after != "*" { return true } return false }) for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } sortCallback = func(c *callback) error { if c.before != "" { // if defined before callback if c.before == "*" && len(sorted) > 0 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { sorted = append([]string{c.name}, sorted...) } } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) } else if curIdx > sortedIdx { return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) } } else if idx := getRIndex(names, c.before); idx != -1 { // if before callback exists cs[idx].after = c.name } } if c.after != "" { // if defined after callback if c.after == "*" && len(sorted) > 0 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { sorted = append(sorted, c.name) } } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) } else if curIdx < sortedIdx { return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) } } else if idx := getRIndex(names, c.after); idx != -1 { // if after callback exists but haven't sorted // set after callback's before callback to current callback after := cs[idx] if after.before == "" { after.before = c.name } if err := sortCallback(after); err != nil { return err } if err := sortCallback(c); err != nil { return err } } } // if current callback haven't been sorted, append it to last if getRIndex(sorted, c.name) == -1 { sorted = append(sorted, c.name) } return nil } for _, c := range cs { if err = sortCallback(c); err != nil { return } } for _, name := range sorted { if idx := getRIndex(names, name); !cs[idx].remove { fns = append(fns, cs[idx].handler) } } return }