From 58a7252251a205976d6fc0a99aeab9e6965fe9a7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 17 Jan 2016 18:38:18 +0800 Subject: [PATCH] Refactor update callback --- callback_create.go | 4 ++-- callback_update.go | 46 +++++++++++++++++++++++++++------------------- scope_private.go | 31 +++---------------------------- utils_private.go | 2 +- 4 files changed, 33 insertions(+), 50 deletions(-) diff --git a/callback_create.go b/callback_create.go index 921aa4cc..b8725363 100644 --- a/callback_create.go +++ b/callback_create.go @@ -10,7 +10,7 @@ func init() { defaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) defaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) defaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - defaultCallback.Create().Register("gorm:update_time_stamp_when_create", updateTimeStampForCreateCallback) + defaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) defaultCallback.Create().Register("gorm:create", createCallback) defaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) defaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) @@ -120,7 +120,7 @@ func forceReloadAfterCreateCallback(scope *Scope) { } } -// beforeCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating +// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating func afterCreateCallback(scope *Scope) { if !scope.HasError() { scope.CallMethod("AfterCreate") diff --git a/callback_update.go b/callback_update.go index b9d2bcbc..b3a6c7da 100644 --- a/callback_update.go +++ b/callback_update.go @@ -5,19 +5,21 @@ import ( "strings" ) +// Define callbacks for updating func init() { - defaultCallback.Update().Register("gorm:assign_update_attributes", assignUpdateAttributesCallback) + defaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) defaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) defaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) defaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - defaultCallback.Update().Register("gorm:update_time_stamp_when_update", updateTimeStampForUpdateCallback) + defaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) defaultCallback.Update().Register("gorm:update", updateCallback) defaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) defaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } -func assignUpdateAttributesCallback(scope *Scope) { +// assignUpdatingAttributesCallback assign updating attributes to model +func assignUpdatingAttributesCallback(scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if maps := convertInterfaceToMap(attrs); len(maps) > 0 { protected, ok := scope.Get("gorm:ignore_protected_attrs") @@ -36,6 +38,7 @@ func assignUpdateAttributesCallback(scope *Scope) { } } +// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating func beforeUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { @@ -47,32 +50,40 @@ func beforeUpdateCallback(scope *Scope) { } } +// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { scope.SetColumn("UpdatedAt", NowFunc()) } } +// updateCallback the callback used to update data to database func updateCallback(scope *Scope) { if !scope.HasError() { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for key, value := range updateAttrs.(map[string]interface{}) { - if scope.changeableDBColumn(key) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) + for column, value := range updateAttrs.(map[string]interface{}) { + if field, ok := scope.FieldByName(column); ok { + if scope.changeableField(field) { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(value))) + } + } else { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } } else { fields := scope.Fields() for _, field := range fields { - if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, dbName := range relationship.ForeignDBNames { - if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { - sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) - sqls = append(sqls, sql) + if scope.changeableField(field) { + if !field.IsPrimaryKey && field.IsNormal { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { + for _, foreignKey := range relationship.ForeignDBNames { + if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) { + sqls = append(sqls, + fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) + } } } } @@ -81,16 +92,13 @@ func updateCallback(scope *Scope) { if len(sqls) > 0 { scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v %v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - scope.CombinedConditionSql(), - )) - scope.Exec() + "UPDATE %v SET %v %v", scope.QuotedTableName(), strings.Join(sqls, ", "), scope.CombinedConditionSql(), + )).Exec() } } } +// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating func afterUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { diff --git a/scope_private.go b/scope_private.go index e0a98c1c..ef16cf93 100644 --- a/scope_private.go +++ b/scope_private.go @@ -441,32 +441,8 @@ func (scope *Scope) trace(t time.Time) { } } -func (scope *Scope) changeableDBColumn(column string) bool { - selectAttrs := scope.SelectAttrs() - omitAttrs := scope.OmitAttrs() - - if len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if column == ToDBName(attr) { - return true - } - } - return false - } - - for _, attr := range omitAttrs { - if column == ToDBName(attr) { - return false - } - } - return true -} - func (scope *Scope) changeableField(field *Field) bool { - selectAttrs := scope.SelectAttrs() - omitAttrs := scope.OmitAttrs() - - if len(selectAttrs) > 0 { + if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { for _, attr := range selectAttrs { if field.Name == attr || field.DBName == attr { return true @@ -475,7 +451,7 @@ func (scope *Scope) changeableField(field *Field) bool { return false } - for _, attr := range omitAttrs { + for _, attr := range scope.OmitAttrs() { if field.Name == attr || field.DBName == attr { return false } @@ -485,8 +461,7 @@ func (scope *Scope) changeableField(field *Field) bool { } func (scope *Scope) shouldSaveAssociations() bool { - saveAssociations, ok := scope.Get("gorm:save_associations") - if ok && !saveAssociations.(bool) { + if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { return false } return true && !scope.HasError() diff --git a/utils_private.go b/utils_private.go index 5c17eda5..2851a37e 100644 --- a/utils_private.go +++ b/utils_private.go @@ -46,7 +46,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { switch value := values.(type) { case map[string]interface{}: for k, v := range value { - attrs[ToDBName(k)] = v + attrs[k] = v } case []interface{}: for _, v := range value {