diff --git a/association.go b/association.go index d93ff8ca..4c55c7e1 100644 --- a/association.go +++ b/association.go @@ -66,7 +66,9 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation( /*clear*/ true, values...) + if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { + return association.Error + } // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -378,11 +380,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } selectedSaveColumns := []string{association.Relationship.Name} + omitColumns := []string{} + selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, association.Relationship.Name) { + columnName = strings.TrimPrefix(name, association.Relationship.Name) + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selectedSaveColumns = append(selectedSaveColumns, columnName) + } else { + omitColumns = append(omitColumns, columnName) + } + } + } + for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey { selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) } } + associationDB := association.DB.Session(&Session{}).Model(nil).Select(selectedSaveColumns).Session(&Session{}) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: @@ -417,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error + association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data @@ -439,7 +461,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error + association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error } } diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 1ddd3b85..739d1682 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -113,6 +113,11 @@ func TestMany2ManyOmitAssociations(t *testing.T) { if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { t.Errorf("languages count should be %v, but got %v", 2, len(languages)) } + + var newLang = Language{Code: "omitmany2many", Name: "omitmany2many"} + if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { + t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) + } } func TestMany2ManyAssociationForSlice(t *testing.T) { diff --git a/tests/go.mod b/tests/go.mod index f6912a0f..67db5117 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,11 +7,11 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 - gorm.io/driver/mysql v1.0.3 - gorm.io/driver/postgres v1.0.6 + gorm.io/driver/mysql v1.0.4 + gorm.io/driver/postgres v1.0.7 gorm.io/driver/sqlite v1.1.4 - gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.8 + gorm.io/driver/sqlserver v1.0.6 + gorm.io/gorm v1.20.12 ) replace gorm.io/gorm => ../