From 174cfc5cbda81b3f51f847c9553c99eb3185ad07 Mon Sep 17 00:00:00 2001 From: black Date: Wed, 28 Jun 2023 13:31:13 +0800 Subject: [PATCH] fix: on confilct with default value --- callbacks/create.go | 6 ++-- tests/associations_test.go | 51 +++++++++++++++++++++++++++++++ tests/create_test.go | 61 +++++++++++++++++++++++++++++--------- tests/go.mod | 1 + 4 files changed, 102 insertions(+), 17 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index f0b78139..766eacf0 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -3,7 +3,6 @@ package callbacks import ( "fmt" "reflect" - "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -303,8 +302,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || - strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { + // We can confirm the column either has a value or default value, + // so checking HasDefaultValue again will cause some columns to be missed. + if !field.PrimaryKey && field.AutoCreateTime == 0 { if field.AutoUpdateTime > 0 { assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} switch field.AutoUpdateTime { diff --git a/tests/associations_test.go b/tests/associations_test.go index 4e8862e5..c4e35020 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/datatypes" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" @@ -183,6 +184,56 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { } } +func TestFullSaveAssociationsWithJSONDefault(t *testing.T) { + if DB.Dialector.Name() == "mysql" { + t.Skip() // mysql json can't have a default value + } + + type ValueDep struct { + ID int + ValueID int + Name string + Params datatypes.JSONMap `gorm:"default:'{}'"` + } + type Value struct { + ID int + Name string + Dep ValueDep + } + + if err := DB.Migrator().DropTable(&ValueDep{}, &Value{}); err != nil { + t.Fatalf("failed to drop value table, got err: %v", err) + } + if err := DB.AutoMigrate(&ValueDep{}, &Value{}); err != nil { + t.Fatalf("failed to migrate value table, got err: %v", err) + } + + if err := DB.Create(&Value{ + Name: "foo", + Dep: ValueDep{Name: "bar", Params: map[string]interface{}{"foo": "bar"}}, + }).Error; err != nil { + t.Errorf("failed to create value, got err: %v", err) + } + + var value Value + if err := DB.Preload("Dep").First(&value).Error; err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + + value.Dep.Params["foo"] = "new bar" + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&value).Error; err != nil { + t.Errorf("failed to svae value, got err: %v", err) + } + + var result Value + if err := DB.Preload("Dep").First(&result).Error; err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + if result.Dep.Params["foo"] != "new bar" { + t.Errorf("failed to save value dep params, got: %v", result.Dep.Params) + } +} + func TestFullSaveAssociations(t *testing.T) { coupon := &Coupon{ AppliesToProduct: []*CouponProduct{ diff --git a/tests/create_test.go b/tests/create_test.go index 02613b72..7cd397f9 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jinzhu/now" + "gorm.io/datatypes" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" @@ -580,38 +581,70 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { } } -func TestCreateOnConfilctWithDefalutNull(t *testing.T) { - type OnConfilctUser struct { +func TestCreateOnConflictWithDefaultNull(t *testing.T) { + type OnConflictUser struct { ID string Name string `gorm:"default:null"` Email string Mobile string `gorm:"default:'133xxxx'"` } - err := DB.Migrator().DropTable(&OnConfilctUser{}) + err := DB.Migrator().DropTable(&OnConflictUser{}) AssertEqual(t, err, nil) - err = DB.AutoMigrate(&OnConfilctUser{}) + err = DB.AutoMigrate(&OnConflictUser{}) AssertEqual(t, err, nil) - u := OnConfilctUser{ - ID: "on-confilct-user-id", - Name: "on-confilct-user-name", - Email: "on-confilct-user-email", - Mobile: "on-confilct-user-mobile", + u := OnConflictUser{ + ID: "on-conflict-user-id", + Name: "on-conflict-user-name", + Email: "on-conflict-user-email", + Mobile: "on-conflict-user-mobile", } err = DB.Create(&u).Error AssertEqual(t, err, nil) - u.Name = "on-confilct-user-name-2" - u.Email = "on-confilct-user-email-2" + u.Name = "on-conflict-user-name-2" + u.Email = "on-conflict-user-email-2" u.Mobile = "" err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error AssertEqual(t, err, nil) - var u2 OnConfilctUser + var u2 OnConflictUser err = DB.Where("id = ?", u.ID).First(&u2).Error AssertEqual(t, err, nil) - AssertEqual(t, u2.Name, "on-confilct-user-name-2") - AssertEqual(t, u2.Email, "on-confilct-user-email-2") + AssertEqual(t, u2.Name, "on-conflict-user-name-2") + AssertEqual(t, u2.Email, "on-conflict-user-email-2") AssertEqual(t, u2.Mobile, "133xxxx") } + +func TestCreateOnConflictWithDefaultJSON(t *testing.T) { + if DB.Dialector.Name() == "mysql" { + t.Skip() // mysql json can't have a default value + } + type OnConflictValue struct { + ID int + Params datatypes.JSONMap `gorm:"default:'{}'"` + } + + err := DB.Migrator().DropTable(&OnConflictValue{}) + AssertEqual(t, err, nil) + err = DB.AutoMigrate(&OnConflictValue{}) + AssertEqual(t, err, nil) + + v := OnConflictValue{ + Params: datatypes.JSONMap{"foo": "bar"}, + } + err = DB.Create(&v).Error + AssertEqual(t, err, nil) + + err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&OnConflictValue{ + ID: v.ID, + Params: datatypes.JSONMap{"foo": "new-bar"}, + }).Error + AssertEqual(t, err, nil) + + var v2 OnConflictValue + err = DB.Where("id = ?", v.ID).First(&v2).Error + AssertEqual(t, err, nil) + AssertEqual(t, v2.Params, datatypes.JSONMap{"foo": "new-bar"}) +} diff --git a/tests/go.mod b/tests/go.mod index 0b38b9d0..d070d808 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,6 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.8 github.com/mattn/go-sqlite3 v1.14.16 // indirect + gorm.io/datatypes v1.2.0 gorm.io/driver/mysql v1.5.0 gorm.io/driver/postgres v1.5.0 gorm.io/driver/sqlite v1.5.0