From 187eae8d9c209ed5db628afc76890288cb00a2cc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Mar 2015 18:30:59 +0800 Subject: [PATCH] Update with Select and Omit --- callback_create.go | 4 +-- callback_shared.go | 4 +-- callback_update.go | 11 ++++++-- create_test.go | 4 +-- scope.go | 2 +- update_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 84 insertions(+), 9 deletions(-) diff --git a/callback_create.go b/callback_create.go index ff111f76..b21df08b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -26,7 +26,7 @@ func Create(scope *Scope) { var sqls, columns []string fields := scope.Fields() for _, field := range fields { - if scope.ValidField(field) { + if scope.changeableField(field) { if field.IsNormal { if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { if !field.IsBlank || !field.HasDefaultValue { @@ -35,7 +35,7 @@ func Create(scope *Scope) { } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - if relationField := fields[relationship.ForeignDBName]; !scope.ValidField(relationField) { + if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { columns = append(columns, scope.Quote(relationField.DBName)) sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) } diff --git a/callback_shared.go b/callback_shared.go index ce99d3f0..99ad8f50 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -12,7 +12,7 @@ func CommitOrRollbackTransaction(scope *Scope) { func SaveBeforeAssociations(scope *Scope) { for _, field := range scope.Fields() { - if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored { + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) @@ -26,7 +26,7 @@ func SaveBeforeAssociations(scope *Scope) { func SaveAfterAssociations(scope *Scope) { for _, field := range scope.Fields() { - if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored { + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field diff --git a/callback_update.go b/callback_update.go index b85e319a..e1725a98 100644 --- a/callback_update.go +++ b/callback_update.go @@ -46,11 +46,18 @@ func Update(scope *Scope) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) } } else { - for _, field := range scope.Fields() { - if !field.IsPrimaryKey && field.IsNormal { + fields := scope.Fields() + for _, field := range fields { + if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { if !field.IsBlank || !field.HasDefaultValue { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } + } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { + if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { + if !relationField.IsBlank { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))) + } + } } } } diff --git a/create_test.go b/create_test.go index e3c09d05..97175980 100644 --- a/create_test.go +++ b/create_test.go @@ -124,7 +124,7 @@ func TestAnonymousField(t *testing.T) { func TestSelectWithCreate(t *testing.T) { user := getPreparedUser("select_user", "select_with_create") - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(&user) + DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) var queryuser User DB.Preload("BillingAddress").Preload("ShippingAddress"). @@ -142,7 +142,7 @@ func TestSelectWithCreate(t *testing.T) { func TestOmitWithCreate(t *testing.T) { user := getPreparedUser("omit_user", "omit_with_create") - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(&user) + DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) var queryuser User DB.Preload("BillingAddress").Preload("ShippingAddress"). diff --git a/scope.go b/scope.go index 7d8cb52e..9d235d05 100644 --- a/scope.go +++ b/scope.go @@ -351,7 +351,7 @@ func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } -func (scope *Scope) ValidField(field *Field) bool { +func (scope *Scope) changeableField(field *Field) bool { selectAttrs := scope.SelectAttrs() omitAttrs := scope.OmitAttrs() diff --git a/update_test.go b/update_test.go index 9b66fc01..e86e82c9 100644 --- a/update_test.go +++ b/update_test.go @@ -206,3 +206,71 @@ func TestUpdateColumn(t *testing.T) { t.Errorf("UpdateColumn with expression should not update UpdatedAt") } } + +func TestSelectWithUpdate(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update") + DB.Create(user) + + var reloadUser User + DB.First(&reloadUser, user.Id) + reloadUser.Name = "new_name" + reloadUser.Age = 50 + reloadUser.BillingAddress = Address{Address1: "New Billing Address"} + reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} + reloadUser.CreditCard = CreditCard{Number: "987654321"} + reloadUser.Emails = []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + } + reloadUser.Company = Company{Name: "new company"} + + DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name == user.Name || queryUser.Age != user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || + queryUser.ShippingAddressId != user.ShippingAddressId || + queryUser.CreditCard.ID == user.CreditCard.ID || + len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { + t.Errorf("Should only update selected relationships") + } +} + +func TestOmitWithUpdate(t *testing.T) { + user := getPreparedUser("omit_user", "omit_with_update") + DB.Create(user) + + var reloadUser User + DB.First(&reloadUser, user.Id) + reloadUser.Name = "new_name" + reloadUser.Age = 50 + reloadUser.BillingAddress = Address{Address1: "New Billing Address"} + reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} + reloadUser.CreditCard = CreditCard{Number: "987654321"} + reloadUser.Emails = []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + } + reloadUser.Company = Company{Name: "new company"} + + DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name != user.Name || queryUser.Age == user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || + queryUser.ShippingAddressId == user.ShippingAddressId || + queryUser.CreditCard.ID != user.CreditCard.ID || + len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { + t.Errorf("Should only update selected relationships") + } +}