mirror of https://github.com/go-gorm/gorm.git
Implement callbacks
This commit is contained in:
parent
9d5b9834d9
commit
e509b3100d
|
@ -0,0 +1,211 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/gorm/utils"
|
||||
)
|
||||
|
||||
// Callbacks gorm callbacks manager
|
||||
type Callbacks struct {
|
||||
creates []func(*DB)
|
||||
queries []func(*DB)
|
||||
updates []func(*DB)
|
||||
deletes []func(*DB)
|
||||
row []func(*DB)
|
||||
raw []func(*DB)
|
||||
db *DB
|
||||
processors []*processor
|
||||
}
|
||||
|
||||
type processor struct {
|
||||
kind string
|
||||
name string
|
||||
before string
|
||||
after string
|
||||
remove bool
|
||||
replace bool
|
||||
match func(*DB) bool
|
||||
handler func(*DB)
|
||||
callbacks *Callbacks
|
||||
}
|
||||
|
||||
func (cs *Callbacks) Create() *processor {
|
||||
return &processor{callbacks: cs, kind: "create"}
|
||||
}
|
||||
|
||||
func (cs *Callbacks) Query() *processor {
|
||||
return &processor{callbacks: cs, kind: "query"}
|
||||
}
|
||||
|
||||
func (cs *Callbacks) Update() *processor {
|
||||
return &processor{callbacks: cs, kind: "update"}
|
||||
}
|
||||
|
||||
func (cs *Callbacks) Delete() *processor {
|
||||
return &processor{callbacks: cs, kind: "delete"}
|
||||
}
|
||||
|
||||
func (cs *Callbacks) Row() *processor {
|
||||
return &processor{callbacks: cs, kind: "row"}
|
||||
}
|
||||
|
||||
func (cs *Callbacks) Raw() *processor {
|
||||
return &processor{callbacks: cs, kind: "raw"}
|
||||
}
|
||||
|
||||
func (p *processor) Before(name string) *processor {
|
||||
p.before = name
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *processor) After(name string) *processor {
|
||||
p.after = name
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *processor) Match(fc func(*DB) bool) *processor {
|
||||
p.match = fc
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *processor) Get(name string) func(*DB) {
|
||||
for i := len(p.callbacks.processors) - 1; i >= 0; i-- {
|
||||
if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove {
|
||||
return v.handler
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *processor) Register(name string, fn func(*DB)) {
|
||||
p.name = name
|
||||
p.handler = fn
|
||||
p.callbacks.processors = append(p.callbacks.processors, p)
|
||||
p.callbacks.compile(p.callbacks.db)
|
||||
}
|
||||
|
||||
func (p *processor) Remove(name string) {
|
||||
logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
p.name = name
|
||||
p.remove = true
|
||||
p.callbacks.processors = append(p.callbacks.processors, p)
|
||||
p.callbacks.compile(p.callbacks.db)
|
||||
}
|
||||
|
||||
func (p *processor) Replace(name string, fn func(*DB)) {
|
||||
logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
p.name = name
|
||||
p.handler = fn
|
||||
p.replace = true
|
||||
p.callbacks.processors = append(p.callbacks.processors, p)
|
||||
p.callbacks.compile(p.callbacks.db)
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
func getRIndex(strs []string, str string) int {
|
||||
for i := len(strs) - 1; i >= 0; i-- {
|
||||
if strs[i] == str {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func sortProcessors(ps []*processor) []func(*DB) {
|
||||
var (
|
||||
allNames, sortedNames []string
|
||||
sortProcessor func(*processor) error
|
||||
)
|
||||
|
||||
for _, p := range ps {
|
||||
// show warning message the callback name already exists
|
||||
if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove {
|
||||
log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum())
|
||||
}
|
||||
allNames = append(allNames, p.name)
|
||||
}
|
||||
|
||||
sortProcessor = func(p *processor) error {
|
||||
if getRIndex(sortedNames, p.name) == -1 { // if not sorted
|
||||
if p.before != "" { // if defined before callback
|
||||
if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...)
|
||||
} else if curIdx > sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before)
|
||||
}
|
||||
} else if idx := getRIndex(allNames, p.before); idx != -1 {
|
||||
// if before callback exists
|
||||
ps[idx].after = p.name
|
||||
}
|
||||
}
|
||||
|
||||
if p.after != "" { // if defined after callback
|
||||
if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 {
|
||||
// if after callback sorted, append current callback to last
|
||||
sortedNames = append(sortedNames, p.name)
|
||||
} else if idx := getRIndex(allNames, p.after); idx != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
// set after callback's before callback to current callback
|
||||
if after := ps[idx]; after.before == "" {
|
||||
after.before = p.name
|
||||
sortProcessor(after)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if current callback haven't been sorted, append it to last
|
||||
if getRIndex(sortedNames, p.name) == -1 {
|
||||
sortedNames = append(sortedNames, p.name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, p := range ps {
|
||||
sortProcessor(p)
|
||||
}
|
||||
|
||||
var fns []func(*DB)
|
||||
for _, name := range sortedNames {
|
||||
if idx := getRIndex(allNames, name); !ps[idx].remove {
|
||||
fns = append(fns, ps[idx].handler)
|
||||
}
|
||||
}
|
||||
|
||||
return fns
|
||||
}
|
||||
|
||||
// compile processors
|
||||
func (cs *Callbacks) compile(db *DB) {
|
||||
processors := map[string][]*processor{}
|
||||
for _, p := range cs.processors {
|
||||
if p.name != "" {
|
||||
if p.match == nil || p.match(db) {
|
||||
processors[p.kind] = append(processors[p.kind], p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name, ps := range processors {
|
||||
switch name {
|
||||
case "create":
|
||||
cs.creates = sortProcessors(ps)
|
||||
case "query":
|
||||
cs.queries = sortProcessors(ps)
|
||||
case "update":
|
||||
cs.updates = sortProcessors(ps)
|
||||
case "delete":
|
||||
cs.deletes = sortProcessors(ps)
|
||||
case "row":
|
||||
cs.row = sortProcessors(ps)
|
||||
case "raw":
|
||||
cs.raw = sortProcessors(ps)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) {
|
||||
var got []string
|
||||
|
||||
for _, f := range funcs {
|
||||
got = append(got, getFuncName(f))
|
||||
}
|
||||
|
||||
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
|
||||
}
|
||||
|
||||
func getFuncName(fc func(*DB)) string {
|
||||
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".")
|
||||
return fnames[len(fnames)-1]
|
||||
}
|
||||
|
||||
func c1(*DB) {}
|
||||
func c2(*DB) {}
|
||||
func c3(*DB) {}
|
||||
func c4(*DB) {}
|
||||
func c5(*DB) {}
|
||||
|
||||
func TestCallbacks(t *testing.T) {
|
||||
type callback struct {
|
||||
name string
|
||||
before string
|
||||
after string
|
||||
remove bool
|
||||
replace bool
|
||||
err error
|
||||
match func(*DB) bool
|
||||
h func(*DB)
|
||||
}
|
||||
|
||||
datas := []struct {
|
||||
callbacks []callback
|
||||
results []string
|
||||
}{
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
|
||||
results: []string{"c1", "c2", "c3", "c4", "c5"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
|
||||
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
|
||||
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
|
||||
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
|
||||
results: []string{"c1", "c5", "c2", "c3", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
|
||||
results: []string{"c1", "c3", "c5", "c2", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||
results: []string{"c1", "c5", "c3", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||
results: []string{"c1", "c4", "c3"},
|
||||
},
|
||||
}
|
||||
|
||||
// func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||
// var callback2 = &Callback{logger: defaultLogger}
|
||||
|
||||
// callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||
// callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||
// callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
||||
// callback2.Delete().Register("after_create1", afterCreate1)
|
||||
// callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
||||
|
||||
// if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||
// t.Errorf("register callback with order")
|
||||
// }
|
||||
// }
|
||||
|
||||
for idx, data := range datas {
|
||||
callbacks := &Callbacks{}
|
||||
|
||||
for _, c := range data.callbacks {
|
||||
p := callbacks.Create()
|
||||
|
||||
if c.name == "" {
|
||||
c.name = getFuncName(c.h)
|
||||
}
|
||||
|
||||
if c.before != "" {
|
||||
p = p.Before(c.before)
|
||||
}
|
||||
|
||||
if c.after != "" {
|
||||
p = p.After(c.after)
|
||||
}
|
||||
|
||||
if c.match != nil {
|
||||
p = p.Match(c.match)
|
||||
}
|
||||
|
||||
if c.remove {
|
||||
p.Remove(c.name)
|
||||
} else if c.replace {
|
||||
p.Replace(c.name, c.h)
|
||||
} else {
|
||||
p.Register(c.name, c.h)
|
||||
}
|
||||
}
|
||||
|
||||
if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok {
|
||||
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,9 @@
|
|||
package gorm
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRecordNotFound record not found error
|
||||
|
@ -13,10 +16,14 @@ var (
|
|||
ErrUnaddressable = errors.New("using unaddressable value")
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e Error) Unwrap() error {
|
||||
return e.Err
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embeded into your model or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `gorm:"index"`
|
||||
}
|
|
@ -1,7 +1,15 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
type LogLevel int
|
||||
|
||||
var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)}
|
||||
|
||||
const (
|
||||
Info LogLevel = iota + 1
|
||||
Warn
|
||||
|
@ -11,4 +19,42 @@ const (
|
|||
// Interface logger interface
|
||||
type Interface interface {
|
||||
LogMode(LogLevel) Interface
|
||||
Info(string, ...interface{})
|
||||
Warn(string, ...interface{})
|
||||
Error(string, ...interface{})
|
||||
}
|
||||
|
||||
// Writer log writer interface
|
||||
type Writer interface {
|
||||
Print(...interface{})
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
Writer
|
||||
logLevel LogLevel
|
||||
}
|
||||
|
||||
func (logger Logger) LogMode(level LogLevel) Interface {
|
||||
return Logger{Writer: logger.Writer, logLevel: level}
|
||||
}
|
||||
|
||||
// Info print info
|
||||
func (logger Logger) Info(msg string, data ...interface{}) {
|
||||
if logger.logLevel >= Info {
|
||||
logger.Print("[info] " + fmt.Sprintf(msg, data...))
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (logger Logger) Warn(msg string, data ...interface{}) {
|
||||
if logger.logLevel >= Warn {
|
||||
logger.Print("[warn] " + fmt.Sprintf(msg, data...))
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (logger Logger) Error(msg string, data ...interface{}) {
|
||||
if logger.logLevel >= Error {
|
||||
logger.Print("[error] " + fmt.Sprintf(msg, data...))
|
||||
}
|
||||
}
|
||||
|
|
15
model.go
15
model.go
|
@ -1,15 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embeded into your model or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `gorm:"index"`
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
|
||||
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
|
||||
|
||||
func FileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
|
||||
return fmt.Sprintf("%v:%v", file, line)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
Loading…
Reference in New Issue