mirror of https://github.com/go-gorm/gorm.git
Refactor self referencing m2m support
This commit is contained in:
parent
8e7d807ebf
commit
44b9911f51
|
@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||
|
@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||
} else {
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
|
|
|
@ -282,22 +282,44 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
|
|||
|
||||
type SelfReferencingUser struct {
|
||||
gorm.Model
|
||||
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"`
|
||||
Name string
|
||||
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
|
||||
}
|
||||
|
||||
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
||||
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
||||
DB.AutoMigrate(&SelfReferencingUser{})
|
||||
|
||||
friend := SelfReferencingUser{}
|
||||
if err := DB.Create(&friend).Error; err != nil {
|
||||
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
|
||||
if err := DB.Create(&friend1).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
|
||||
if err := DB.Create(&friend2).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
user := SelfReferencingUser{
|
||||
Friends: []*SelfReferencingUser{&friend},
|
||||
Name: "self_m2m",
|
||||
Friends: []*SelfReferencingUser{&friend1, &friend2},
|
||||
}
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("Friends").Count() != 2 {
|
||||
t.Errorf("Should find created friends correctly")
|
||||
}
|
||||
|
||||
var newUser = SelfReferencingUser{}
|
||||
|
||||
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if len(newUser.Friends) != 2 {
|
||||
t.Errorf("Should preload created frineds for self reference m2m")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,55 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string {
|
|||
return s.TableName
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||
values := map[string]interface{}{}
|
||||
|
||||
func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
|
||||
for _, source := range sources {
|
||||
scope := db.NewScope(source)
|
||||
modelType := scope.GetModelStruct().ModelType
|
||||
|
||||
if s.Source.ModelType == modelType {
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
for _, joinTableSource := range joinTableSources {
|
||||
if joinTableSource.ModelType == modelType {
|
||||
for _, foreignKey := range joinTableSource.ForeignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
values[foreignKey.DBName] = field.Field.Interface()
|
||||
conditionMap[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
} else if s.Destination.ModelType == modelType {
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
values[foreignKey.DBName] = field.Field.Interface()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Add create relationship in join table for source and destination
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||
scope := db.NewScope("")
|
||||
searchMap := map[string]interface{}{}
|
||||
var (
|
||||
scope = db.NewScope("")
|
||||
conditionMap = map[string]interface{}{}
|
||||
)
|
||||
|
||||
// getSearchMap() cannot be used here since the source and destination
|
||||
// model types may be identical
|
||||
// Update condition map for source
|
||||
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
|
||||
|
||||
sourceScope := db.NewScope(source)
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
searchMap[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
destinationScope := db.NewScope(destination)
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
searchMap[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
// Update condition map for destination
|
||||
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
|
||||
|
||||
var assignColumns, binVars, conditions []string
|
||||
var values []interface{}
|
||||
for key, value := range searchMap {
|
||||
for key, value := range conditionMap {
|
||||
assignColumns = append(assignColumns, scope.Quote(key))
|
||||
binVars = append(binVars, `?`)
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
|
@ -161,9 +146,12 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
|||
scope = db.NewScope(nil)
|
||||
conditions []string
|
||||
values []interface{}
|
||||
conditionMap = map[string]interface{}{}
|
||||
)
|
||||
|
||||
for key, value := range s.getSearchMap(db, sources...) {
|
||||
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
|
||||
|
||||
for key, value := range conditionMap {
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
|
|
@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
)
|
||||
|
||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||
foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
||||
foreignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||
associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
||||
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
|
||||
|
@ -264,6 +266,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||
relationship.Kind = "many_to_many"
|
||||
|
||||
{ // Foreign Keys for Source
|
||||
joinTableDBNames := []string{}
|
||||
|
||||
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
||||
joinTableDBNames = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
// if no foreign keys defined with tag
|
||||
if len(foreignKeys) == 0 {
|
||||
for _, field := range modelStruct.PrimaryFields {
|
||||
|
@ -271,15 +280,29 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
}
|
||||
}
|
||||
|
||||
for _, foreignKey := range foreignKeys {
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||
// source foreign keys (db names)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
|
||||
// join table foreign keys for source
|
||||
joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||
|
||||
// setup join table foreign keys for source
|
||||
if len(joinTableDBNames) > idx {
|
||||
// if defined join table's foreign key
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
||||
} else {
|
||||
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{ // Foreign Keys for Association (Destination)
|
||||
associationJoinTableDBNames := []string{}
|
||||
|
||||
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
||||
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
// if no association foreign keys defined with tag
|
||||
if len(associationForeignKeys) == 0 {
|
||||
|
@ -288,28 +311,22 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
}
|
||||
}
|
||||
|
||||
for _, name := range associationForeignKeys {
|
||||
|
||||
// In order to allow self-referencing many2many tables, the name
|
||||
// may be followed by "=" to allow renaming the column
|
||||
parts := strings.Split(name, "=")
|
||||
name = parts[0]
|
||||
|
||||
for idx, name := range associationForeignKeys {
|
||||
if field, ok := toScope.FieldByName(name); ok {
|
||||
// association foreign keys (db names)
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
||||
|
||||
// If a new name was provided for the field, use it
|
||||
name = field.DBName
|
||||
if len(parts) > 1 {
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
// setup join table foreign keys for association
|
||||
if len(associationJoinTableDBNames) > idx {
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
||||
} else {
|
||||
// join table foreign keys for association
|
||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + name
|
||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
joinTableHandler := JoinTableHandler{}
|
||||
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
||||
|
@ -412,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
)
|
||||
|
||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||
tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
||||
tagForeignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||
tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
||||
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||
}
|
||||
|
||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
|
|
Loading…
Reference in New Issue