diff --git a/callbacks/associations.go b/callbacks/associations.go index ea90780c..0fa47868 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -67,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) == nil { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -80,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), rv.Interface()) == nil { + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -142,7 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -162,7 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), f.Interface()) + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) } } } @@ -221,7 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } } @@ -287,7 +287,7 @@ func SaveAfterAssociations(db *gorm.DB) { if elems.Len() > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) } for i := 0; i < elems.Len(); i++ { @@ -302,10 +302,14 @@ func SaveAfterAssociations(db *gorm.DB) { } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { if stmt.DB.FullSaveAssociations { defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) for _, dbName := range s.DBNames { + if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { + continue + } + if !s.LookUpField(dbName).PrimaryKey { defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) } @@ -331,9 +335,12 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol return clause.OnConflict{DoNothing: true} } -func saveAssociations(db *gorm.DB, selectColumns map[string]bool, refName string, onConflict clause.OnConflict, values interface{}) error { - var selects, omits []string - refName = refName + "." +func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + var ( + selects, omits []string + onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + refName = rel.Name + "." + ) for name, ok := range selectColumns { columnName := ""