Refactor self referencing m2m support

This commit is contained in:
Jinzhu 2018-02-10 20:57:39 +08:00
parent 8e7d807ebf
commit 44b9911f51
4 changed files with 117 additions and 88 deletions

View File

@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
} else {
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {

View File

@ -282,22 +282,44 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
type SelfReferencingUser struct {
gorm.Model
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"`
Name string
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
}
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
DB.AutoMigrate(&SelfReferencingUser{})
friend := SelfReferencingUser{}
if err := DB.Create(&friend).Error; err != nil {
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
if err := DB.Create(&friend1).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
if err := DB.Create(&friend2).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
user := SelfReferencingUser{
Friends: []*SelfReferencingUser{&friend},
Name: "self_m2m",
Friends: []*SelfReferencingUser{&friend1, &friend2},
}
if err := DB.Create(&user).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if DB.Model(&user).Association("Friends").Count() != 2 {
t.Errorf("Should find created friends correctly")
}
var newUser = SelfReferencingUser{}
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if len(newUser.Friends) != 2 {
t.Errorf("Should preload created frineds for self reference m2m")
}
}

View File

@ -82,55 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
for _, source := range sources {
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
}
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
for _, joinTableSource := range joinTableSources {
if joinTableSource.ModelType == modelType {
for _, foreignKey := range joinTableSource.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
conditionMap[foreignKey.DBName] = field.Field.Interface()
}
}
break
}
}
}
return values
}
// Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
scope := db.NewScope("")
searchMap := map[string]interface{}{}
var (
scope = db.NewScope("")
conditionMap = map[string]interface{}{}
)
// getSearchMap() cannot be used here since the source and destination
// model types may be identical
// Update condition map for source
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
sourceScope := db.NewScope(source)
for _, foreignKey := range s.Source.ForeignKeys {
if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok {
searchMap[foreignKey.DBName] = field.Field.Interface()
}
}
destinationScope := db.NewScope(destination)
for _, foreignKey := range s.Destination.ForeignKeys {
if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok {
searchMap[foreignKey.DBName] = field.Field.Interface()
}
}
// Update condition map for destination
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range searchMap {
for key, value := range conditionMap {
assignColumns = append(assignColumns, scope.Quote(key))
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
@ -158,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source
// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
conditions []string
values []interface{}
scope = db.NewScope(nil)
conditions []string
values []interface{}
conditionMap = map[string]interface{}{}
)
for key, value := range s.getSearchMap(db, sources...) {
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
for key, value := range conditionMap {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}

View File

@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
)
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
foreignKeys = strings.Split(foreignKey, ",")
}
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
}
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
@ -264,50 +266,65 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many"
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
for _, field := range modelStruct.PrimaryFields {
foreignKeys = append(foreignKeys, field.DBName)
{ // Foreign Keys for Source
joinTableDBNames := []string{}
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
joinTableDBNames = strings.Split(foreignKey, ",")
}
}
for _, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
// source foreign keys (db names)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
// join table foreign keys for source
joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
}
}
// if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 {
for _, field := range toScope.PrimaryFields() {
associationForeignKeys = append(associationForeignKeys, field.DBName)
}
}
for _, name := range associationForeignKeys {
// In order to allow self-referencing many2many tables, the name
// may be followed by "=" to allow renaming the column
parts := strings.Split(name, "=")
name = parts[0]
if field, ok := toScope.FieldByName(name); ok {
// association foreign keys (db names)
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
// If a new name was provided for the field, use it
name = field.DBName
if len(parts) > 1 {
name = parts[1]
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
for _, field := range modelStruct.PrimaryFields {
foreignKeys = append(foreignKeys, field.DBName)
}
}
// join table foreign keys for association
joinTableDBName := ToDBName(elemType.Name()) + "_" + name
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
// source foreign keys (db names)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
// setup join table foreign keys for source
if len(joinTableDBNames) > idx {
// if defined join table's foreign key
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
} else {
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
}
}
}
}
{ // Foreign Keys for Association (Destination)
associationJoinTableDBNames := []string{}
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
associationJoinTableDBNames = strings.Split(foreignKey, ",")
}
// if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 {
for _, field := range toScope.PrimaryFields() {
associationForeignKeys = append(associationForeignKeys, field.DBName)
}
}
for idx, name := range associationForeignKeys {
if field, ok := toScope.FieldByName(name); ok {
// association foreign keys (db names)
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
// setup join table foreign keys for association
if len(associationJoinTableDBNames) > idx {
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
} else {
// join table foreign keys for association
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
}
}
}
}
@ -412,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
)
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
tagForeignKeys = strings.Split(foreignKey, ",")
}
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
}
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {