mirror of https://github.com/go-gorm/gorm.git
Allow to omit fields in associations, close #3752
This commit is contained in:
parent
50df9da6a1
commit
54b80b18bc
|
@ -2,6 +2,7 @@ package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
@ -66,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
|
if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) == nil {
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
setupReferences(objs[i], elems.Index(i))
|
setupReferences(objs[i], elems.Index(i))
|
||||||
}
|
}
|
||||||
|
@ -79,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
rv = rv.Addr()
|
rv = rv.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
|
if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), rv.Interface()) == nil {
|
||||||
setupReferences(db.Statement.ReflectValue, rv)
|
setupReferences(db.Statement.ReflectValue, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -141,9 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(
|
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface())
|
||||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
|
||||||
).Create(elems.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||||
|
@ -163,9 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(
|
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), f.Interface())
|
||||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
|
||||||
).Create(f.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -224,9 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(
|
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface())
|
||||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
|
||||||
).Create(elems.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,7 +287,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
|
saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
|
@ -335,3 +330,37 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol
|
||||||
|
|
||||||
return clause.OnConflict{DoNothing: true}
|
return clause.OnConflict{DoNothing: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func saveAssociations(db *gorm.DB, selectColumns map[string]bool, refName string, onConflict clause.OnConflict, values interface{}) error {
|
||||||
|
var selects, omits []string
|
||||||
|
refName = refName + "."
|
||||||
|
|
||||||
|
for name, ok := range selectColumns {
|
||||||
|
columnName := ""
|
||||||
|
if strings.HasPrefix(name, refName) {
|
||||||
|
columnName = strings.TrimPrefix(name, refName)
|
||||||
|
} else if strings.HasPrefix(name, clause.Associations) {
|
||||||
|
columnName = name
|
||||||
|
}
|
||||||
|
|
||||||
|
if columnName != "" {
|
||||||
|
if ok {
|
||||||
|
selects = append(selects, columnName)
|
||||||
|
} else {
|
||||||
|
omits = append(omits, columnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict)
|
||||||
|
|
||||||
|
if len(selects) > 0 {
|
||||||
|
tx = tx.Select(selects)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(omits) > 0 {
|
||||||
|
tx = tx.Omit(omits...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.AddError(tx.Create(values).Error)
|
||||||
|
}
|
||||||
|
|
|
@ -83,6 +83,20 @@ func TestHasOneAssociation(t *testing.T) {
|
||||||
AssertAssociationCount(t, user2, "Account", 0, "after clear")
|
AssertAssociationCount(t, user2, "Account", 0, "after clear")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasOneAssociationWithSelect(t *testing.T) {
|
||||||
|
var user = *GetUser("hasone", Config{Account: true})
|
||||||
|
|
||||||
|
DB.Omit("Account.Number").Create(&user)
|
||||||
|
|
||||||
|
AssertAssociationCount(t, user, "Account", 1, "")
|
||||||
|
|
||||||
|
var account Account
|
||||||
|
DB.Model(&user).Association("Account").Find(&account)
|
||||||
|
if account.Number != "" {
|
||||||
|
t.Errorf("account's number should not be saved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHasOneAssociationForSlice(t *testing.T) {
|
func TestHasOneAssociationForSlice(t *testing.T) {
|
||||||
var users = []User{
|
var users = []User{
|
||||||
*GetUser("slice-hasone-1", Config{Account: true}),
|
*GetUser("slice-hasone-1", Config{Account: true}),
|
||||||
|
|
Loading…
Reference in New Issue