diff --git a/callback_update.go b/callback_update.go index e1725a98..1167871c 100644 --- a/callback_update.go +++ b/callback_update.go @@ -43,7 +43,9 @@ func Update(scope *Scope) { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { for key, value := range updateAttrs.(map[string]interface{}) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) + if scope.changeableDBColumn(key) { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) + } } } else { fields := scope.Fields() diff --git a/scope.go b/scope.go index 9d235d05..fccc5b88 100644 --- a/scope.go +++ b/scope.go @@ -20,6 +20,7 @@ type Scope struct { primaryKeyField *Field skipLeft bool fields map[string]*Field + selectAttrs *[]string } func (scope *Scope) IndirectValue() reflect.Value { @@ -334,23 +335,48 @@ func (scope *Scope) CommitOrRollback() *Scope { return scope } -func (scope *Scope) SelectAttrs() (attrs []string) { - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) +func (scope *Scope) SelectAttrs() []string { + if scope.selectAttrs == nil { + attrs := []string{} + for _, value := range scope.Search.selects { + if str, ok := value.(string); ok { + attrs = append(attrs, str) + } else if strs, ok := value.([]interface{}); ok { + for _, str := range strs { + attrs = append(attrs, fmt.Sprintf("%v", str)) + } } } + scope.selectAttrs = &attrs } - return attrs + return *scope.selectAttrs } func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } +func (scope *Scope) changeableDBColumn(column string) bool { + selectAttrs := scope.SelectAttrs() + omitAttrs := scope.OmitAttrs() + + if len(selectAttrs) > 0 { + for _, attr := range selectAttrs { + if column == ToDBName(attr) { + return true + } + } + return false + } + + for _, attr := range omitAttrs { + if column == ToDBName(attr) { + return false + } + } + return true +} + func (scope *Scope) changeableField(field *Field) bool { selectAttrs := scope.SelectAttrs() omitAttrs := scope.OmitAttrs() diff --git a/update_test.go b/update_test.go index e86e82c9..8a019087 100644 --- a/update_test.go +++ b/update_test.go @@ -241,6 +241,42 @@ func TestSelectWithUpdate(t *testing.T) { } } +func TestSelectWithUpdateWithMap(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{ + "Name": "new_name", + "Age": 50, + "BillingAddress": Address{Address1: "New Billing Address"}, + "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, + "CreditCard": CreditCard{Number: "987654321"}, + "Emails": []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + }, + "Company": Company{Name: "new company"}, + } + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name == user.Name || queryUser.Age != user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || + queryUser.ShippingAddressId != user.ShippingAddressId || + queryUser.CreditCard.ID == user.CreditCard.ID || + len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { + t.Errorf("Should only update selected relationships") + } +} + func TestOmitWithUpdate(t *testing.T) { user := getPreparedUser("omit_user", "omit_with_update") DB.Create(user) @@ -271,6 +307,78 @@ func TestOmitWithUpdate(t *testing.T) { queryUser.ShippingAddressId == user.ShippingAddressId || queryUser.CreditCard.ID != user.CreditCard.ID || len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update selected relationships") + t.Errorf("Should only update relationships that not omited") + } +} + +func TestOmitWithUpdateWithMap(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{ + "Name": "new_name", + "Age": 50, + "BillingAddress": Address{Address1: "New Billing Address"}, + "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, + "CreditCard": CreditCard{Number: "987654321"}, + "Emails": []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + }, + "Company": Company{Name: "new company"}, + } + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name != user.Name || queryUser.Age == user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || + queryUser.ShippingAddressId == user.ShippingAddressId || + queryUser.CreditCard.ID != user.CreditCard.ID || + len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { + t.Errorf("Should only update relationships not omited") + } +} + +func TestSelectWithUpdateColumn(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) + + var queryUser User + DB.First(&queryUser, user.Id) + + if queryUser.Name == user.Name || queryUser.Age != user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestOmitWithUpdateColumn(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) + + var queryUser User + DB.First(&queryUser, user.Id) + + if queryUser.Name != user.Name || queryUser.Age == user.Age { + t.Errorf("Should omit name column when update user") } }