Refactor self referencing m2m support

This commit is contained in:
Jinzhu 2018-02-10 20:57:39 +08:00
parent 8e7d807ebf
commit 44b9911f51
4 changed files with 117 additions and 88 deletions

View File

@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
} }
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?) // has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
} else { } else {
var foreignKeyMap = map[string]interface{}{} var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames { for _, foreignKey := range relationship.ForeignDBNames {

View File

@ -282,22 +282,44 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
type SelfReferencingUser struct { type SelfReferencingUser struct {
gorm.Model gorm.Model
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"` Name string
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
} }
func TestSelfReferencingMany2ManyColumn(t *testing.T) { func TestSelfReferencingMany2ManyColumn(t *testing.T) {
DB.DropTable(&SelfReferencingUser{}, "UserFriends") DB.DropTable(&SelfReferencingUser{}, "UserFriends")
DB.AutoMigrate(&SelfReferencingUser{}) DB.AutoMigrate(&SelfReferencingUser{})
friend := SelfReferencingUser{} friend1 := SelfReferencingUser{Name: "friend1_m2m"}
if err := DB.Create(&friend).Error; err != nil { if err := DB.Create(&friend1).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
if err := DB.Create(&friend2).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
user := SelfReferencingUser{ user := SelfReferencingUser{
Friends: []*SelfReferencingUser{&friend}, Name: "self_m2m",
Friends: []*SelfReferencingUser{&friend1, &friend2},
} }
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
if DB.Model(&user).Association("Friends").Count() != 2 {
t.Errorf("Should find created friends correctly")
}
var newUser = SelfReferencingUser{}
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if len(newUser.Friends) != 2 {
t.Errorf("Should preload created frineds for self reference m2m")
}
} }

View File

@ -82,55 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string {
return s.TableName return s.TableName
} }
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
values := map[string]interface{}{}
for _, source := range sources { for _, source := range sources {
scope := db.NewScope(source) scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType { for _, joinTableSource := range joinTableSources {
for _, foreignKey := range s.Source.ForeignKeys { if joinTableSource.ModelType == modelType {
for _, foreignKey := range joinTableSource.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface() conditionMap[foreignKey.DBName] = field.Field.Interface()
} }
} }
} else if s.Destination.ModelType == modelType { break
for _, foreignKey := range s.Destination.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
} }
} }
} }
}
return values
} }
// Add create relationship in join table for source and destination // Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
scope := db.NewScope("") var (
searchMap := map[string]interface{}{} scope = db.NewScope("")
conditionMap = map[string]interface{}{}
)
// getSearchMap() cannot be used here since the source and destination // Update condition map for source
// model types may be identical s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
sourceScope := db.NewScope(source) // Update condition map for destination
for _, foreignKey := range s.Source.ForeignKeys { s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok {
searchMap[foreignKey.DBName] = field.Field.Interface()
}
}
destinationScope := db.NewScope(destination)
for _, foreignKey := range s.Destination.ForeignKeys {
if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok {
searchMap[foreignKey.DBName] = field.Field.Interface()
}
}
var assignColumns, binVars, conditions []string var assignColumns, binVars, conditions []string
var values []interface{} var values []interface{}
for key, value := range searchMap { for key, value := range conditionMap {
assignColumns = append(assignColumns, scope.Quote(key)) assignColumns = append(assignColumns, scope.Quote(key))
binVars = append(binVars, `?`) binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
@ -161,9 +146,12 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
scope = db.NewScope(nil) scope = db.NewScope(nil)
conditions []string conditions []string
values []interface{} values []interface{}
conditionMap = map[string]interface{}{}
) )
for key, value := range s.getSearchMap(db, sources...) { s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
for key, value := range conditionMap {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value) values = append(values, value)
} }

View File

@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
) )
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") foreignKeys = strings.Split(foreignKey, ",")
} }
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") associationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
} }
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
@ -264,6 +266,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many" relationship.Kind = "many_to_many"
{ // Foreign Keys for Source
joinTableDBNames := []string{}
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
joinTableDBNames = strings.Split(foreignKey, ",")
}
// if no foreign keys defined with tag // if no foreign keys defined with tag
if len(foreignKeys) == 0 { if len(foreignKeys) == 0 {
for _, field := range modelStruct.PrimaryFields { for _, field := range modelStruct.PrimaryFields {
@ -271,15 +280,29 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
for _, foreignKey := range foreignKeys { for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
// source foreign keys (db names) // source foreign keys (db names)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
// join table foreign keys for source
joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName // setup join table foreign keys for source
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) if len(joinTableDBNames) > idx {
// if defined join table's foreign key
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
} else {
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
} }
} }
}
}
{ // Foreign Keys for Association (Destination)
associationJoinTableDBNames := []string{}
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
associationJoinTableDBNames = strings.Split(foreignKey, ",")
}
// if no association foreign keys defined with tag // if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 { if len(associationForeignKeys) == 0 {
@ -288,28 +311,22 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
for _, name := range associationForeignKeys { for idx, name := range associationForeignKeys {
// In order to allow self-referencing many2many tables, the name
// may be followed by "=" to allow renaming the column
parts := strings.Split(name, "=")
name = parts[0]
if field, ok := toScope.FieldByName(name); ok { if field, ok := toScope.FieldByName(name); ok {
// association foreign keys (db names) // association foreign keys (db names)
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
// If a new name was provided for the field, use it // setup join table foreign keys for association
name = field.DBName if len(associationJoinTableDBNames) > idx {
if len(parts) > 1 { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
name = parts[1] } else {
}
// join table foreign keys for association // join table foreign keys for association
joinTableDBName := ToDBName(elemType.Name()) + "_" + name joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
} }
} }
}
}
joinTableHandler := JoinTableHandler{} joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, reflectType, elemType) joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
@ -412,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
) )
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") tagForeignKeys = strings.Split(foreignKey, ",")
} }
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") tagAssociationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
} }
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {