gorm/callbacks.go

312 lines
7.5 KiB
Go
Raw Normal View History

2020-01-31 01:35:25 +03:00
package gorm
import (
2020-05-05 16:28:38 +03:00
"context"
2020-02-02 09:40:44 +03:00
"errors"
2020-01-31 01:35:25 +03:00
"fmt"
2020-02-24 03:51:35 +03:00
"reflect"
2020-08-03 16:48:36 +03:00
"sort"
2020-02-23 07:39:26 +03:00
"time"
2020-01-31 01:35:25 +03:00
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
2020-01-31 01:35:25 +03:00
)
2020-02-02 14:32:27 +03:00
func initializeCallbacks(db *DB) *callbacks {
2020-01-31 03:29:35 +03:00
return &callbacks{
processors: map[string]*processor{
2020-06-07 10:24:34 +03:00
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
2020-01-31 03:29:35 +03:00
},
}
}
// callbacks gorm callbacks manager
type callbacks struct {
processors map[string]*processor
2020-01-31 01:35:25 +03:00
}
type processor struct {
2020-01-31 03:29:35 +03:00
db *DB
fns []func(*DB)
callbacks []*callback
}
type callback struct {
2020-01-31 01:35:25 +03:00
name string
before string
after string
remove bool
replace bool
match func(*DB) bool
handler func(*DB)
2020-01-31 03:29:35 +03:00
processor *processor
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (cs *callbacks) Create() *processor {
return cs.processors["create"]
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (cs *callbacks) Query() *processor {
return cs.processors["query"]
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (cs *callbacks) Update() *processor {
return cs.processors["update"]
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (cs *callbacks) Delete() *processor {
return cs.processors["delete"]
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (cs *callbacks) Row() *processor {
return cs.processors["row"]
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (p *processor) Execute(db *DB) {
2021-02-25 13:49:01 +03:00
var (
curTime = time.Now()
stmt = db.Statement
)
2020-02-04 03:56:15 +03:00
2020-06-08 17:32:35 +03:00
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
2020-06-08 17:32:35 +03:00
}
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
2020-02-02 09:40:44 +03:00
}
2020-06-08 17:32:35 +03:00
}
2020-05-23 06:57:28 +03:00
2020-06-08 17:32:35 +03:00
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
2021-01-26 15:08:41 +03:00
if stmt.ReflectValue.IsNil() {
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
break
}
2020-06-08 17:32:35 +03:00
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value"))
2020-05-23 06:57:28 +03:00
}
2020-02-02 09:40:44 +03:00
}
2021-02-25 13:49:01 +03:00
// call scopes
for _, scope := range stmt.scopes {
db = scope(db)
}
stmt.scopes = nil
2020-01-31 03:29:35 +03:00
for _, f := range p.fns {
f(db)
}
2020-02-23 07:39:26 +03:00
2020-06-08 17:32:35 +03:00
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)
2020-03-09 10:32:55 +03:00
2020-06-08 17:32:35 +03:00
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
2020-02-23 07:39:26 +03:00
}
2020-01-31 01:35:25 +03:00
}
func (p *processor) Get(name string) func(*DB) {
2020-01-31 03:29:35 +03:00
for i := len(p.callbacks) - 1; i >= 0; i-- {
if v := p.callbacks[i]; v.name == name && !v.remove {
2020-01-31 01:35:25 +03:00
return v.handler
}
}
return nil
}
2020-01-31 03:29:35 +03:00
func (p *processor) Before(name string) *callback {
return &callback{before: name, processor: p}
}
func (p *processor) After(name string) *callback {
return &callback{after: name, processor: p}
}
func (p *processor) Match(fc func(*DB) bool) *callback {
return &callback{match: fc, processor: p}
}
func (p *processor) Register(name string, fn func(*DB)) error {
return (&callback{processor: p}).Register(name, fn)
}
func (p *processor) Remove(name string) error {
return (&callback{processor: p}).Remove(name)
}
func (p *processor) Replace(name string, fn func(*DB)) error {
return (&callback{processor: p}).Replace(name, fn)
}
2020-02-02 14:32:27 +03:00
func (p *processor) compile() (err error) {
var callbacks []*callback
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
}
2020-06-02 07:46:55 +03:00
p.callbacks = callbacks
2020-02-02 14:32:27 +03:00
2020-01-31 03:29:35 +03:00
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
2020-01-31 03:29:35 +03:00
}
return
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (c *callback) Before(name string) *callback {
c.before = name
return c
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
func (c *callback) After(name string) *callback {
c.after = name
return c
}
func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
c.processor.callbacks = append(c.processor.callbacks, c)
2020-02-02 14:32:27 +03:00
return c.processor.compile()
2020-01-31 03:29:35 +03:00
}
func (c *callback) Remove(name string) error {
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
2020-01-31 03:29:35 +03:00
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
2020-02-02 14:32:27 +03:00
return c.processor.compile()
2020-01-31 03:29:35 +03:00
}
func (c *callback) Replace(name string, fn func(*DB)) error {
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
2020-01-31 03:29:35 +03:00
c.name = name
c.handler = fn
c.replace = true
c.processor.callbacks = append(c.processor.callbacks, c)
2020-02-02 14:32:27 +03:00
return c.processor.compile()
2020-01-31 01:35:25 +03:00
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
2020-01-31 03:29:35 +03:00
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
2020-01-31 01:35:25 +03:00
var (
2020-01-31 03:29:35 +03:00
names, sorted []string
sortCallback func(*callback) error
2020-01-31 01:35:25 +03:00
)
2020-08-03 16:48:36 +03:00
sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*"
})
2020-01-31 01:35:25 +03:00
2020-01-31 03:29:35 +03:00
for _, c := range cs {
2020-01-31 01:35:25 +03:00
// show warning message the callback name already exists
2020-01-31 03:29:35 +03:00
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
names = append(names, c.name)
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback
2020-08-03 16:48:36 +03:00
if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
2020-01-31 03:29:35 +03:00
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
cs[idx].after = c.name
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
}
2020-01-31 01:35:25 +03:00
2020-01-31 03:29:35 +03:00
if c.after != "" { // if defined after callback
2020-08-03 16:48:36 +03:00
if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
2020-01-31 03:29:35 +03:00
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
2020-01-31 01:35:25 +03:00
// if after callback sorted, append current callback to last
2020-01-31 03:29:35 +03:00
sorted = append(sorted, c.name)
} else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted
// set after callback's before callback to current callback
after := cs[idx]
if after.before == "" {
after.before = c.name
}
if err := sortCallback(after); err != nil {
return err
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
if err := sortCallback(c); err != nil {
return err
}
2020-01-31 01:35:25 +03:00
}
}
2020-01-31 03:29:35 +03:00
// if current callback haven't been sorted, append it to last
if getRIndex(sorted, c.name) == -1 {
sorted = append(sorted, c.name)
}
2020-01-31 01:35:25 +03:00
2020-01-31 03:29:35 +03:00
return nil
2020-01-31 01:35:25 +03:00
}
2020-01-31 03:29:35 +03:00
for _, c := range cs {
if err = sortCallback(c); err != nil {
return
2020-01-31 01:35:25 +03:00
}
}
2020-01-31 03:29:35 +03:00
for _, name := range sorted {
if idx := getRIndex(names, name); !cs[idx].remove {
fns = append(fns, cs[idx].handler)
2020-01-31 01:35:25 +03:00
}
}
2020-01-31 03:29:35 +03:00
return
2020-01-31 01:35:25 +03:00
}