diff --git a/callbacks/associations.go b/callbacks/associations.go index ce91c2ee..ea90780c 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -2,6 +2,7 @@ package callbacks import ( "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -66,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { + if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -79,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { + if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), rv.Interface()) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -141,9 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(elems.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -163,9 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(f.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), f.Interface()) } } } @@ -224,9 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(elems.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) } } @@ -292,7 +287,7 @@ func SaveAfterAssociations(db *gorm.DB) { if elems.Len() > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) } for i := 0; i < elems.Len(); i++ { @@ -335,3 +330,37 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol return clause.OnConflict{DoNothing: true} } + +func saveAssociations(db *gorm.DB, selectColumns map[string]bool, refName string, onConflict clause.OnConflict, values interface{}) error { + var selects, omits []string + refName = refName + "." + + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, refName) { + columnName = strings.TrimPrefix(name, refName) + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selects = append(selects, columnName) + } else { + omits = append(omits, columnName) + } + } + } + + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict) + + if len(selects) > 0 { + tx = tx.Select(selects) + } + + if len(omits) > 0 { + tx = tx.Omit(omits...) + } + + return db.AddError(tx.Create(values).Error) +} diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index f487bd9e..a4fc8c4f 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -83,6 +83,20 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 0, "after clear") } +func TestHasOneAssociationWithSelect(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + DB.Omit("Account.Number").Create(&user) + + AssertAssociationCount(t, user, "Account", 1, "") + + var account Account + DB.Model(&user).Association("Account").Find(&account) + if account.Number != "" { + t.Errorf("account's number should not be saved") + } +} + func TestHasOneAssociationForSlice(t *testing.T) { var users = []User{ *GetUser("slice-hasone-1", Config{Account: true}),