package gorm import ( "errors" "fmt" "sort" "strings" ) // Define callbacks for updating func init() { DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) DefaultCallback.Update().Register("gorm:update", updateCallback) DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } // assignUpdatingAttributesCallback assign updating attributes to model func assignUpdatingAttributesCallback(scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { scope.InstanceSet("gorm:update_attrs", updateMaps) } else { scope.SkipLeft() } } } // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating func beforeUpdateCallback(scope *Scope) { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { scope.Err(errors.New("Missing WHERE clause while updating")) return } if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { scope.CallMethod("BeforeSave") } if !scope.HasError() { scope.CallMethod("BeforeUpdate") } } } // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { scope.SetColumn("UpdatedAt", NowFunc()) } } // updateCallback the callback used to update data to database func updateCallback(scope *Scope) { if !scope.HasError() { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { // Sort the column names so that the generated SQL is the same every time. updateMap := updateAttrs.(map[string]interface{}) var columns []string for c := range updateMap { columns = append(columns, c) } sort.Strings(columns) for _, column := range columns { value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { for _, field := range scope.Fields() { if scope.changeableField(field) { if !field.IsPrimaryKey && field.IsNormal { if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) } } } } } } var extraOption string if str, ok := scope.Get("gorm:update_option"); ok { extraOption = fmt.Sprint(str) } if len(sqls) > 0 { scope.Raw(fmt.Sprintf( "UPDATE %v SET %v%v%v", scope.QuotedTableName(), strings.Join(sqls, ", "), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), )).Exec() } } } // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating func afterUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { scope.CallMethod("AfterUpdate") } if !scope.HasError() { scope.CallMethod("AfterSave") } } }