diff --git a/callback_create.go b/callback_create.go index 6eddddb3..ff111f76 100644 --- a/callback_create.go +++ b/callback_create.go @@ -24,11 +24,21 @@ func Create(scope *Scope) { if !scope.HasError() { // set create sql var sqls, columns []string - for _, field := range scope.Fields() { - if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) { - if !field.IsBlank || !field.HasDefaultValue { - columns = append(columns, scope.Quote(field.DBName)) - sqls = append(sqls, scope.AddToVars(field.Field.Interface())) + fields := scope.Fields() + for _, field := range fields { + if scope.ValidField(field) { + if field.IsNormal { + if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { + if !field.IsBlank || !field.HasDefaultValue { + columns = append(columns, scope.Quote(field.DBName)) + sqls = append(sqls, scope.AddToVars(field.Field.Interface())) + } + } + } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { + if relationField := fields[relationship.ForeignDBName]; !scope.ValidField(relationField) { + columns = append(columns, scope.Quote(relationField.DBName)) + sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + } } } } diff --git a/callback_shared.go b/callback_shared.go index 48f0b937..ce99d3f0 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -12,7 +12,7 @@ func CommitOrRollbackTransaction(scope *Scope) { func SaveBeforeAssociations(scope *Scope) { for _, field := range scope.Fields() { - if !field.IsBlank && !field.IsIgnored { + if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) @@ -26,7 +26,7 @@ func SaveBeforeAssociations(scope *Scope) { func SaveAfterAssociations(scope *Scope) { for _, field := range scope.Fields() { - if !field.IsBlank && !field.IsIgnored { + if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field diff --git a/create_test.go b/create_test.go index a5b46630..f8f4d642 100644 --- a/create_test.go +++ b/create_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "fmt" "reflect" "testing" "time" @@ -121,3 +122,22 @@ func TestAnonymousField(t *testing.T) { t.Errorf("Should be able to get anonymous field") } } + +func TestSelectCreate(t *testing.T) { + user := getPreparedUser("user1", "select_create") + DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(&user) + + var user2 User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&user2, user.Id) + + if user2.Name != user.Name || user2.Age == user.Age { + t.Errorf("Should only create users with name column") + } + + fmt.Println(user2.CreditCard.ID) + if user2.BillingAddressID.Int64 == 0 || user2.ShippingAddressId != 0 || + user2.CreditCard.ID == 0 || len(user2.Emails) == 0 { + t.Errorf("Should only create users with name column") + } +} diff --git a/main_test.go b/main_test.go index 6ffbbf8d..3864fcbd 100644 --- a/main_test.go +++ b/main_test.go @@ -89,8 +89,12 @@ func TestExceptionsWithInvalidSql(t *testing.T) { } func TestSetTable(t *testing.T) { - if DB.Table("users").Pluck("age", &[]int{}).Error != nil { - t.Errorf("No errors should happen if set table for pluck") + DB.Create(getPreparedUser("pluck_user1", "pluck_user")) + DB.Create(getPreparedUser("pluck_user2", "pluck_user")) + DB.Create(getPreparedUser("pluck_user3", "pluck_user")) + + if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { + t.Errorf("No errors should happen if set table for pluck", err.Error()) } var users []User @@ -115,9 +119,11 @@ func TestSetTable(t *testing.T) { t.Errorf("Query from specified table") } + DB.Save(getPreparedUser("normal_user", "reset_table")) + DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) var user1, user2, user3 User - DB.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) - if (user1.Name == user2.Name) || (user1.Name != user3.Name) { + DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) + if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { t.Errorf("unset specified table with blank string") } } diff --git a/preload_test.go b/preload_test.go index 6bb498c0..2547933b 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1,29 +1,9 @@ package gorm_test -import ( - "fmt" - "testing" -) +import "testing" -func getPreloadUser(name string) User { - var company Company - DB.Where(Company{Name: "preload"}).FirstOrCreate(&company) - - return User{ - Name: name, - Role: Role{"Preload"}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } +func getPreloadUser(name string) *User { + return getPreparedUser(name, "Preload") } func checkUserHasPreloadData(user User, t *testing.T) { @@ -64,7 +44,7 @@ func checkUserHasPreloadData(user User, t *testing.T) { func TestPreload(t *testing.T) { user1 := getPreloadUser("user1") - DB.Save(&user1) + DB.Save(user1) preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). Preload("CreditCard").Preload("Emails").Preload("Company") @@ -73,10 +53,10 @@ func TestPreload(t *testing.T) { checkUserHasPreloadData(user, t) user2 := getPreloadUser("user2") - DB.Save(&user2) + DB.Save(user2) user3 := getPreloadUser("user3") - DB.Save(&user3) + DB.Save(user3) var users []User preloadDB.Find(&users) diff --git a/scope.go b/scope.go index b8cbd1f3..7d8cb52e 100644 --- a/scope.go +++ b/scope.go @@ -333,3 +333,42 @@ func (scope *Scope) CommitOrRollback() *Scope { } return scope } + +func (scope *Scope) SelectAttrs() (attrs []string) { + for _, value := range scope.Search.selects { + if str, ok := value.(string); ok { + attrs = append(attrs, str) + } else if strs, ok := value.([]interface{}); ok { + for _, str := range strs { + attrs = append(attrs, fmt.Sprintf("%v", str)) + } + } + } + return attrs +} + +func (scope *Scope) OmitAttrs() []string { + return scope.Search.omits +} + +func (scope *Scope) ValidField(field *Field) bool { + selectAttrs := scope.SelectAttrs() + omitAttrs := scope.OmitAttrs() + + if len(selectAttrs) > 0 { + for _, attr := range selectAttrs { + if field.Name == attr || field.DBName == attr { + return true + } + } + return false + } + + for _, attr := range omitAttrs { + if field.Name == attr || field.DBName == attr { + return false + } + } + + return !field.IsIgnored +} diff --git a/search.go b/search.go index 47f5a6cc..502c226f 100644 --- a/search.go +++ b/search.go @@ -89,18 +89,6 @@ func (s *search) Omit(columns ...string) *search { return s } -func (s *search) SelectAttrs() (attrs []string) { - for key, value := range s.selects { - attrs = append(attrs, key) - attrs = append(attrs, value.([]string)...) - } - return attrs -} - -func (s *search) OmitAttrs() []string { - return s.omits -} - func (s *search) Limit(value interface{}) *search { s.limit = s.getInterfaceAsSql(value) return s diff --git a/structs_test.go b/structs_test.go index f52ef906..9a9b23d1 100644 --- a/structs_test.go +++ b/structs_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "errors" + "fmt" "reflect" "time" @@ -194,3 +195,25 @@ func (nt NullTime) Value() (driver.Value, error) { } return nt.Time, nil } + +func getPreparedUser(name string, role string) *User { + var company Company + DB.Where(Company{Name: role}).FirstOrCreate(&company) + + return &User{ + Name: name, + Age: 20, + Role: Role{role}, + BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, + ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, + CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, + Emails: []Email{ + {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, + }, + Company: company, + Languages: []Language{ + {Name: fmt.Sprintf("lang_1_%v", name)}, + {Name: fmt.Sprintf("lang_2_%v", name)}, + }, + } +}