make callback create works

This commit is contained in:
Jinzhu 2014-01-26 19:34:06 +08:00
parent 261ece5696
commit 8dd7b4ed91
7 changed files with 150 additions and 42 deletions

View File

@ -1,15 +1,40 @@
package gorm package gorm
import (
"fmt"
"strings"
"time"
)
func BeforeCreate(scope *Scope) { func BeforeCreate(scope *Scope) {
scope.CallMethod("BeforeSave") scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeCreate") scope.CallMethod("BeforeCreate")
} }
func SaveBeforeAssociations(scope *Scope) {
}
func Create(scope *Scope) { func Create(scope *Scope) {
defer scope.Trace(time.Now())
if !scope.HasError() { if !scope.HasError() {
// set create sql
var sqls, columns []string
for _, field := range scope.Fields() {
if field.IsBlank || len(field.SqlTag) == 0 {
continue
}
columns = append(columns, scope.quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Value))
}
scope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v) %v",
scope.TableName(),
strings.Join(columns, ","),
strings.Join(sqls, ","),
scope.Dialect().ReturningStr(scope.PrimaryKey()),
))
// execute create sql
var id interface{} var id interface{}
if scope.Dialect().SupportLastInsertId() { if scope.Dialect().SupportLastInsertId() {
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
@ -20,8 +45,10 @@ func Create(scope *Scope) {
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
} }
if !scope.HasError() {
scope.SetColumn(scope.PrimaryKey(), id) scope.SetColumn(scope.PrimaryKey(), id)
} }
}
} }
func AfterCreate(scope *Scope) { func AfterCreate(scope *Scope) {
@ -29,13 +56,12 @@ func AfterCreate(scope *Scope) {
scope.CallMethod("AfterSave") scope.CallMethod("AfterSave")
} }
func SaveAfterAssociations(scope *Scope) {
}
func init() { func init() {
DefaultCallback.Create().Register("begin_transaction", BeginTransaction)
DefaultCallback.Create().Register("before_create", BeforeCreate) DefaultCallback.Create().Register("before_create", BeforeCreate)
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations) DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("create", Create) DefaultCallback.Create().Register("create", Create)
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("after_create", AfterCreate) DefaultCallback.Create().Register("after_create", AfterCreate)
DefaultCallback.Create().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction)
} }

View File

@ -12,10 +12,7 @@ func BeforeDelete(scope *Scope) {
func Delete(scope *Scope) { func Delete(scope *Scope) {
defer scope.Trace(time.Now()) defer scope.Trace(time.Now())
if scope.HasError() { if !scope.HasError() {
return
}
if !scope.Search.unscope && scope.HasColumn("DeletedAt") { if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
scope.Raw( scope.Raw(
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
@ -28,6 +25,7 @@ func Delete(scope *Scope) {
} }
scope.Exec() scope.Exec()
}
} }
func AfterDelete(scope *Scope) { func AfterDelete(scope *Scope) {

View File

@ -7,3 +7,9 @@ func BeginTransaction(scope *Scope) {
func CommitOrRollbackTransaction(scope *Scope) { func CommitOrRollbackTransaction(scope *Scope) {
scope.CommitOrRollback() scope.CommitOrRollback()
} }
func SaveBeforeAssociations(scope *Scope) {
}
func SaveAfterAssociations(scope *Scope) {
}

15
do.go
View File

@ -311,21 +311,6 @@ func (s *Do) update() *Do {
return s return s
} }
func (s *Do) delete() *Do {
s.model.callMethod("BeforeDelete")
if !s.db.hasError() {
if !s.search.unscope && s.model.hasColumn("DeletedAt") {
s.setSql(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql()))
} else {
s.setSql(fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql()))
}
s.exec()
s.model.callMethod("AfterDelete")
}
return s
}
func (s *Do) prepareQuerySql() { func (s *Do) prepareQuerySql() {
if s.search.raw { if s.search.raw {
s.setSql(strings.TrimLeft(s.combinedSql(), "WHERE ")) s.setSql(strings.TrimLeft(s.combinedSql(), "WHERE "))

View File

@ -10,9 +10,17 @@ import (
type Field struct { type Field struct {
Name string Name string
DBName string
Value interface{} Value interface{}
model *Model IsBlank bool
IsIgnored bool
Tag string
AddationalTag string
Size int
SqlTag string
dbName string dbName string
model *Model
isBlank bool isBlank bool
ignoreField bool ignoreField bool
isPrimaryKey bool isPrimaryKey bool
@ -25,6 +33,16 @@ type Field struct {
structField reflect.StructField structField reflect.StructField
} }
func (f *Field) IsScanner() bool {
_, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
return is_scanner
}
func (f *Field) IsTime() bool {
_, is_time := f.Value.(time.Time)
return is_time
}
func (f *Field) parseBlank() { func (f *Field) parseBlank() {
f.isBlank = isBlank(f.reflectValue) f.isBlank = isBlank(f.reflectValue)
} }
@ -38,7 +56,7 @@ func (f *Field) parseIgnore() {
} }
func (f *Field) isScanner() bool { func (f *Field) isScanner() bool {
_, is_scanner := reflect.New(f.reflectValue.Type()).Interface().(sql.Scanner) _, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
return is_scanner return is_scanner
} }

View File

@ -178,7 +178,12 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
} }
func (s *DB) Save(value interface{}) *DB { func (s *DB) Save(value interface{}) *DB {
scope := s.clone().newScope(value)
if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db
} else {
return s.clone().do(value).begin().save().commit_or_rollback().db return s.clone().do(value).begin().save().commit_or_rollback().db
}
} }
func (s *DB) Delete(value interface{}) *DB { func (s *DB) Delete(value interface{}) *DB {

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/jinzhu/gorm/dialect" "github.com/jinzhu/gorm/dialect"
"go/ast"
"strings" "strings"
"time" "time"
@ -149,8 +150,77 @@ func (s *Scope) CombinedConditionSql() string {
return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql() return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql()
} }
func (scope *Scope) SqlTagForField(field *Field) (tag string) {
value := field.Value
reflect_value := reflect.ValueOf(value)
if field.IsScanner() {
value = reflect_value.Field(0).Interface()
}
switch reflect_value.Kind() {
case reflect.Slice:
if _, ok := value.([]byte); !ok {
return
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
return
}
}
if tag = field.Tag; len(tag) == 0 && tag != "-" {
if field.isPrimaryKey {
tag = scope.Dialect().PrimaryKeyTag(value, field.Size)
} else {
tag = scope.Dialect().SqlTag(value, field.Size)
}
if len(field.AddationalTag) > 0 {
tag = tag + " " + field.AddationalTag
}
}
return
}
func (scope *Scope) Fields() []*Field { func (scope *Scope) Fields() []*Field {
return []*Field{} indirect_value := reflect.Indirect(reflect.ValueOf(scope.Value))
fields := []*Field{}
if !indirect_value.IsValid() {
return fields
}
typ := indirect_value.Type()
for i := 0; i < typ.NumField(); i++ {
field_struct := typ.Field(i)
if field_struct.Anonymous || !ast.IsExported(field_struct.Name) {
continue
}
var field Field
field.Name = field_struct.Name
field.DBName = toSnake(field_struct.Name)
value := indirect_value.FieldByName(field_struct.Name)
field.Value = value.Interface()
field.IsBlank = isBlank(value)
tag, addational_tag, size := parseSqlTag(field_struct.Tag.Get(scope.db.parent.tagIdentifier))
field.Tag = tag
field.AddationalTag = addational_tag
field.Size = size
field.SqlTag = scope.SqlTagForField(&field)
if tag == "-" {
field.IsIgnored = true
}
field.parseAssociation()
fields = append(fields, &field)
}
return fields
} }
func (scope *Scope) Raw(sql string) { func (scope *Scope) Raw(sql string) {