Test many to many relation with customized column

This commit is contained in:
Jinzhu 2016-01-04 08:26:02 +08:00
parent caeb4040f2
commit d87a960248
4 changed files with 56 additions and 12 deletions

View File

@ -394,7 +394,7 @@ func toQueryCondition(scope *Scope, columns []string) string {
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
} else {
return strings.Join(columns, ",")
return strings.Join(newColumns, ",")
}
}

View File

@ -63,3 +63,42 @@ func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
t.Errorf("Should not raise error: %s", err)
}
}
type CustomizePerson struct {
IdPerson string `gorm:"column:idPerson;primary_key:true"`
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
}
type CustomizeAccount struct {
IdAccount string `gorm:"column:idAccount;primary_key:true"`
Name string
}
func TestManyToManyWithCustomizedColumn(t *testing.T) {
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
person := CustomizePerson{
IdPerson: "person",
Accounts: []CustomizeAccount{account},
}
if err := DB.Create(&account).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Create(&person).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
var person1 CustomizePerson
scope := DB.NewScope(nil)
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
}
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
t.Errorf("should preload correct accounts")
}
}

View File

@ -92,7 +92,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range searchMap {
assignColumns = append(assignColumns, key)
assignColumns = append(assignColumns, scope.Quote(key))
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
@ -102,7 +102,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
values = append(values, value)
}
quotedTable := handler.Table(db)
quotedTable := scope.Quote(handler.Table(db))
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
quotedTable,
@ -117,11 +117,14 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
}
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var conditions []string
var values []interface{}
var (
scope = db.NewScope(nil)
conditions []string
values []interface{}
)
for key, value := range s.GetSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
@ -129,12 +132,14 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
}
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
quotedTable := handler.Table(db)
var (
scope = db.NewScope(source)
modelType = scope.GetModelStruct().ModelType
quotedTable = scope.Quote(handler.Table(db))
joinConditions []string
values []interface{}
)
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
var joinConditions []string
var values []interface{}
if s.Source.ModelType == modelType {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
for _, foreignKey := range s.Destination.ForeignKeys {

View File

@ -99,7 +99,7 @@ type Relationship struct {
func getForeignField(column string, fields []*StructField) *StructField {
for _, field := range fields {
if field.Name == column || field.DBName == ToDBName(column) {
if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
return field
}
}