mirror of https://github.com/go-gorm/gorm.git
Refactor self referencing m2m support
This commit is contained in:
parent
8e7d807ebf
commit
44b9911f51
|
@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||||
|
|
||||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||||
}
|
}
|
||||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||||
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||||
|
@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||||
|
|
||||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||||
} else {
|
} else {
|
||||||
var foreignKeyMap = map[string]interface{}{}
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
|
|
@ -282,22 +282,44 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
|
||||||
|
|
||||||
type SelfReferencingUser struct {
|
type SelfReferencingUser struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"`
|
Name string
|
||||||
|
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
||||||
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
||||||
DB.AutoMigrate(&SelfReferencingUser{})
|
DB.AutoMigrate(&SelfReferencingUser{})
|
||||||
|
|
||||||
friend := SelfReferencingUser{}
|
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
|
||||||
if err := DB.Create(&friend).Error; err != nil {
|
if err := DB.Create(&friend1).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
|
||||||
|
if err := DB.Create(&friend2).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user := SelfReferencingUser{
|
user := SelfReferencingUser{
|
||||||
Friends: []*SelfReferencingUser{&friend},
|
Name: "self_m2m",
|
||||||
|
Friends: []*SelfReferencingUser{&friend1, &friend2},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&user).Error; err != nil {
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Friends").Count() != 2 {
|
||||||
|
t.Errorf("Should find created friends correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
var newUser = SelfReferencingUser{}
|
||||||
|
|
||||||
|
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newUser.Friends) != 2 {
|
||||||
|
t.Errorf("Should preload created frineds for self reference m2m")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,55 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string {
|
||||||
return s.TableName
|
return s.TableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
|
||||||
values := map[string]interface{}{}
|
|
||||||
|
|
||||||
for _, source := range sources {
|
for _, source := range sources {
|
||||||
scope := db.NewScope(source)
|
scope := db.NewScope(source)
|
||||||
modelType := scope.GetModelStruct().ModelType
|
modelType := scope.GetModelStruct().ModelType
|
||||||
|
|
||||||
if s.Source.ModelType == modelType {
|
for _, joinTableSource := range joinTableSources {
|
||||||
for _, foreignKey := range s.Source.ForeignKeys {
|
if joinTableSource.ModelType == modelType {
|
||||||
|
for _, foreignKey := range joinTableSource.ForeignKeys {
|
||||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
values[foreignKey.DBName] = field.Field.Interface()
|
conditionMap[foreignKey.DBName] = field.Field.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if s.Destination.ModelType == modelType {
|
break
|
||||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
|
||||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
|
||||||
values[foreignKey.DBName] = field.Field.Interface()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add create relationship in join table for source and destination
|
// Add create relationship in join table for source and destination
|
||||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||||
scope := db.NewScope("")
|
var (
|
||||||
searchMap := map[string]interface{}{}
|
scope = db.NewScope("")
|
||||||
|
conditionMap = map[string]interface{}{}
|
||||||
|
)
|
||||||
|
|
||||||
// getSearchMap() cannot be used here since the source and destination
|
// Update condition map for source
|
||||||
// model types may be identical
|
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
|
||||||
|
|
||||||
sourceScope := db.NewScope(source)
|
// Update condition map for destination
|
||||||
for _, foreignKey := range s.Source.ForeignKeys {
|
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
|
||||||
if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok {
|
|
||||||
searchMap[foreignKey.DBName] = field.Field.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
destinationScope := db.NewScope(destination)
|
|
||||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
|
||||||
if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok {
|
|
||||||
searchMap[foreignKey.DBName] = field.Field.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var assignColumns, binVars, conditions []string
|
var assignColumns, binVars, conditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
for key, value := range searchMap {
|
for key, value := range conditionMap {
|
||||||
assignColumns = append(assignColumns, scope.Quote(key))
|
assignColumns = append(assignColumns, scope.Quote(key))
|
||||||
binVars = append(binVars, `?`)
|
binVars = append(binVars, `?`)
|
||||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
|
@ -161,9 +146,12 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
||||||
scope = db.NewScope(nil)
|
scope = db.NewScope(nil)
|
||||||
conditions []string
|
conditions []string
|
||||||
values []interface{}
|
values []interface{}
|
||||||
|
conditionMap = map[string]interface{}{}
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value := range s.getSearchMap(db, sources...) {
|
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
|
||||||
|
|
||||||
|
for key, value := range conditionMap {
|
||||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||||
foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
foreignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
||||||
associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
|
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||||
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
|
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
|
||||||
|
@ -264,6 +266,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||||
relationship.Kind = "many_to_many"
|
relationship.Kind = "many_to_many"
|
||||||
|
|
||||||
|
{ // Foreign Keys for Source
|
||||||
|
joinTableDBNames := []string{}
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
||||||
|
joinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
|
}
|
||||||
|
|
||||||
// if no foreign keys defined with tag
|
// if no foreign keys defined with tag
|
||||||
if len(foreignKeys) == 0 {
|
if len(foreignKeys) == 0 {
|
||||||
for _, field := range modelStruct.PrimaryFields {
|
for _, field := range modelStruct.PrimaryFields {
|
||||||
|
@ -271,15 +280,29 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, foreignKey := range foreignKeys {
|
for idx, foreignKey := range foreignKeys {
|
||||||
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
// source foreign keys (db names)
|
// source foreign keys (db names)
|
||||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
|
||||||
// join table foreign keys for source
|
|
||||||
joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
// setup join table foreign keys for source
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
if len(joinTableDBNames) > idx {
|
||||||
|
// if defined join table's foreign key
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
||||||
|
} else {
|
||||||
|
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{ // Foreign Keys for Association (Destination)
|
||||||
|
associationJoinTableDBNames := []string{}
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
||||||
|
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
|
}
|
||||||
|
|
||||||
// if no association foreign keys defined with tag
|
// if no association foreign keys defined with tag
|
||||||
if len(associationForeignKeys) == 0 {
|
if len(associationForeignKeys) == 0 {
|
||||||
|
@ -288,28 +311,22 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, name := range associationForeignKeys {
|
for idx, name := range associationForeignKeys {
|
||||||
|
|
||||||
// In order to allow self-referencing many2many tables, the name
|
|
||||||
// may be followed by "=" to allow renaming the column
|
|
||||||
parts := strings.Split(name, "=")
|
|
||||||
name = parts[0]
|
|
||||||
|
|
||||||
if field, ok := toScope.FieldByName(name); ok {
|
if field, ok := toScope.FieldByName(name); ok {
|
||||||
// association foreign keys (db names)
|
// association foreign keys (db names)
|
||||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
||||||
|
|
||||||
// If a new name was provided for the field, use it
|
// setup join table foreign keys for association
|
||||||
name = field.DBName
|
if len(associationJoinTableDBNames) > idx {
|
||||||
if len(parts) > 1 {
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
||||||
name = parts[1]
|
} else {
|
||||||
}
|
|
||||||
|
|
||||||
// join table foreign keys for association
|
// join table foreign keys for association
|
||||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + name
|
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
joinTableHandler := JoinTableHandler{}
|
joinTableHandler := JoinTableHandler{}
|
||||||
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
||||||
|
@ -412,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||||
tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
tagForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
||||||
tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
|
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||||
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||||
|
|
Loading…
Reference in New Issue