Implement callbacks

This commit is contained in:
Jinzhu 2020-01-31 06:35:25 +08:00
parent 9d5b9834d9
commit e509b3100d
6 changed files with 422 additions and 22 deletions

211
callbacks.go Normal file
View File

@ -0,0 +1,211 @@
package gorm
import (
"fmt"
"log"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/utils"
)
// Callbacks gorm callbacks manager
type Callbacks struct {
creates []func(*DB)
queries []func(*DB)
updates []func(*DB)
deletes []func(*DB)
row []func(*DB)
raw []func(*DB)
db *DB
processors []*processor
}
type processor struct {
kind string
name string
before string
after string
remove bool
replace bool
match func(*DB) bool
handler func(*DB)
callbacks *Callbacks
}
func (cs *Callbacks) Create() *processor {
return &processor{callbacks: cs, kind: "create"}
}
func (cs *Callbacks) Query() *processor {
return &processor{callbacks: cs, kind: "query"}
}
func (cs *Callbacks) Update() *processor {
return &processor{callbacks: cs, kind: "update"}
}
func (cs *Callbacks) Delete() *processor {
return &processor{callbacks: cs, kind: "delete"}
}
func (cs *Callbacks) Row() *processor {
return &processor{callbacks: cs, kind: "row"}
}
func (cs *Callbacks) Raw() *processor {
return &processor{callbacks: cs, kind: "raw"}
}
func (p *processor) Before(name string) *processor {
p.before = name
return p
}
func (p *processor) After(name string) *processor {
p.after = name
return p
}
func (p *processor) Match(fc func(*DB) bool) *processor {
p.match = fc
return p
}
func (p *processor) Get(name string) func(*DB) {
for i := len(p.callbacks.processors) - 1; i >= 0; i-- {
if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove {
return v.handler
}
}
return nil
}
func (p *processor) Register(name string, fn func(*DB)) {
p.name = name
p.handler = fn
p.callbacks.processors = append(p.callbacks.processors, p)
p.callbacks.compile(p.callbacks.db)
}
func (p *processor) Remove(name string) {
logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
p.name = name
p.remove = true
p.callbacks.processors = append(p.callbacks.processors, p)
p.callbacks.compile(p.callbacks.db)
}
func (p *processor) Replace(name string, fn func(*DB)) {
logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
p.name = name
p.handler = fn
p.replace = true
p.callbacks.processors = append(p.callbacks.processors, p)
p.callbacks.compile(p.callbacks.db)
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
func sortProcessors(ps []*processor) []func(*DB) {
var (
allNames, sortedNames []string
sortProcessor func(*processor) error
)
for _, p := range ps {
// show warning message the callback name already exists
if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove {
log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum())
}
allNames = append(allNames, p.name)
}
sortProcessor = func(p *processor) error {
if getRIndex(sortedNames, p.name) == -1 { // if not sorted
if p.before != "" { // if defined before callback
if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 {
if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true {
// if before callback already sorted, append current callback just after it
sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before)
}
} else if idx := getRIndex(allNames, p.before); idx != -1 {
// if before callback exists
ps[idx].after = p.name
}
}
if p.after != "" { // if defined after callback
if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 {
// if after callback sorted, append current callback to last
sortedNames = append(sortedNames, p.name)
} else if idx := getRIndex(allNames, p.after); idx != -1 {
// if after callback exists but haven't sorted
// set after callback's before callback to current callback
if after := ps[idx]; after.before == "" {
after.before = p.name
sortProcessor(after)
}
}
}
// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, p.name) == -1 {
sortedNames = append(sortedNames, p.name)
}
}
return nil
}
for _, p := range ps {
sortProcessor(p)
}
var fns []func(*DB)
for _, name := range sortedNames {
if idx := getRIndex(allNames, name); !ps[idx].remove {
fns = append(fns, ps[idx].handler)
}
}
return fns
}
// compile processors
func (cs *Callbacks) compile(db *DB) {
processors := map[string][]*processor{}
for _, p := range cs.processors {
if p.name != "" {
if p.match == nil || p.match(db) {
processors[p.kind] = append(processors[p.kind], p)
}
}
}
for name, ps := range processors {
switch name {
case "create":
cs.creates = sortProcessors(ps)
case "query":
cs.queries = sortProcessors(ps)
case "update":
cs.updates = sortProcessors(ps)
case "delete":
cs.deletes = sortProcessors(ps)
case "row":
cs.row = sortProcessors(ps)
case "raw":
cs.raw = sortProcessors(ps)
}
}
}

131
callbacks_test.go Normal file
View File

@ -0,0 +1,131 @@
package gorm
import (
"fmt"
"reflect"
"runtime"
"strings"
"testing"
)
func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) {
var got []string
for _, f := range funcs {
got = append(got, getFuncName(f))
}
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
}
func getFuncName(fc func(*DB)) string {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".")
return fnames[len(fnames)-1]
}
func c1(*DB) {}
func c2(*DB) {}
func c3(*DB) {}
func c4(*DB) {}
func c5(*DB) {}
func TestCallbacks(t *testing.T) {
type callback struct {
name string
before string
after string
remove bool
replace bool
err error
match func(*DB) bool
h func(*DB)
}
datas := []struct {
callbacks []callback
results []string
}{
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c2", "c3", "c4", "c5"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c5", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c1", "c3", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c5", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
results: []string{"c1", "c4", "c3"},
},
}
// func TestRegisterCallbackWithComplexOrder(t *testing.T) {
// var callback2 = &Callback{logger: defaultLogger}
// callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
// callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
// callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
// callback2.Delete().Register("after_create1", afterCreate1)
// callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
// if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
// t.Errorf("register callback with order")
// }
// }
for idx, data := range datas {
callbacks := &Callbacks{}
for _, c := range data.callbacks {
p := callbacks.Create()
if c.name == "" {
c.name = getFuncName(c.h)
}
if c.before != "" {
p = p.Before(c.before)
}
if c.after != "" {
p = p.After(c.after)
}
if c.match != nil {
p = p.Match(c.match)
}
if c.remove {
p.Remove(c.name)
} else if c.replace {
p.Replace(c.name, c.h)
} else {
p.Register(c.name, c.h)
}
}
if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok {
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
}
}
}

View File

@ -1,6 +1,9 @@
package gorm
import "errors"
import (
"errors"
"time"
)
var (
// ErrRecordNotFound record not found error
@ -13,10 +16,14 @@ var (
ErrUnaddressable = errors.New("using unaddressable value")
)
type Error struct {
Err error
}
func (e Error) Unwrap() error {
return e.Err
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}

View File

@ -1,7 +1,15 @@
package logger
import (
"fmt"
"log"
"os"
)
type LogLevel int
var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)}
const (
Info LogLevel = iota + 1
Warn
@ -11,4 +19,42 @@ const (
// Interface logger interface
type Interface interface {
LogMode(LogLevel) Interface
Info(string, ...interface{})
Warn(string, ...interface{})
Error(string, ...interface{})
}
// Writer log writer interface
type Writer interface {
Print(...interface{})
}
type Logger struct {
Writer
logLevel LogLevel
}
func (logger Logger) LogMode(level LogLevel) Interface {
return Logger{Writer: logger.Writer, logLevel: level}
}
// Info print info
func (logger Logger) Info(msg string, data ...interface{}) {
if logger.logLevel >= Info {
logger.Print("[info] " + fmt.Sprintf(msg, data...))
}
}
// Warn print warn messages
func (logger Logger) Warn(msg string, data ...interface{}) {
if logger.logLevel >= Warn {
logger.Print("[warn] " + fmt.Sprintf(msg, data...))
}
}
// Error print error messages
func (logger Logger) Error(msg string, data ...interface{}) {
if logger.logLevel >= Error {
logger.Print("[error] " + fmt.Sprintf(msg, data...))
}
}

View File

@ -1,15 +0,0 @@
package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}

20
utils/utils.go Normal file
View File

@ -0,0 +1,20 @@
package utils
import (
"fmt"
"regexp"
"runtime"
)
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
func FileWithLineNum() string {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
return fmt.Sprintf("%v:%v", file, line)
}
}
return ""
}