diff --git a/callback_create.go b/callback_create.go index 1bd8ff05..71db4ef0 100644 --- a/callback_create.go +++ b/callback_create.go @@ -32,6 +32,8 @@ func Create(scope *Scope) { if !field.IsBlank || !field.HasDefaultValue { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) + } else if field.HasDefaultValue { + scope.InstanceSet("gorm:force_reload_after_create", true) } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { @@ -95,6 +97,12 @@ func Create(scope *Scope) { } } +func ForceReloadAfterCreate(scope *Scope) { + if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok { + scope.DB().New().First(scope.Value) + } +} + func AfterCreate(scope *Scope) { scope.CallMethodWithErrorCheck("AfterCreate") scope.CallMethodWithErrorCheck("AfterSave") @@ -106,6 +114,7 @@ func init() { DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) DefaultCallback.Create().Register("gorm:create", Create) + DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("gorm:after_create", AfterCreate) DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callback_update.go b/callback_update.go index 6090ee6b..4c9952d2 100644 --- a/callback_update.go +++ b/callback_update.go @@ -51,9 +51,7 @@ func Update(scope *Scope) { fields := scope.Fields() for _, field := range fields { if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { - if !field.IsBlank || !field.HasDefaultValue { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, dbName := range relationship.ForeignDBNames { if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { diff --git a/update_test.go b/update_test.go index 9a0af806..75877488 100644 --- a/update_test.go +++ b/update_test.go @@ -113,6 +113,14 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { if animal.Name != "amazing horse" { t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) } + + // When changing a field with a default value with blank value + animal.Name = "" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "" { + t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) + } } func TestUpdates(t *testing.T) {