diff --git a/association.go b/association.go index a382abf2..5cc32e1c 100644 --- a/association.go +++ b/association.go @@ -394,7 +394,7 @@ func toQueryCondition(scope *Scope, columns []string) string { if len(columns) > 1 { return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) } else { - return strings.Join(columns, ",") + return strings.Join(newColumns, ",") } } diff --git a/customize_column_test.go b/customize_column_test.go index cf4f1d1a..0a212525 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -63,3 +63,42 @@ func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { t.Errorf("Should not raise error: %s", err) } } + +type CustomizePerson struct { + IdPerson string `gorm:"column:idPerson;primary_key:true"` + Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` +} + +type CustomizeAccount struct { + IdAccount string `gorm:"column:idAccount;primary_key:true"` + Name string +} + +func TestManyToManyWithCustomizedColumn(t *testing.T) { + DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") + DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) + + account := CustomizeAccount{IdAccount: "account", Name: "id1"} + person := CustomizePerson{ + IdPerson: "person", + Accounts: []CustomizeAccount{account}, + } + + if err := DB.Create(&account).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if err := DB.Create(&person).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + var person1 CustomizePerson + scope := DB.NewScope(nil) + if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { + t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) + } + + if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { + t.Errorf("should preload correct accounts") + } +} diff --git a/join_table_handler.go b/join_table_handler.go index 0a81a929..878bb491 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -92,7 +92,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 var assignColumns, binVars, conditions []string var values []interface{} for key, value := range searchMap { - assignColumns = append(assignColumns, key) + assignColumns = append(assignColumns, scope.Quote(key)) binVars = append(binVars, `?`) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) @@ -102,7 +102,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 values = append(values, value) } - quotedTable := handler.Table(db) + quotedTable := scope.Quote(handler.Table(db)) sql := fmt.Sprintf( "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", quotedTable, @@ -117,11 +117,14 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 } func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var conditions []string - var values []interface{} + var ( + scope = db.NewScope(nil) + conditions []string + values []interface{} + ) for key, value := range s.GetSearchMap(db, sources...) { - conditions = append(conditions, fmt.Sprintf("%v = ?", key)) + conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } @@ -129,12 +132,14 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour } func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - quotedTable := handler.Table(db) + var ( + scope = db.NewScope(source) + modelType = scope.GetModelStruct().ModelType + quotedTable = scope.Quote(handler.Table(db)) + joinConditions []string + values []interface{} + ) - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - var joinConditions []string - var values []interface{} if s.Source.ModelType == modelType { destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() for _, foreignKey := range s.Destination.ForeignKeys { diff --git a/model_struct.go b/model_struct.go index aeda25f9..d094edad 100644 --- a/model_struct.go +++ b/model_struct.go @@ -99,7 +99,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { return field } }