Allow to omit fields in associations, close #3752

This commit is contained in:
Jinzhu 2020-11-17 21:49:40 +08:00
parent 50df9da6a1
commit 54b80b18bc
2 changed files with 55 additions and 12 deletions

View File

@ -2,6 +2,7 @@ package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -66,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
}
if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
@ -79,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
rv = rv.Addr()
}
if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), rv.Interface()) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
@ -141,9 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface())
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
@ -163,9 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(f.Interface()).Error)
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), f.Interface())
}
}
}
@ -224,9 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface())
}
}
@ -292,7 +287,7 @@ func SaveAfterAssociations(db *gorm.DB) {
if elems.Len() > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface())
}
for i := 0; i < elems.Len(); i++ {
@ -335,3 +330,37 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol
return clause.OnConflict{DoNothing: true}
}
func saveAssociations(db *gorm.DB, selectColumns map[string]bool, refName string, onConflict clause.OnConflict, values interface{}) error {
var selects, omits []string
refName = refName + "."
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, refName) {
columnName = strings.TrimPrefix(name, refName)
} else if strings.HasPrefix(name, clause.Associations) {
columnName = name
}
if columnName != "" {
if ok {
selects = append(selects, columnName)
} else {
omits = append(omits, columnName)
}
}
}
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict)
if len(selects) > 0 {
tx = tx.Select(selects)
}
if len(omits) > 0 {
tx = tx.Omit(omits...)
}
return db.AddError(tx.Create(values).Error)
}

View File

@ -83,6 +83,20 @@ func TestHasOneAssociation(t *testing.T) {
AssertAssociationCount(t, user2, "Account", 0, "after clear")
}
func TestHasOneAssociationWithSelect(t *testing.T) {
var user = *GetUser("hasone", Config{Account: true})
DB.Omit("Account.Number").Create(&user)
AssertAssociationCount(t, user, "Account", 1, "")
var account Account
DB.Model(&user).Association("Account").Find(&account)
if account.Number != "" {
t.Errorf("account's number should not be saved")
}
}
func TestHasOneAssociationForSlice(t *testing.T) {
var users = []User{
*GetUser("slice-hasone-1", Config{Account: true}),