make callback deletes works

This commit is contained in:
Jinzhu 2014-01-27 22:36:08 +08:00
parent eab146a275
commit 506d14a2f2
6 changed files with 171 additions and 54 deletions

View File

@ -11,7 +11,7 @@ func BeforeCreate(scope *Scope) {
scope.CallMethod("BeforeCreate") scope.CallMethod("BeforeCreate")
} }
func UpdateCreateTimeStamp(scope *Scope) { func UpdateTimeStampWhenCreate(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.SetColumn("CreatedAt", time.Now()) scope.SetColumn("CreatedAt", time.Now())
scope.SetColumn("UpdatedAt", time.Now()) scope.SetColumn("UpdatedAt", time.Now())
@ -66,7 +66,7 @@ func init() {
DefaultCallback.Create().Register("begin_transaction", BeginTransaction) DefaultCallback.Create().Register("begin_transaction", BeginTransaction)
DefaultCallback.Create().Register("before_create", BeforeCreate) DefaultCallback.Create().Register("before_create", BeforeCreate)
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations) DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("update_create_time_stamp", UpdateCreateTimeStamp) DefaultCallback.Create().Register("update_time_stamp_when_create", UpdateTimeStampWhenCreate)
DefaultCallback.Create().Register("create", Create) DefaultCallback.Create().Register("create", Create)
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("after_create", AfterCreate) DefaultCallback.Create().Register("after_create", AfterCreate)

View File

@ -6,6 +6,21 @@ import (
"time" "time"
) )
func AssignUpdateAttributes(scope *Scope) {
if attrs, ok := scope.Get("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs")
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
if len(updateAttrs) > 0 {
scope.Set("gorm:update_attrs", updateAttrs)
} else if !hasUpdate {
scope.SkipLeft()
return
}
}
}
}
func BeforeUpdate(scope *Scope) { func BeforeUpdate(scope *Scope) {
scope.CallMethod("BeforeSave") scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeUpdate") scope.CallMethod("BeforeUpdate")
@ -18,13 +33,23 @@ func UpdateTimeStampWhenUpdate(scope *Scope) {
} }
func Update(scope *Scope) { func Update(scope *Scope) {
defer scope.Trace(time.Now())
if !scope.HasError() { if !scope.HasError() {
var sqls []string var sqls []string
updateAttrs, ok := scope.Get("gorm:update_attrs")
if ok {
for key, value := range updateAttrs.(map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(key), scope.AddToVars(value)))
}
} else {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored { if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value)))
} }
} }
}
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v", "UPDATE %v SET %v %v",
@ -42,6 +67,7 @@ func AfterUpdate(scope *Scope) {
} }
func init() { func init() {
DefaultCallback.Update().Register("assign_update_attributes", AssignUpdateAttributes)
DefaultCallback.Update().Register("begin_transaction", BeginTransaction) DefaultCallback.Update().Register("begin_transaction", BeginTransaction)
DefaultCallback.Update().Register("before_update", BeforeUpdate) DefaultCallback.Update().Register("before_update", BeforeUpdate)
DefaultCallback.Update().Register("save_before_associations", SaveBeforeAssociations) DefaultCallback.Update().Register("save_before_associations", SaveBeforeAssociations)

View File

@ -929,7 +929,8 @@ func TestUpdate(t *testing.T) {
func TestUpdates(t *testing.T) { func TestUpdates(t *testing.T) {
product1 := Product{Code: "abc", Price: 10} product1 := Product{Code: "abc", Price: 10}
product2 := Product{Code: "cde", Price: 20} product2 := Product{Code: "cde", Price: 20}
db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100}) db.Save(&product1).Save(&product2)
db.Model(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100})
if product2.Code != "edf" || product2.Price != 100 { if product2.Code != "edf" || product2.Price != 100 {
t.Errorf("Record should be updated also with update attributes") t.Errorf("Record should be updated also with update attributes")
} }

View File

@ -165,8 +165,10 @@ func (s *DB) Update(attrs ...interface{}) *DB {
} }
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.clone().do(s.Value).begin().updateAttrs(values, ignoreProtectedAttrs...).update().commit_or_rollback().db return s.clone().NewScope(s.Value).
// return s.clone().NewScope(s.Value).callCallbacks(s.parent.callback.updates).db Set("gorm:update_interface", values).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
callCallbacks(s.parent.callback.updates).db
} }
func (s *DB) UpdateColumn(attrs ...interface{}) *DB { func (s *DB) UpdateColumn(attrs ...interface{}) *DB {

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/jinzhu/gorm/dialect" "github.com/jinzhu/gorm/dialect"
"go/ast" "go/ast"
"strconv"
"strings" "strings"
"time" "time"
@ -19,6 +20,7 @@ type Scope struct {
SqlVars []interface{} SqlVars []interface{}
db *DB db *DB
_values map[string]interface{} _values map[string]interface{}
skipLeft bool
} }
func (db *DB) NewScope(value interface{}) *Scope { func (db *DB) NewScope(value interface{}) *Scope {
@ -26,9 +28,16 @@ func (db *DB) NewScope(value interface{}) *Scope {
return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}} return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}}
} }
func (scope *Scope) SkipLeft() {
scope.skipLeft = true
}
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
for _, f := range funcs { for _, f := range funcs {
(*f)(scope) (*f)(scope)
if scope.skipLeft {
break
}
} }
return scope return scope
} }
@ -90,12 +99,54 @@ func (scope *Scope) HasColumn(name string) bool {
return false return false
} }
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
data := reflect.Indirect(reflect.ValueOf(scope.Value))
if !data.CanAddr() {
return values, true
}
for key, value := range values {
if field := data.FieldByName(snakeToUpperCamel(key)); field.IsValid() {
if field.Interface() != value {
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if s, ok := value.(string); ok {
i, err := strconv.Atoi(s)
if scope.Err(err) == nil {
value = i
}
}
scope.db.log(field.Int() != reflect.ValueOf(value).Int())
if field.Int() != reflect.ValueOf(value).Int() {
hasUpdate = true
setFieldValue(field, value)
}
default:
hasUpdate = true
setFieldValue(field, value)
}
}
}
}
return
}
func (scope *Scope) SetColumn(column string, value interface{}) { func (scope *Scope) SetColumn(column string, value interface{}) {
if scope.Value == nil {
return
}
data := reflect.Indirect(reflect.ValueOf(scope.Value)) data := reflect.Indirect(reflect.ValueOf(scope.Value))
setFieldValue(data.FieldByName(snakeToUpperCamel(column)), value) setFieldValue(data.FieldByName(snakeToUpperCamel(column)), value)
} }
func (scope *Scope) CallMethod(name string) { func (scope *Scope) CallMethod(name string) {
if scope.Value == nil {
return
}
if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() { if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() {
fi := fm.Interface() fi := fm.Interface()
if f, ok := fi.(func()); ok { if f, ok := fi.(func()); ok {
@ -193,31 +244,32 @@ func (scope *Scope) SqlTagForField(field *Field) (tag string) {
} }
func (scope *Scope) Fields() []*Field { func (scope *Scope) Fields() []*Field {
indirect_value := reflect.Indirect(reflect.ValueOf(scope.Value)) indirectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
fields := []*Field{} fields := []*Field{}
if !indirect_value.IsValid() { if !indirectValue.IsValid() {
return fields return fields
} }
scope_typ := indirect_value.Type() scopeTyp := indirectValue.Type()
for i := 0; i < scope_typ.NumField(); i++ { for i := 0; i < scopeTyp.NumField(); i++ {
field_struct := scope_typ.Field(i) fieldStruct := scopeTyp.Field(i)
if field_struct.Anonymous || !ast.IsExported(field_struct.Name) { if fieldStruct.Anonymous || !ast.IsExported(fieldStruct.Name) {
continue continue
} }
var field Field var field Field
field.Name = field_struct.Name field.Name = fieldStruct.Name
field.DBName = toSnake(field_struct.Name) field.DBName = toSnake(fieldStruct.Name)
value := indirect_value.FieldByName(field_struct.Name) value := indirectValue.FieldByName(fieldStruct.Name)
field.Value = value.Interface() field.Value = value.Interface()
field.IsBlank = isBlank(value) field.IsBlank = isBlank(value)
tag, addational_tag, size := parseSqlTag(field_struct.Tag.Get(scope.db.parent.tagIdentifier)) if scope.db != nil {
tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier))
field.Tag = tag field.Tag = tag
field.AddationalTag = addational_tag field.AddationalTag = addationalTag
field.Size = size field.Size = size
field.SqlTag = scope.SqlTagForField(&field) field.SqlTag = scope.SqlTagForField(&field)
@ -234,7 +286,7 @@ func (scope *Scope) Fields() []*Field {
typ = typ.Elem() typ = typ.Elem()
if _, ok := field.Value.([]byte); !ok { if _, ok := field.Value.([]byte); !ok {
foreignKey := scope_typ.Name() + "Id" foreignKey := scopeTyp.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey field.ForeignKey = foreignKey
} }
@ -246,7 +298,7 @@ func (scope *Scope) Fields() []*Field {
field.ForeignKey = field.Name + "Id" field.ForeignKey = field.Name + "Id"
field.BeforeAssociation = true field.BeforeAssociation = true
} else { } else {
foreignKey := scope_typ.Name() + "Id" foreignKey := scopeTyp.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey field.ForeignKey = foreignKey
} }
@ -254,6 +306,7 @@ func (scope *Scope) Fields() []*Field {
} }
} }
} }
}
fields = append(fields, &field) fields = append(fields, &field)
} }
@ -276,8 +329,9 @@ func (scope *Scope) Get(name string) (value interface{}, ok bool) {
return return
} }
func (scope *Scope) Set(name string, value interface{}) { func (scope *Scope) Set(name string, value interface{}) *Scope {
scope._values[name] = value scope._values[name] = value
return scope
} }
func (scope *Scope) Trace(t time.Time) { func (scope *Scope) Trace(t time.Time) {

View File

@ -124,3 +124,37 @@ func setFieldValue(field reflect.Value, value interface{}) bool {
func isBlank(value reflect.Value) bool { func isBlank(value reflect.Value) bool {
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
} }
func convertInterfaceToMap(values interface{}) map[string]interface{} {
attrs := map[string]interface{}{}
switch value := values.(type) {
case map[string]interface{}:
for k, v := range value {
attrs[toSnake(k)] = v
}
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)
switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[toSnake(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
scope := Scope{Value: values}
for _, field := range scope.Fields() {
if !field.IsBlank {
attrs[field.DBName] = field.Value
}
}
}
}
return attrs
}