forked from mirror/gorm
Add callback create, delete
This commit is contained in:
parent
192ed06ad8
commit
973acd6339
24
callback.go
24
callback.go
|
@ -1,12 +1,14 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
type callback struct {
|
type callback struct {
|
||||||
creates []*func()
|
creates []*func(scope *Scope)
|
||||||
updates []*func()
|
updates []*func(scope *Scope)
|
||||||
deletes []*func()
|
deletes []*func(scope *Scope)
|
||||||
queries []*func()
|
queries []*func(scope *Scope)
|
||||||
processors []*callback_processor
|
processors []*callback_processor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,7 +19,7 @@ type callback_processor struct {
|
||||||
replace bool
|
replace bool
|
||||||
remove bool
|
remove bool
|
||||||
typ string
|
typ string
|
||||||
processor *func()
|
processor *func(scope *Scope)
|
||||||
callback *callback
|
callback *callback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,7 +55,7 @@ func (cp *callback_processor) After(name string) *callback_processor {
|
||||||
return cp
|
return cp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callback_processor) Register(name string, fc func()) {
|
func (cp *callback_processor) Register(name string, fc func(scope *Scope)) {
|
||||||
cp.name = name
|
cp.name = name
|
||||||
cp.processor = &fc
|
cp.processor = &fc
|
||||||
cp.callback.sort()
|
cp.callback.sort()
|
||||||
|
@ -65,7 +67,7 @@ func (cp *callback_processor) Remove(name string) {
|
||||||
cp.callback.sort()
|
cp.callback.sort()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callback_processor) Replace(name string, fc func()) {
|
func (cp *callback_processor) Replace(name string, fc func(scope *Scope)) {
|
||||||
cp.name = name
|
cp.name = name
|
||||||
cp.processor = &fc
|
cp.processor = &fc
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
|
@ -81,7 +83,7 @@ func getRIndex(strs []string, str string) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func sortProcessors(cps []*callback_processor) []*func() {
|
func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
|
||||||
var sortCallbackProcessor func(c *callback_processor, force bool)
|
var sortCallbackProcessor func(c *callback_processor, force bool)
|
||||||
var names, sortedNames = []string{}, []string{}
|
var names, sortedNames = []string{}, []string{}
|
||||||
|
|
||||||
|
@ -137,8 +139,8 @@ func sortProcessors(cps []*callback_processor) []*func() {
|
||||||
sortCallbackProcessor(cp, false)
|
sortCallbackProcessor(cp, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
var funcs = []*func(){}
|
var funcs = []*func(scope *Scope){}
|
||||||
var sortedFuncs = []*func(){}
|
var sortedFuncs = []*func(scope *Scope){}
|
||||||
for _, name := range sortedNames {
|
for _, name := range sortedNames {
|
||||||
index := getRIndex(names, name)
|
index := getRIndex(names, name)
|
||||||
if !cps[index].remove {
|
if !cps[index].remove {
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
func BeforeCreate(scope *Scope) {
|
||||||
|
scope.CallMethod("BeforeSave")
|
||||||
|
scope.CallMethod("BeforeCreate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveBeforeAssociations(scope *Scope) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func Create(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
var id interface{}
|
||||||
|
if scope.Dialect().SupportLastInsertId() {
|
||||||
|
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||||
|
id, err = sql_result.LastInsertId()
|
||||||
|
scope.Err(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.SetColumn(scope.PrimaryKey(), id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AfterCreate(scope *Scope) {
|
||||||
|
scope.CallMethod("AfterCreate")
|
||||||
|
scope.CallMethod("AfterSave")
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveAfterAssociations(scope *Scope) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Create().Register("before_create", BeforeCreate)
|
||||||
|
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
|
||||||
|
DefaultCallback.Create().Register("create", Create)
|
||||||
|
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
|
||||||
|
DefaultCallback.Create().Register("after_create", AfterCreate)
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BeforeDelete(scope *Scope) {
|
||||||
|
scope.CallMethod("BeforeDelete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Delete(scope *Scope) {
|
||||||
|
if scope.HasError() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
|
||||||
|
scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.TableName(), scope.AddToVars(time.Now()), scope.CombinedConditionSql()))
|
||||||
|
} else {
|
||||||
|
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql()))
|
||||||
|
}
|
||||||
|
scope.Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func AfterDelete(scope *Scope) {
|
||||||
|
scope.CallMethod("AfterDelete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Delete().Register("before_delete", BeforeDelete)
|
||||||
|
DefaultCallback.Delete().Register("delete", Delete)
|
||||||
|
DefaultCallback.Delete().Register("after_delete", AfterDelete)
|
||||||
|
}
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func equalFuncs(funcs []*func(), fnames []string) bool {
|
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
|
||||||
var names []string
|
var names []string
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
||||||
|
@ -16,11 +16,11 @@ func equalFuncs(funcs []*func(), fnames []string) bool {
|
||||||
return reflect.DeepEqual(names, fnames)
|
return reflect.DeepEqual(names, fnames)
|
||||||
}
|
}
|
||||||
|
|
||||||
func create() {}
|
func create(s *Scope) {}
|
||||||
func before_create1() {}
|
func before_create1(s *Scope) {}
|
||||||
func before_create2() {}
|
func before_create2(s *Scope) {}
|
||||||
func after_create1() {}
|
func after_create1(s *Scope) {}
|
||||||
func after_create2() {}
|
func after_create2(s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callback_processor{}}
|
var callback = &callback{processors: []*callback_processor{}}
|
||||||
|
@ -76,7 +76,7 @@ func TestRegisterCallbackWithComplexOrder2(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func replace_create() {}
|
func replace_create(s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callback_processor{}}
|
var callback = &callback{processors: []*callback_processor{}}
|
||||||
|
|
|
@ -1,12 +1,6 @@
|
||||||
package callback
|
package callbacks
|
||||||
|
|
||||||
import (
|
import "github.com/jinzhu/gorm"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Create(scope *gorm.Scope) {
|
func Create(scope *gorm.Scope) {
|
||||||
}
|
}
|
||||||
|
@ -15,32 +9,6 @@ func init() {
|
||||||
gorm.DefaultCallback.Create().Before().Register(Create)
|
gorm.DefaultCallback.Create().Before().Register(Create)
|
||||||
}
|
}
|
||||||
|
|
||||||
func query(db *DB) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func save(db *DB) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func create(db *DB) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func update(db *DB) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func Delete(scope *Scope) {
|
|
||||||
scope.CallMethod("BeforeDelete")
|
|
||||||
|
|
||||||
if !scope.HasError() {
|
|
||||||
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
|
|
||||||
scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.Table(), scope.AddToVars(time.Now()), scope.CombinedSql()))
|
|
||||||
} else {
|
|
||||||
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.Table(), scope.CombinedSql()))
|
|
||||||
}
|
|
||||||
scope.Exec()
|
|
||||||
scope.CallMethod("AfterDelete")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Create().Before("Delete").After("Lalala").Register("delete", Delete)
|
DefaultCallback.Create().Before("Delete").After("Lalala").Register("delete", Delete)
|
||||||
DefaultCallback.Update().Before("Delete").After("Lalala").Remove("replace", Delete)
|
DefaultCallback.Update().Before("Delete").After("Lalala").Remove("replace", Delete)
|
||||||
|
|
4
main.go
4
main.go
|
@ -8,7 +8,7 @@ import (
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Callbacks *callback
|
callback *callback
|
||||||
Error error
|
Error error
|
||||||
db sqlCommon
|
db sqlCommon
|
||||||
parent *DB
|
parent *DB
|
||||||
|
@ -22,7 +22,7 @@ type DB struct {
|
||||||
|
|
||||||
func Open(driver, source string) (DB, error) {
|
func Open(driver, source string) (DB, error) {
|
||||||
var err error
|
var err error
|
||||||
db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger}
|
db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback}
|
||||||
db.db, err = sql.Open(driver, source)
|
db.db, err = sql.Open(driver, source)
|
||||||
db.parent = &db
|
db.parent = &db
|
||||||
return db, err
|
return db, err
|
||||||
|
|
14
private.go
14
private.go
|
@ -1,11 +1,7 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,16 +51,6 @@ func (s *DB) hasError() bool {
|
||||||
return s.Error != nil
|
return s.Error != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fileWithLineNum() string {
|
|
||||||
for i := 1; i < 15; i++ {
|
|
||||||
_, file, line, ok := runtime.Caller(i)
|
|
||||||
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
|
||||||
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DB) print(v ...interface{}) {
|
func (s *DB) print(v ...interface{}) {
|
||||||
s.parent.logger.(logger).Print(v...)
|
s.parent.logger.(logger).Print(v...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import "github.com/jinzhu/gorm/dialect"
|
||||||
|
|
||||||
|
type Scope struct {
|
||||||
|
Search *search
|
||||||
|
Sql string
|
||||||
|
SqlVars []interface{}
|
||||||
|
db *DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) DB() sqlCommon {
|
||||||
|
return scope.db.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) Dialect() dialect.Dialect {
|
||||||
|
return scope.db.parent.dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) Err(err error) error {
|
||||||
|
if err != nil {
|
||||||
|
scope.db.err(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) HasError() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) PrimaryKey() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) HasColumn(name string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) SetColumn(column string, value interface{}) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) CallMethod(name string) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) CombinedConditionSql() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) TableName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) Raw(sql string, values ...interface{}) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) Exec() {
|
||||||
|
}
|
14
utils.go
14
utils.go
|
@ -3,7 +3,11 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -86,6 +90,16 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fileWithLineNum() string {
|
||||||
|
for i := 1; i < 15; i++ {
|
||||||
|
_, file, line, ok := runtime.Caller(i)
|
||||||
|
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||||
|
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func setFieldValue(field reflect.Value, value interface{}) bool {
|
func setFieldValue(field reflect.Value, value interface{}) bool {
|
||||||
if field.IsValid() && field.CanAddr() {
|
if field.IsValid() && field.CanAddr() {
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
|
|
Loading…
Reference in New Issue