Add callback create, delete

This commit is contained in:
Jinzhu 2014-01-26 12:41:37 +08:00
parent 192ed06ad8
commit 973acd6339
9 changed files with 173 additions and 68 deletions

View File

@ -1,12 +1,14 @@
package gorm
import "fmt"
import (
"fmt"
)
type callback struct {
creates []*func()
updates []*func()
deletes []*func()
queries []*func()
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
processors []*callback_processor
}
@ -17,7 +19,7 @@ type callback_processor struct {
replace bool
remove bool
typ string
processor *func()
processor *func(scope *Scope)
callback *callback
}
@ -53,7 +55,7 @@ func (cp *callback_processor) After(name string) *callback_processor {
return cp
}
func (cp *callback_processor) Register(name string, fc func()) {
func (cp *callback_processor) Register(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.callback.sort()
@ -65,7 +67,7 @@ func (cp *callback_processor) Remove(name string) {
cp.callback.sort()
}
func (cp *callback_processor) Replace(name string, fc func()) {
func (cp *callback_processor) Replace(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.replace = true
@ -81,7 +83,7 @@ func getRIndex(strs []string, str string) int {
return -1
}
func sortProcessors(cps []*callback_processor) []*func() {
func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *callback_processor, force bool)
var names, sortedNames = []string{}, []string{}
@ -137,8 +139,8 @@ func sortProcessors(cps []*callback_processor) []*func() {
sortCallbackProcessor(cp, false)
}
var funcs = []*func(){}
var sortedFuncs = []*func(){}
var funcs = []*func(scope *Scope){}
var sortedFuncs = []*func(scope *Scope){}
for _, name := range sortedNames {
index := getRIndex(names, name)
if !cps[index].remove {

41
callback_create.go Normal file
View File

@ -0,0 +1,41 @@
package gorm
func BeforeCreate(scope *Scope) {
scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeCreate")
}
func SaveBeforeAssociations(scope *Scope) {
}
func Create(scope *Scope) {
if !scope.HasError() {
var id interface{}
if scope.Dialect().SupportLastInsertId() {
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err = sql_result.LastInsertId()
scope.Err(err)
}
} else {
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
}
scope.SetColumn(scope.PrimaryKey(), id)
}
}
func AfterCreate(scope *Scope) {
scope.CallMethod("AfterCreate")
scope.CallMethod("AfterSave")
}
func SaveAfterAssociations(scope *Scope) {
}
func init() {
DefaultCallback.Create().Register("before_create", BeforeCreate)
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("create", Create)
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("after_create", AfterCreate)
}

33
callback_delete.go Normal file
View File

@ -0,0 +1,33 @@
package gorm
import (
"fmt"
"time"
)
func BeforeDelete(scope *Scope) {
scope.CallMethod("BeforeDelete")
}
func Delete(scope *Scope) {
if scope.HasError() {
return
}
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.TableName(), scope.AddToVars(time.Now()), scope.CombinedConditionSql()))
} else {
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql()))
}
scope.Exec()
}
func AfterDelete(scope *Scope) {
scope.CallMethod("AfterDelete")
}
func init() {
DefaultCallback.Delete().Register("before_delete", BeforeDelete)
DefaultCallback.Delete().Register("delete", Delete)
DefaultCallback.Delete().Register("after_delete", AfterDelete)
}

View File

@ -7,7 +7,7 @@ import (
"testing"
)
func equalFuncs(funcs []*func(), fnames []string) bool {
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
@ -16,11 +16,11 @@ func equalFuncs(funcs []*func(), fnames []string) bool {
return reflect.DeepEqual(names, fnames)
}
func create() {}
func before_create1() {}
func before_create2() {}
func after_create1() {}
func after_create2() {}
func create(s *Scope) {}
func before_create1(s *Scope) {}
func before_create2(s *Scope) {}
func after_create1(s *Scope) {}
func after_create2(s *Scope) {}
func TestRegisterCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}}
@ -76,7 +76,7 @@ func TestRegisterCallbackWithComplexOrder2(t *testing.T) {
}
}
func replace_create() {}
func replace_create(s *Scope) {}
func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}}

View File

@ -1,12 +1,6 @@
package callback
package callbacks
import (
"fmt"
"github.com/jinzhu/gorm"
"time"
)
import "github.com/jinzhu/gorm"
func Create(scope *gorm.Scope) {
}
@ -15,32 +9,6 @@ func init() {
gorm.DefaultCallback.Create().Before().Register(Create)
}
func query(db *DB) {
}
func save(db *DB) {
}
func create(db *DB) {
}
func update(db *DB) {
}
func Delete(scope *Scope) {
scope.CallMethod("BeforeDelete")
if !scope.HasError() {
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.Table(), scope.AddToVars(time.Now()), scope.CombinedSql()))
} else {
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.Table(), scope.CombinedSql()))
}
scope.Exec()
scope.CallMethod("AfterDelete")
}
}
func init() {
DefaultCallback.Create().Before("Delete").After("Lalala").Register("delete", Delete)
DefaultCallback.Update().Before("Delete").After("Lalala").Remove("replace", Delete)

View File

@ -8,7 +8,7 @@ import (
type DB struct {
Value interface{}
Callbacks *callback
callback *callback
Error error
db sqlCommon
parent *DB
@ -22,7 +22,7 @@ type DB struct {
func Open(driver, source string) (DB, error) {
var err error
db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger}
db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback}
db.db, err = sql.Open(driver, source)
db.parent = &db
return db, err

View File

@ -1,11 +1,7 @@
package gorm
import (
"fmt"
"os"
"regexp"
"runtime"
"strings"
"time"
)
@ -55,16 +51,6 @@ func (s *DB) hasError() bool {
return s.Error != nil
}
func fileWithLineNum() string {
for i := 1; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
}
}
return ""
}
func (s *DB) print(v ...interface{}) {
s.parent.logger.(logger).Print(v...)
}

61
scope.go Normal file
View File

@ -0,0 +1,61 @@
package gorm
import "github.com/jinzhu/gorm/dialect"
type Scope struct {
Search *search
Sql string
SqlVars []interface{}
db *DB
}
func (scope *Scope) DB() sqlCommon {
return scope.db.db
}
func (scope *Scope) Dialect() dialect.Dialect {
return scope.db.parent.dialect
}
func (scope *Scope) Err(err error) error {
if err != nil {
scope.db.err(err)
}
return err
}
func (scope *Scope) HasError() bool {
return true
}
func (scope *Scope) PrimaryKey() string {
return ""
}
func (scope *Scope) HasColumn(name string) bool {
return false
}
func (scope *Scope) SetColumn(column string, value interface{}) {
}
func (scope *Scope) CallMethod(name string) {
}
func (scope *Scope) CombinedConditionSql() string {
return ""
}
func (scope *Scope) AddToVars(value interface{}) string {
return ""
}
func (scope *Scope) TableName() string {
return ""
}
func (scope *Scope) Raw(sql string, values ...interface{}) {
}
func (scope *Scope) Exec() {
}

View File

@ -3,7 +3,11 @@ package gorm
import (
"bytes"
"database/sql"
"fmt"
"os"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
@ -86,6 +90,16 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
return
}
func fileWithLineNum() string {
for i := 1; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
}
}
return ""
}
func setFieldValue(field reflect.Value, value interface{}) bool {
if field.IsValid() && field.CanAddr() {
switch field.Kind() {